cuTile Kernels
Run cuTile kernel benchmarks, FMHA implementation, and LLM inference on DGX Spark and B300
FMHA Implementation Guide
NOTE
This is a guide to understanding FMHA implementation in cuTile, not a complete reference. For comprehensive documentation, see the cuTile Python Documentation.
Attention Basics
Attention allows a neural network to focus on relevant parts of the input. In transformers (GPT, LLaMA, Qwen), each position computes how much to attend to every other position using three vectors:
- Query (Q): "What am I looking for?"
- Key (K): "What do I contain?"
- Value (V): "Here is my content"
Attention(Q, K, V) = softmax(Q × K^T / √d) × V
Shapes:
Q, K, V = [batch, heads, seq_len, head_dim]
Q × K^T = [batch, heads, seq_len, seq_len] # Attention scores
Output = [batch, heads, seq_len, head_dim]
For autoregressive models, causal masking ensures each token only attends to previous tokens by setting future scores to -infinity before softmax.
Flash Attention Algorithm
Standard attention materializes a [seq_len × seq_len] matrix (e.g., 2 GB for seq_len=32768). Flash Attention avoids this by processing in tiles with online softmax:
m = -infinity # Running maximum
l = 0 # Running sum of exp(x - m)
acc = 0 # Running weighted sum of values
FOR each K,V tile:
scores = Q_tile @ K_tile.T * scale
m_new = max(m, max(scores))
correction = exp(m - m_new)
l = l * correction + sum(exp(scores - m_new))
acc = acc * correction + exp(scores - m_new) @ V_tile
m = m_new
output = acc / l
cuTile Pseudocode → Actual Mapping
| Concept | Pseudocode | cuTile |
|---|---|---|
| Define kernel | KERNEL fmha(...) | @ct.kernel() |
| Get block ID | block_x = BLOCK_ID_X | bid_x = ct.bid(0) |
| Create indices | range(0, N) | ct.arange(N, dtype=ct.int32) |
| Create constant tile | tile = zeros(M, N) | ct.full((M, N), 0.0, dtype) |
| Load from memory | tile = LOAD(ptr, shape) | ct.load(tensor, index, shape) |
| Store to memory | STORE(ptr, tile) | ct.store(tensor, index, tile) |
| Matrix multiply | C = A @ B + C | ct.mma(A, B, C) |
| Reduction | max_val = MAX(tile, axis) | ct.max(tile, axis, keepdims) |
Kernel Pseudocode
KERNEL fmha(Q, K, V, Out, scale, TILE_M, TILE_N):
tile_row = BLOCK_ID_X
batch_head = BLOCK_ID_Y
batch = batch_head // num_heads
head = batch_head % num_heads
m_i = full(TILE_M, -infinity)
l_i = full(TILE_M, 0)
acc = zeros(TILE_M, head_dim)
q = LOAD(Q[batch, head, tile_row*TILE_M : (tile_row+1)*TILE_M, :])
FOR j = 0 to num_k_tiles:
k = LOAD(K[batch, head, j*TILE_N : (j+1)*TILE_N, :])
v = LOAD(V[batch, head, j*TILE_N : (j+1)*TILE_N, :])
scores = MMA(q, transpose(k)) * scale
IF causal AND in_mask_region:
scores = WHERE(valid_mask, scores, -infinity)
m_new = max(m_i, row_max(scores))
correction = exp(m_i - m_new)
p = exp(scores - m_new)
l_i = l_i * correction + row_sum(p)
acc = acc * correction + MMA(p, v)
m_i = m_new
out = acc / l_i
STORE(Out[batch, head, tile_row*TILE_M :, :], out)
cuTile Implementation
import cuda.tile as ct
import math
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]
@ct.kernel()
def fmha_kernel(Q, K, V, Out, qk_scale: float, TILE_D: ConstInt, H: ConstInt,
TILE_M: ConstInt, TILE_N: ConstInt, CAUSAL: ConstBool):
bid_x, bid_y = ct.bid(0), ct.bid(1)
batch_idx, head_idx = bid_y // H, bid_y % H
offs_m = (bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32))[:, None]
offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)[None, :]
m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)
q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0),
shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D))
k_seqlen = K.shape[2]
if CAUSAL:
Tc = ct.cdiv(min((bid_x + 1) * TILE_M, k_seqlen), TILE_N)
mask_start = (bid_x * TILE_M) // TILE_N
else:
Tc = ct.cdiv(k_seqlen, TILE_N)
mask_start = k_seqlen // TILE_N
for j in range(0, Tc):
k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
k_t = ct.permute(k_tile, (1, 0))
qk = ct.mma(q, k_t, ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32))
qk = qk * qk_scale
if CAUSAL and j >= mask_start:
offs_n = j * TILE_N + offs_n_tile
qk = ct.where(offs_m >= offs_n, qk,
ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))
m_ij = ct.maximum(m_i, ct.max(qk, axis=-1, keepdims=True))
qk = qk - m_ij
p = ct.exp(qk)
alpha = ct.exp(m_i - m_ij)
l_i = l_i * alpha + ct.sum(p, axis=-1, keepdims=True)
acc = acc * alpha
v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0),
shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
acc = ct.mma(p.astype(Q.dtype), v_tile, acc)
m_i = m_ij
acc = (acc / l_i).reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)
Launching the Kernel
def run_fmha(q, k, v, sm_scale, is_causal=True):
import torch
TILE_M, TILE_N = 64, 64 # Platform-specific (see below)
batch, num_heads, seq_len, head_dim = q.shape
out = torch.empty_like(q)
grid = (math.ceil(seq_len / TILE_M), batch * num_heads, 1)
ct.launch(
torch.cuda.current_stream(), grid, fmha_kernel,
(q, k, v, out, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
)
return out
Optimizations
exp2 + flush_to_zero
exp2(x) = 2^x is faster than exp(x) on GPU. Requires scale adjustment by 1/log(2).
# Convert natural-exp scale to base-2 so we can use the faster ct.exp2 intrinsic.
# exp(x) == exp2(x / log(2)) == exp2(x * INV_LOG_2).
INV_LOG_2 = 1.0 / math.log(2) # ≈ 1.4427
qk_scale_log2 = qk_scale * INV_LOG_2 # Pre-multiply the softmax scale once
# ... in loop:
# Fuse the running-max update with the scale multiplication.
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
# Subtract the running max for numerical stability (online softmax).
qk = qk * qk_scale_log2 - m_ij
# flush_to_zero=True: flush denormals to 0 -> avoids slow denormal handling on GPU.
p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True) # Correction factor for previous acc/l_i
Load Order Transpose
Load K already transposed using order parameter, avoiding explicit permute.
# order=(0,1,3,2) swaps the last two axes during the load,
# producing K^T directly in registers -- no extra ct.permute() needed.
# shape is expressed in the transposed layout: (1, 1, TILE_D, TILE_N).
k_t = ct.load(K, index=(..., 0, j), shape=(1,1,TILE_D,TILE_N),
order=(0,1,3,2)).reshape((TILE_D, TILE_N))
Latency Hints
Prefetch data to overlap memory loads with computation. See the Performance Tuning docs for the full list of load/store hints (e.g. allow_tma, latency).
# latency=N tells the compiler to issue this load N loop iterations in
# advance of its use, so the memory transfer overlaps with the MMA work
# from earlier iterations. Larger latency = deeper software pipeline but
# more register pressure.
k_t = ct.load(K, ..., latency=2) # Prefetch K 2 iterations ahead
v_tile = ct.load(V, ..., latency=4) # Prefetch V 4 iterations ahead (used later in the loop)
Occupancy
Allow multiple thread blocks per SM to hide memory latency. See the Execution Model docs for details on how occupancy interacts with registers and shared memory.
# occupancy=N is a hint to the compiler to target N concurrent CTAs per SM.
# Higher occupancy -> more warps available to hide memory latency,
# but constrains the per-CTA register/SMEM budget.
@ct.kernel(occupancy=2) # 2 thread blocks (CTAs) co-resident per SM
def fmha_optimized(...):
Approximate Division
Use fast approximate division for final normalization.
from cuda.tile import RoundingMode as RMd
# RMd.APPROX -> hardware approximate reciprocal/divide (MUFU), much faster
# than IEEE-compliant division. Safe here because it's the final softmax
# normalization step where a small ULP error is acceptable.
# flush_to_zero=True flushes denormals to 0 to avoid the slow path.
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)
Platform Configuration
The same kernel code works on all platforms; only configuration parameters change. Use ct.ByTarget to select values per architecture, or ct.autotune to search candidate values automatically.
| Platform | TILE_M | TILE_N | Occupancy | Rationale |
|---|---|---|---|---|
| DGX Spark (sm_121) | 64 | 64 | 2 | Smaller tiles, higher occupancy for 48 SMs |
| B300 (sm_103) | 256 | 128 | 1 | Large tiles maximize HBM3e throughput |
| B300 alternate | 128 | 128 | 2 | Higher occupancy, balanced parallelism |
import cuda.tile as ct
@ct.kernel(
# TILE_M / TILE_N: rows/cols of the Q and K/V tiles processed per CTA.
# Larger tiles -> more arithmetic intensity; smaller tiles -> higher occupancy.
# occupancy: target concurrent CTAs per SM (latency hiding vs. register pressure).
occupancy=ct.ByTarget({
"sm_121": 2, # DGX Spark (48 SMs): 2 CTAs/SM for latency hiding
"sm_100": 1, # B300: larger tiles already saturate the SM
"default": 1, # Conservative fallback for other architectures
}),
opt_level=3 # Maximum compiler optimization level
)
def fmha_kernel(...):
...
Performance Results
Note: PyTorch SDPA is used for correctness verification only, not performance comparison.
DGX Spark (sm_121) — Seq 2048
| Step | Optimization | Latency (ms) | TFLOPS |
|---|---|---|---|
| 1 | Basic cuTile | 2.19 | 62.8 |
| 2 | + exp2 | 2.07 | 66.5 |
| 3 | + Load Order | 2.07 | 66.3 |
| 4 | + Latency Hints | 2.07 | 66.5 |
| 5 | + Occupancy=2 | 1.73 | 79.5 |
| 6 | + Approx Div (Final) | 1.69 | 81.1 |
B300 (sm_103) — Various Seq Lengths
| Seq Len | Latency (ms) | TFLOPS | vs Spark |
|---|---|---|---|
| 1024 | 0.074 | 465 | 5.7x |
| 2048 | 0.178 | 770 | 9.5x |
| 4096 | 0.550 | 999 | 15.1x |
| 8192 | 1.897 | 1159 | 14.6x |
| 16384 | 7.014 | 1254 | 14.2x |
Common Issues
| Issue | Solution |
|---|---|
| Shape mismatch in ct.mma | Ensure A is (M,K), B is (K,N), C is (M,N) |
| dtype errors | Use .astype() before mma; accumulator should be float32 |
| Incorrect results with causal | Check mask_start calculation and offs_m >= offs_n logic |
| Low performance | Try different TILE_M/N, check occupancy, verify latency hints |
Companion Scripts
The following scripts are included in this playbook and can be run on DGX Spark or B300:
assets/fmha_optimization_tutorial.py— Step-by-step optimization tutorial. Builds the FMHA kernel from basic to fully optimized, matching the progression in this guide.assets/fmha_scaling_analysis.py— Scaling analysis across sequence lengths. Benchmarks each optimization level and generates performance data.
# Run the optimization tutorial (DGX Spark)
python assets/fmha_optimization_tutorial.py --correctness-check
# Run the scaling analysis
python assets/fmha_scaling_analysis.py --iterations 100