NVIDIA
Explore
Models
Blueprints
GPUs
Docs
⌘KCtrl+K
View All Playbooks
View All Playbooks

onboarding

  • Set Up Local Network Access
  • Open WebUI with Ollama

data science

  • Single-cell RNA Sequencing
  • Portfolio Optimization
  • CUDA-X Data Science
  • Text to Knowledge Graph
  • Optimized JAX

tools

  • DGX Dashboard
  • Comfy UI
  • RAG Application in AI Workbench
  • Set up Tailscale on Your Spark
  • VS Code
  • Connect Three DGX Spark in a Ring Topology
  • Connect Multiple DGX Spark through a Switch

fine tuning

  • FLUX.1 Dreambooth LoRA Fine-tuning
  • LLaMA Factory
  • Fine-tune with NeMo
  • Fine-tune with Pytorch
  • Unsloth on DGX Spark

use case

  • NemoClaw with Nemotron 3 Super and Telegram on DGX Spark
  • cuTile Kernels
  • CLI Coding Agent
  • Live VLM WebUI
  • Install and Use Isaac Sim and Isaac Lab
  • Vibe Coding in VS Code
  • Build and Deploy a Multi-Agent Chatbot
  • Connect Two Sparks
  • NCCL for Two Sparks
  • Build a Video Search and Summarization (VSS) Agent
  • Spark & Reachy Photo Booth
  • Secure Long Running AI Agents with OpenShell on DGX Spark
  • OpenClaw šŸ¦ž

inference

  • LM Studio on DGX Spark
  • Speculative Decoding
  • Run models with llama.cpp on DGX Spark
  • Nemotron-3-Nano with llama.cpp
  • SGLang for Inference
  • TRT LLM for Inference
  • NVFP4 Quantization
  • Multi-modal Inference
  • NIM on Spark
  • vLLM for Inference

cuTile Kernels

60 MIN

Run cuTile kernel benchmarks, FMHA implementation, and LLM inference on DGX Spark and B300

BenchmarkingCross-PlatformDeepSeekDockerFMHAFlash AttentionGPU DevelopmentLLM InferenceQwen2TileGymcuTile
View on GitHub
OverviewOverviewKernel BenchmarksKernel BenchmarksEnd-to-End InferenceEnd-to-End InferenceFMHA ImplementationFMHA ImplementationPlatform ComparisonPlatform ComparisonTroubleshootingTroubleshooting

FMHA Implementation Guide

NOTE

This is a guide to understanding FMHA implementation in cuTile, not a complete reference. For comprehensive documentation, see the cuTile Python Documentation.

Attention Basics

Attention allows a neural network to focus on relevant parts of the input. In transformers (GPT, LLaMA, Qwen), each position computes how much to attend to every other position using three vectors:

  • Query (Q): "What am I looking for?"
  • Key (K): "What do I contain?"
  • Value (V): "Here is my content"
Attention(Q, K, V) = softmax(Q Ɨ K^T / √d) Ɨ V

Shapes:
  Q, K, V = [batch, heads, seq_len, head_dim]
  Q Ɨ K^T = [batch, heads, seq_len, seq_len]  # Attention scores
  Output  = [batch, heads, seq_len, head_dim]

For autoregressive models, causal masking ensures each token only attends to previous tokens by setting future scores to -infinity before softmax.

Flash Attention Algorithm

Standard attention materializes a [seq_len Ɨ seq_len] matrix (e.g., 2 GB for seq_len=32768). Flash Attention avoids this by processing in tiles with online softmax:

m = -infinity    # Running maximum
l = 0            # Running sum of exp(x - m)
acc = 0          # Running weighted sum of values

FOR each K,V tile:
    scores = Q_tile @ K_tile.T * scale
    m_new = max(m, max(scores))
    correction = exp(m - m_new)
    l = l * correction + sum(exp(scores - m_new))
    acc = acc * correction + exp(scores - m_new) @ V_tile
    m = m_new

output = acc / l

cuTile Pseudocode → Actual Mapping

