Run cuTile kernel benchmarks, FMHA implementation, and LLM inference on DGX Spark and B300
NOTE
This is a guide to understanding FMHA implementation in cuTile, not a complete reference. For comprehensive documentation, see the cuTile Python Documentation.
Attention allows a neural network to focus on relevant parts of the input. In transformers (GPT, LLaMA, Qwen), each position computes how much to attend to every other position using three vectors:
Attention(Q, K, V) = softmax(Q Ć K^T / ād) Ć V
Shapes:
Q, K, V = [batch, heads, seq_len, head_dim]
Q Ć K^T = [batch, heads, seq_len, seq_len] # Attention scores
Output = [batch, heads, seq_len, head_dim]
For autoregressive models, causal masking ensures each token only attends to previous tokens by setting future scores to -infinity before softmax.
Standard attention materializes a [seq_len Ć seq_len] matrix (e.g., 2 GB for seq_len=32768). Flash Attention avoids this by processing in tiles with online softmax:
m = -infinity # Running maximum
l = 0 # Running sum of exp(x - m)
acc = 0 # Running weighted sum of values
FOR each K,V tile:
scores = Q_tile @ K_tile.T * scale
m_new = max(m, max(scores))
correction = exp(m - m_new)
l = l * correction + sum(exp(scores - m_new))
acc = acc * correction + exp(scores - m_new) @ V_tile
m = m_new
output = acc / l
| Concept | Pseudocode | cuTile |
|---|---|---|
| Define kernel | KERNEL fmha(...) | @ct.kernel() |
| Get block ID | block_x = BLOCK_ID_X | bid_x = ct.bid(0) |
| Create indices | range(0, N) | ct.arange(N, dtype=ct.int32) |
| Create constant tile | tile = zeros(M, N) | ct.full((M, N), 0.0, dtype) |
| Load from memory | tile = LOAD(ptr, shape) | ct.load(tensor, index, shape) |
| Store to memory | STORE(ptr, tile) | ct.store(tensor, index, tile) |
| Matrix multiply | C = A @ B + C | ct.mma(A, B, C) |
| Reduction | max_val = MAX(tile, axis) | ct.max(tile, axis, keepdims) |
KERNEL fmha(Q, K, V, Out, scale, TILE_M, TILE_N):
tile_row = BLOCK_ID_X
batch_head = BLOCK_ID_Y
batch = batch_head // num_heads
head = batch_head % num_heads
m_i = full(TILE_M, -infinity)
l_i = full(TILE_M, 0)
acc = zeros(TILE_M, head_dim)
q = LOAD(Q[batch, head, tile_row*TILE_M : (tile_row+1)*TILE_M, :])
FOR j = 0 to num_k_tiles:
k = LOAD(K[batch, head, j*TILE_N : (j+1)*TILE_N, :])
v = LOAD(V[batch, head, j*TILE_N : (j+1)*TILE_N, :])
scores = MMA(q, transpose(k)) * scale
IF causal AND in_mask_region:
scores = WHERE(valid_mask, scores, -infinity)
m_new = max(m_i, row_max(scores))
correction = exp(m_i - m_new)
p = exp(scores - m_new)
l_i = l_i * correction + row_sum(p)
acc = acc * correction + MMA(p, v)
m_i = m_new
out = acc / l_i
STORE(Out[batch, head, tile_row*TILE_M :, :], out)
import cuda.tile as ct
import math
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_kernel(Q, K, V, Out, qk_scale: float, TILE_D: ConstInt, H: ConstInt,
TILE_M: ConstInt, TILE_N: ConstInt, CAUSAL: ConstBool):
bid_x, bid_y = ct.bid(0), ct.bid(1)
batch_idx, head_idx = bid_y // H, bid_y % H
offs_m = (bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32))[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0),
shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
Tc = ct.cdiv(min((bid_x + 1) * TILE_M, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.mma(q, k_t, ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32))
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
qk = ct.where(offs_m >= offs_n, qk,
ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.maximum(m_i, ct.max(qk, axis=-1, keepdims=True))
qk = qk - m_ij
p = ct.exp(qk)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + ct.sum(p, axis=-1, keepdims=True)
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
acc = ct.mma(p.astype(Q.dtype), v_tile, acc)
m_i = m_ij
acc = (acc / l_i).reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
def run_fmha(q, k, v, sm_scale, is_causal=True):
import torch
TILE_M, TILE_N = 64, 64 # Platform-specific (see below)
batch, num_heads, seq_len, head_dim = q.shape
out = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch * num_heads, 1)
ct.launch(
torch.cuda.current_stream(), grid, fmha_kernel,
(q, k, v, out, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return out
exp2(x) = 2^x is faster than exp(x) on GPU. Requires scale adjustment by 1/log(2).
# Convert natural-exp scale to base-2 so we can use the faster ct.exp2 intrinsic.
# exp(x) == exp2(x / log(2)) == exp2(x * INV_LOG_2).
INV_LOG_2 = 1.0 / math.log(2) # ā 1.4427
qk_scale_log2 = qk_scale * INV_LOG_2 # Pre-multiply the softmax scale once
# ... in loop:
# Fuse the running-max update with the scale multiplication.
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
# Subtract the running max for numerical stability (online softmax).
qk = qk * qk_scale_log2 - m_ij
# flush_to_zero=True: flush denormals to 0 -> avoids slow denormal handling on GPU.
p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # Correction factor for previous acc/l_i
Load K already transposed using order parameter, avoiding explicit permute.
# order=(0,1,3,2) swaps the last two axes during the load,
# producing K^T directly in registers -- no extra ct.permute() needed.
# shape is expressed in the transposed layout: (1, 1, TILE_D, TILE_N).
k_t = ct.load(K, index=(..., 0, j), shape=(1,1,TILE_D,TILE_N),
order=(0,1,3,2)).reshape((TILE_D, TILE_N))
Prefetch data to overlap memory loads with computation. See the Performance Tuning docs for the full list of load/store hints (e.g. allow_tma, latency).
# latency=N tells the compiler to issue this load N loop iterations in
# advance of its use, so the memory transfer overlaps with the MMA work
# from earlier iterations. Larger latency = deeper software pipeline but
# more register pressure.
k_t = ct.load(K, ..., latency=2) # Prefetch K 2 iterations ahead
v_tile = ct.load(V, ..., latency=4) # Prefetch V 4 iterations ahead (used later in the loop)
Allow multiple thread blocks per SM to hide memory latency. See the Execution Model docs for details on how occupancy interacts with registers and shared memory.
# occupancy=N is a hint to the compiler to target N concurrent CTAs per SM.
# Higher occupancy -> more warps available to hide memory latency,
# but constrains the per-CTA register/SMEM budget.
@ct.kernel(occupancy=2) # 2 thread blocks (CTAs) co-resident per SM
def fmha_optimized(...):
Use fast approximate division for final normalization.
from cuda.tile import RoundingMode as RMd
# RMd.APPROX -> hardware approximate reciprocal/divide (MUFU), much faster
# than IEEE-compliant division. Safe here because it's the final softmax
# normalization step where a small ULP error is acceptable.
# flush_to_zero=True flushes denormals to 0 to avoid the slow path.
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
The same kernel code works on all platforms; only configuration parameters change. Use ct.ByTarget to select values per architecture, or ct.autotune to search candidate values automatically.
| Platform | TILE_M | TILE_N | Occupancy | Rationale |
|---|---|---|---|---|
| DGX Spark (sm_121) | 64 | 64 | 2 | Smaller tiles, higher occupancy for 48 SMs |
| B300 (sm_103) | 256 | 128 | 1 | Large tiles maximize HBM3e throughput |
| B300 alternate | 128 | 128 | 2 | Higher occupancy, balanced parallelism |
import cuda.tile as ct
@ct.kernel(
# TILE_M / TILE_N: rows/cols of the Q and K/V tiles processed per CTA.
# Larger tiles -> more arithmetic intensity; smaller tiles -> higher occupancy.
# occupancy: target concurrent CTAs per SM (latency hiding vs. register pressure).
occupancy=ct.ByTarget({
"sm_121": 2, # DGX Spark (48 SMs): 2 CTAs/SM for latency hiding
"sm_100": 1, # B300: larger tiles already saturate the SM
"default": 1, # Conservative fallback for other architectures
}),
opt_level=3 # Maximum compiler optimization level
)
def fmha_kernel(...):
...
Note: PyTorch SDPA is used for correctness verification only, not performance comparison.
| Step | Optimization | Latency (ms) | TFLOPS |
|---|---|---|---|
| 1 | Basic cuTile | 2.19 | 62.8 |
| 2 | + exp2 | 2.07 | 66.5 |
| 3 | + Load Order | 2.07 | 66.3 |
| 4 | + Latency Hints | 2.07 | 66.5 |
| 5 | + Occupancy=2 | 1.73 | 79.5 |
| 6 | + Approx Div (Final) | 1.69 | 81.1 |
| Seq Len | Latency (ms) | TFLOPS | vs Spark |
|---|---|---|---|
| 1024 | 0.074 | 465 | 5.7x |
| 2048 | 0.178 | 770 | 9.5x |
| 4096 | 0.550 | 999 | 15.1x |
| 8192 | 1.897 | 1159 | 14.6x |
| 16384 | 7.014 | 1254 | 14.2x |
| Issue | Solution |
|---|---|
| Shape mismatch in ct.mma | Ensure A is (M,K), B is (K,N), C is (M,N) |
| dtype errors | Use .astype() before mma; accumulator should be float32 |
| Incorrect results with causal | Check mask_start calculation and offs_m >= offs_n logic |
| Low performance | Try different TILE_M/N, check occupancy, verify latency hints |
The following scripts are included in this playbook and can be run on DGX Spark or B300:
assets/fmha_optimization_tutorial.py ā Step-by-step optimization tutorial. Builds the FMHA kernel from basic to fully optimized, matching the progression in this guide.assets/fmha_scaling_analysis.py ā Scaling analysis across sequence lengths. Benchmarks each optimization level and generates performance data.# Run the optimization tutorial (DGX Spark)
python assets/fmha_optimization_tutorial.py --correctness-check
# Run the scaling analysis
python assets/fmha_scaling_analysis.py --iterations 100