ConceptPseudocodecuTile
Define kernelKERNEL fmha(...)@ct.kernel()
Get block IDblock_x = BLOCK_ID_Xbid_x = ct.bid(0)
Create indicesrange(0, N)ct.arange(N, dtype=ct.int32)
Create constant tiletile = zeros(M, N)ct.full((M, N), 0.0, dtype)
Load from memorytile = LOAD(ptr, shape)ct.load(tensor, index, shape)
Store to memorySTORE(ptr, tile)ct.store(tensor, index, tile)
Matrix multiplyC = A @ B + Cct.mma(A, B, C)
Reductionmax_val = MAX(tile, axis)ct.max(tile, axis, keepdims)

Kernel Pseudocode

KERNEL fmha(Q, K, V, Out, scale, TILE_M, TILE_N):
    tile_row = BLOCK_ID_X
    batch_head = BLOCK_ID_Y
    batch = batch_head // num_heads
    head = batch_head % num_heads

    m_i = full(TILE_M, -infinity)
    l_i = full(TILE_M, 0)
    acc = zeros(TILE_M, head_dim)

    q = LOAD(Q[batch, head, tile_row*TILE_M : (tile_row+1)*TILE_M, :])

    FOR j = 0 to num_k_tiles:
        k = LOAD(K[batch, head, j*TILE_N : (j+1)*TILE_N, :])
        v = LOAD(V[batch, head, j*TILE_N : (j+1)*TILE_N, :])
        scores = MMA(q, transpose(k)) * scale
        IF causal AND in_mask_region:
            scores = WHERE(valid_mask, scores, -infinity)
        m_new = max(m_i, row_max(scores))
        correction = exp(m_i - m_new)
        p = exp(scores - m_new)
        l_i = l_i * correction + row_sum(p)
        acc = acc * correction + MMA(p, v)
        m_i = m_new

    out = acc / l_i
    STORE(Out[batch, head, tile_row*TILE_M :, :], out)

cuTile Implementation

import cuda.tile as ct
import math
ConstInt = ct.Constant[int]
ConstBool = ct.Constant[bool]

@ct.kernel()
def fmha_kernel(Q, K, V, Out, qk_scale: float, TILE_D: ConstInt, H: ConstInt,
                TILE_M: ConstInt, TILE_N: ConstInt, CAUSAL: ConstBool):
    bid_x, bid_y = ct.bid(0), ct.bid(1)
    batch_idx, head_idx = bid_y // H, bid_y % H

    offs_m = (bid_x * TILE_M + ct.arange(TILE_M, dtype=ct.int32))[:, None]
    offs_n_tile = ct.arange(TILE_N, dtype=ct.int32)[None, :]

    m_i = ct.full((TILE_M, 1), -math.inf, dtype=ct.float32)
    l_i = ct.full((TILE_M, 1), 0.0, dtype=ct.float32)
    acc = ct.full((TILE_M, TILE_D), 0.0, dtype=ct.float32)

    q = ct.load(Q, index=(batch_idx, head_idx, bid_x, 0),
                shape=(1, 1, TILE_M, TILE_D)).reshape((TILE_M, TILE_D))

    k_seqlen = K.shape[2]
    if CAUSAL:
        Tc = ct.cdiv(min((bid_x + 1) * TILE_M, k_seqlen), TILE_N)
        mask_start = (bid_x * TILE_M) // TILE_N
    else:
        Tc = ct.cdiv(k_seqlen, TILE_N)
        mask_start = k_seqlen // TILE_N

    for j in range(0, Tc):
        k_tile = ct.load(K, index=(batch_idx, head_idx, j, 0),
                        shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
        k_t = ct.permute(k_tile, (1, 0))

        qk = ct.mma(q, k_t, ct.full((TILE_M, TILE_N), 0.0, dtype=ct.float32))
        qk = qk * qk_scale

        if CAUSAL and j >= mask_start:
            offs_n = j * TILE_N + offs_n_tile
            qk = ct.where(offs_m >= offs_n, qk,
                         ct.full((TILE_M, TILE_N), -math.inf, dtype=ct.float32))

        m_ij = ct.maximum(m_i, ct.max(qk, axis=-1, keepdims=True))
        qk = qk - m_ij
        p = ct.exp(qk)
        alpha = ct.exp(m_i - m_ij)
        l_i = l_i * alpha + ct.sum(p, axis=-1, keepdims=True)
        acc = acc * alpha

        v_tile = ct.load(V, index=(batch_idx, head_idx, j, 0),
                        shape=(1, 1, TILE_N, TILE_D)).reshape((TILE_N, TILE_D))
        acc = ct.mma(p.astype(Q.dtype), v_tile, acc)
        m_i = m_ij

    acc = (acc / l_i).reshape((1, 1, TILE_M, TILE_D)).astype(Out.dtype)
    ct.store(Out, index=(batch_idx, head_idx, bid_x, 0), tile=acc)

Launching the Kernel

def run_fmha(q, k, v, sm_scale, is_causal=True):
    import torch
    TILE_M, TILE_N = 64, 64  # Platform-specific (see below)
    batch, num_heads, seq_len, head_dim = q.shape
    out = torch.empty_like(q)
    grid = (math.ceil(seq_len / TILE_M), batch * num_heads, 1)
    ct.launch(
        torch.cuda.current_stream(), grid, fmha_kernel,
        (q, k, v, out, sm_scale, head_dim, num_heads, TILE_M, TILE_N, is_causal)
    )
    return out

Optimizations

exp2 + flush_to_zero

exp2(x) = 2^x is faster than exp(x) on GPU. Requires scale adjustment by 1/log(2).

# Convert natural-exp scale to base-2 so we can use the faster ct.exp2 intrinsic.
# exp(x) == exp2(x / log(2)) == exp2(x * INV_LOG_2).
INV_LOG_2 = 1.0 / math.log(2)  # ā‰ˆ 1.4427
qk_scale_log2 = qk_scale * INV_LOG_2  # Pre-multiply the softmax scale once

# ... in loop:
# Fuse the running-max update with the scale multiplication.
m_ij = ct.max(qk, axis=-1, keepdims=True) * qk_scale_log2
# Subtract the running max for numerical stability (online softmax).
qk = qk * qk_scale_log2 - m_ij
# flush_to_zero=True: flush denormals to 0 -> avoids slow denormal handling on GPU.
p = ct.exp2(qk, flush_to_zero=True)
alpha = ct.exp2(m_i - m_ij, flush_to_zero=True)  # Correction factor for previous acc/l_i

Load Order Transpose

Load K already transposed using order parameter, avoiding explicit permute.

# order=(0,1,3,2) swaps the last two axes during the load,
# producing K^T directly in registers -- no extra ct.permute() needed.
# shape is expressed in the transposed layout: (1, 1, TILE_D, TILE_N).
k_t = ct.load(K, index=(..., 0, j), shape=(1,1,TILE_D,TILE_N),
              order=(0,1,3,2)).reshape((TILE_D, TILE_N))

Latency Hints

Prefetch data to overlap memory loads with computation. See the Performance Tuning docs for the full list of load/store hints (e.g. allow_tma, latency).

# latency=N tells the compiler to issue this load N loop iterations in
# advance of its use, so the memory transfer overlaps with the MMA work
# from earlier iterations. Larger latency = deeper software pipeline but
# more register pressure.
k_t = ct.load(K, ..., latency=2)    # Prefetch K 2 iterations ahead
v_tile = ct.load(V, ..., latency=4) # Prefetch V 4 iterations ahead (used later in the loop)

Occupancy

Allow multiple thread blocks per SM to hide memory latency. See the Execution Model docs for details on how occupancy interacts with registers and shared memory.

# occupancy=N is a hint to the compiler to target N concurrent CTAs per SM.
# Higher occupancy -> more warps available to hide memory latency,
# but constrains the per-CTA register/SMEM budget.
@ct.kernel(occupancy=2)  # 2 thread blocks (CTAs) co-resident per SM
def fmha_optimized(...):

Approximate Division

Use fast approximate division for final normalization.

from cuda.tile import RoundingMode as RMd
# RMd.APPROX -> hardware approximate reciprocal/divide (MUFU), much faster
# than IEEE-compliant division. Safe here because it's the final softmax
# normalization step where a small ULP error is acceptable.
# flush_to_zero=True flushes denormals to 0 to avoid the slow path.
acc = ct.truediv(acc, l_i, flush_to_zero=True, rounding_mode=RMd.APPROX)

Platform Configuration

The same kernel code works on all platforms; only configuration parameters change. Use ct.ByTarget to select values per architecture, or ct.autotune to search candidate values automatically.

PlatformTILE_MTILE_NOccupancyRationale
DGX Spark (sm_121)64642Smaller tiles, higher occupancy for 48 SMs
B300 (sm_103)2561281Large tiles maximize HBM3e throughput
B300 alternate1281282Higher occupancy, balanced parallelism
import cuda.tile as ct

@ct.kernel(
    # TILE_M / TILE_N: rows/cols of the Q and K/V tiles processed per CTA.
    # Larger tiles -> more arithmetic intensity; smaller tiles -> higher occupancy.
    # occupancy: target concurrent CTAs per SM (latency hiding vs. register pressure).
    occupancy=ct.ByTarget({
        "sm_121": 2,   # DGX Spark (48 SMs): 2 CTAs/SM for latency hiding
        "sm_100": 1,   # B300: larger tiles already saturate the SM
        "default": 1,  # Conservative fallback for other architectures
    }),
    opt_level=3        # Maximum compiler optimization level
)
def fmha_kernel(...):
    ...

Performance Results

Note: PyTorch SDPA is used for correctness verification only, not performance comparison.

DGX Spark (sm_121) — Seq 2048

StepOptimizationLatency (ms)TFLOPS
1Basic cuTile2.1962.8
2+ exp22.0766.5
3+ Load Order2.0766.3
4+ Latency Hints2.0766.5
5+ Occupancy=21.7379.5
6+ Approx Div (Final)1.6981.1

B300 (sm_103) — Various Seq Lengths

Seq LenLatency (ms)TFLOPSvs Spark
10240.0744655.7x
20480.1787709.5x
40960.55099915.1x
81921.897115914.6x
163847.014125414.2x

Common Issues

IssueSolution
Shape mismatch in ct.mmaEnsure A is (M,K), B is (K,N), C is (M,N)
dtype errorsUse .astype() before mma; accumulator should be float32
Incorrect results with causalCheck mask_start calculation and offs_m >= offs_n logic
Low performanceTry different TILE_M/N, check occupancy, verify latency hints

Companion Scripts

The following scripts are included in this playbook and can be run on DGX Spark or B300:

  • assets/fmha_optimization_tutorial.py — Step-by-step optimization tutorial. Builds the FMHA kernel from basic to fully optimized, matching the progression in this guide.
  • assets/fmha_scaling_analysis.py — Scaling analysis across sequence lengths. Benchmarks each optimization level and generates performance data.
# Run the optimization tutorial (DGX Spark)
python assets/fmha_optimization_tutorial.py --correctness-check

# Run the scaling analysis
python assets/fmha_scaling_analysis.py --iterations 100

References

  • cuTile Python Documentation
  • Tile IR Specification
  • TileGym (pre-optimized kernels)
  • NVIDIA Blog: Tuning Flash Attention for Peak Performance in CUDA Tile
  • Flash Attention Paper

Resources

  • TileGym Repository
  • cuTile Python Documentation
  • Tile IR Specification
  • DGX Spark Documentation
  • DGX Spark Forum
  • Qwen2 on HuggingFace
  • DeepSeek-V2-Lite on HuggingFace
  • NVIDIA Blog - Tuning Flash Attention in CUDA Tile
  • Flash Attention Paper
Terms of Use
Privacy Policy
Your Privacy Choices
Contact

Copyright Ā© 2026 NVIDIA Corporation