Copyright © 2026 NVIDIA Corporation
Use torch.profiler to find training bottlenecks, then write custom Triton kernels to optimize LLaMA 8B fine-tuning
Before profiling or writing any GPU kernels, you need a development environment with PyTorch, Triton (the GPU programming language we'll use), and the tools to load LLaMA 3.1 8B. We use a Docker container so everything is pre-configured and isolated from your host system.
Clone the playbook repository and navigate to the assets directory:
git clone https://github.com/NVIDIA/dgx-spark-playbooks
cd dgx-spark-playbooks/nvidia/station-kernel-dev-ft/assets
Build the development container. This creates a Docker image based on NVIDIA's PyTorch NGC container with additional libraries for model loading and benchmarking:
docker build -t kernel-dev-ft .
Start the container with GPU access. Pass your Hugging Face token so the container can download LLaMA 3.1 8B:
docker run -it --rm \
--name kernel-dev-ft \
--gpus all \
--ipc host \
-e HF_TOKEN=$HF_TOKEN \
-v "$(pwd):/workspace" \
-v ~/.cache/huggingface:/root/.cache/huggingface \
-w /workspace \
kernel-dev-ft
NOTE
The -v "$(pwd):/workspace" flag mounts the current directory into the container. Any files you create or modify inside /workspace persist on your host machine after the container exits. The -v ~/.cache/huggingface:/root/.cache/huggingface mount persists downloaded model weights across container restarts so you don't need to re-download the 16 GB model each time. Everything outside these mounted paths is discarded when the container stops.
NOTE
If you haven't set HF_TOKEN in your shell, export it first: export HF_TOKEN=hf_your_token_here. You need a Hugging Face token with access to meta-llama/Llama-3.1-8B. You must first accept the LLaMA 3.1 Community License Agreement on the model page before your token can download the weights.
Verify the toolchain inside the container:
python -c "import triton; print(f'Triton {triton.__version__}')"
python -c "import torch; print(f'PyTorch {torch.__version__}, CUDA {torch.version.cuda}')"
nvidia-smi --query-gpu=name,compute_cap --format=csv,noheader
Expected output should show:
NOTE
Unlike the CUDA C++ workflow (which requires the nvcc compiler and a separate compilation step), Triton is a Python library that JIT-compiles GPU code at runtime. There is no build step — you write Python, and Triton compiles it to optimized GPU machine code when you first call the kernel.
Before profiling, let's build a mental model of where GPU time goes during LLaMA 3.1 8B fine-tuning and why certain operations are candidates for custom kernels.
LLaMA 3.1 8B architecture at a glance:
| Property | Value |
|---|---|
| Parameters | 8.03 billion |
| Layers | 32 transformer blocks |
| Hidden size | 4096 |
| Attention heads | 32 |
| Key/value heads | 8 (grouped-query attention) |
| Vocabulary | 128,256 tokens |
| Normalization | RMSNorm (not LayerNorm) |
| Activation | SwiGLU (SiLU-gated MLP) |
Memory budget for full fine-tuning in BF16:
Why training is different from inference for kernel optimization:
For inference, torch.compile and serving frameworks like vLLM already fuse most pointwise operations automatically. Writing a custom SiLU or SwiGLU kernel for inference is reinventing what's already solved.
Training is different for three reasons:
torch.compile handles some of these but cannot restructure algorithms (like how loss is computed).Now let's see where GPU time actually goes. We'll use torch.profiler to capture a detailed trace of a single forward + backward + optimizer step.
Run the profiling script:
python profile_baseline.py
NOTE
The first run downloads LLaMA 3.1 8B weights (~16 GB in BF16) from Hugging Face. This takes several minutes depending on network speed. Subsequent runs use the cached weights and start immediately.
The script loads LLaMA 3.1 8B, runs one training step under torch.profiler, and prints a table like this:
======================================================================
Top 20 CUDA Operations by Total GPU Time
======================================================================
Name Self CUDA Self CUDA % # Calls
------------------------------------ ---------- ----------- --------
aten::mm 152.3ms 42.1% 258
aten::_flash_attention_forward 48.7ms 13.5% 32
aten::_flash_attention_backward 41.2ms 11.4% 32
aten::_scaled_mm 28.1ms 7.8% 2
aten::pow 12.4ms 3.4% 65
aten::mean 11.8ms 3.3% 65
aten::rsqrt 8.2ms 2.3% 65
aten::mul 15.6ms 4.3% 198
aten::_log_softmax 8.9ms 2.5% 1
aten::nll_loss_forward 3.2ms 0.9% 1
...
How to read these results:
aten::mm (matrix multiplications): The largest single category (~42% of GPU time). These are already highly optimized by cuBLAS. Not a target for custom kernels.aten::_flash_attention_forward/backward (~25% combined): Already optimized by FlashAttention. Not a target.aten::pow, aten::mean, aten::rsqrt, and some aten::mul calls: These are the RMSNorm operations, broken into separate kernels. Individually small, but there are many of them (32 layers x 2 norms per layer + 1 model norm = 65 in the forward pass, plus corresponding backward operations). In aggregate, they consume significant time and make many redundant memory round-trips.aten::_log_softmax + aten::nll_loss_forward: This is the cross-entropy loss computation. Only called once, but it operates over the full [batch_size * seq_len, 128256] logit tensor.The profiler also saves a Chrome trace file. You can inspect it visually:
TIP
Open the Chrome trace JSON in Perfetto UI for an interactive timeline view. Look for sequences of narrow bars (small kernels with gaps between them) — these represent unfused operations where the GPU reads and writes the same data multiple times.
Before writing kernels, let's understand why our two targets are slow. This understanding will guide the kernel design.
RMSNorm is memory-bandwidth-bound.
The formula for RMSNorm is:
RMSNorm(x) = (x / sqrt(mean(x^2) + eps)) * weight
PyTorch's default implementation breaks this into separate GPU operations:
x.pow(2) — square each element → writes result to memory.mean(-1) — reduce across hidden dimension → reads result, writes mean+ eps then .rsqrt() — reads mean, writes inverse RMSx * rnorm — reads x again, reads rnorm, writes normalized output* weight — reads output, reads weight, writes final resultEach of these reads from and writes to GPU memory (HBM). For hidden_size=4096 in BF16, a single row is 8 KB. The unfused version reads and writes this data 5+ times. A fused kernel reads it once and writes once.
The DGX Station GB300's HBM3e has ~8 TB/s of bandwidth. PyTorch's unfused RMSNorm typically achieves only ~11% of this peak. A fused kernel can reach ~80-90% — a dramatic improvement for an operation that runs 66 times per training step.
Cross-entropy is memory-capacity-bound.
Standard cross-entropy computes softmax(logits) over the full vocabulary for every token position. For LLaMA 3.1 8B:
logit tensor shape: [batch_size * seq_len, 128256]
For batch_size=1, seq_len=512: [512, 128256]
Memory: 512 * 128256 * 4 bytes (float32) ≈ 250 MB
PyTorch also saves the softmax output for the backward pass, roughly doubling this to ~500 MB. As batch size or sequence length grows, this scales linearly.
The online softmax trick (Milakov & Gimelshein, 2018) avoids materializing the full logit tensor. Instead of computing softmax all at once, it processes the vocabulary in chunks while maintaining two running values:
m: the running maximum logit (for numerical stability)d: the running sum of exp(logit - m) (the softmax denominator)Here's the algorithm with a small example. Suppose we have 8 logits [2, 5, 1, 3, 4, 7, 2, 6] and process them in chunks of 4:
Chunk 1: [2, 5, 1, 3]
m = 5 (max of chunk)d = exp(2-5) + exp(5-5) + exp(1-5) + exp(3-5) = 0.050 + 1.0 + 0.018 + 0.135 = 1.203Chunk 2: [4, 7, 2, 6]
chunk_max = 7, new_m = max(5, 7) = 7d: d = 1.203 * exp(5 - 7) + exp(4-7) + exp(7-7) + exp(2-7) + exp(6-7)d = 1.203 * 0.135 + 0.050 + 1.0 + 0.007 + 0.368 = 1.587After all chunks: loss = log(d) + m - logit[target]. No [8]-sized softmax tensor was ever allocated — just two scalars (m, d) maintained across chunks.
For V=128256, this reduces the algorithmic memory from O(B*T*V) to O(B*T) per row. In practice, the input logit tensor is still allocated (PyTorch needs it for the backward pass), so the measured end-to-end reduction is ~6x — still a significant saving that frees hundreds of megabytes at realistic batch sizes.
Let's start with the simpler kernel. Open rmsnorm_kernel.py to review the implementation:
cat rmsnorm_kernel.py
This file contains four components:
_rmsnorm_fwd_kernel — The forward pass Triton kernel_rmsnorm_bwd_kernel — The backward pass Triton kernelTritonRMSNormFunction — A torch.autograd.Function that connects the kernels to PyTorch's autogradTritonRMSNorm — A drop-in nn.Module replacement for LlamaRMSNormKey Triton concepts in the forward kernel:
@triton.jit
def _rmsnorm_fwd_kernel(X_ptr, W_ptr, Y_ptr, Rnorm_ptr, stride_x, hidden_size, eps, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
...
@triton.jit marks a function for GPU compilation. This is Triton's equivalent of CUDA's __global__ keyword, but instead of writing C++, you write Python-like code. Triton's compiler handles thread management, memory coalescing, and vectorization automatically.
tl.program_id(0) returns a unique index for each "program" (similar to a CUDA thread block). Each program handles one row of the input tensor. For a batch of 512 tokens with hidden_size=4096, we launch 512 programs.
tl.load(X_ptr + row_start + offsets, mask=mask, other=0.0) loads a vector of values from GPU memory into registers. The mask ensures we don't read beyond the row boundary. The other=0.0 provides a default value for masked-out elements.
BLOCK_SIZE: tl.constexpr is a compile-time constant. Triton generates specialized GPU code for each value of BLOCK_SIZE. For hidden_size=4096, we use BLOCK_SIZE=4096 (the next power of 2), meaning each program loads the entire row in one batch.
The key optimization:
# One pass: read x, compute variance, normalize, multiply by weight, write y
x_fp32 = x.to(tl.float32)
variance = tl.sum(x_fp32 * x_fp32, axis=0) / hidden_size
rnorm = 1.0 / tl.sqrt(variance + eps)
y = (x_fp32 * rnorm).to(x.dtype) * w
tl.store(Y_ptr + row_start + offsets, y, mask=mask)
The entire RMSNorm computation — square, mean, rsqrt, normalize, scale by weight — happens in registers without intermediate writes to GPU memory. Compare this to PyTorch's 5 separate kernel launches, each with a full memory round-trip.
The backward kernel follows the same pattern: load everything needed for one row, compute both grad_x and grad_w in registers, write once. The mathematical derivation is documented in the kernel source comments.
The autograd wrapper (TritonRMSNormFunction) connects the kernels to PyTorch's automatic differentiation:
forward() calls the forward kernel and saves x, weight, and rnorm for later.backward() receives the upstream gradient, calls the backward kernel, and returns gradients for x and weight.NOTE
Triton vs. CUDA C++: In the Custom CUDA Kernel Development playbook, we wrote CUDA C++ with explicit thread indexing (blockIdx.x, threadIdx.x), manual float4 vectorization, nvcc compilation, and ctypes bindings. Triton abstracts all of that — you write Python-like code, and the compiler handles vectorization, memory coalescing, and PTX generation automatically. The tradeoff is less fine-grained hardware control, but for operations like RMSNorm, Triton matches hand-tuned CUDA performance.
Before measuring performance, verify the kernel produces the same results as PyTorch's implementation. Even small numerical errors can cascade through a 32-layer transformer and produce garbage gradients.
Run the correctness tests:
python rmsnorm_test.py
Expected output:
RMSNorm Correctness Tests
============================================================
Test 1: Float32
FP32 Forward — max diff: 9.54e-07 PASSED
FP32 Backward (dx) — max diff: 1.43e-06 PASSED
FP32 Backward (dw) — max diff: 2.29e-05 PASSED
Test 2: BFloat16 (relaxed tolerance)
BF16 Forward — max diff: 1.56e-02 PASSED
BF16 Backward (dx) — max diff: 1.56e-02 PASSED
BF16 Backward (dw) — max diff: 5.00e-01 PASSED
============================================================
All RMSNorm correctness tests PASSED
The tests compare the custom kernel against PyTorch's reference LlamaRMSNorm at shapes matching LLaMA 3.1 8B (batch=4, seq_len=512, hidden_size=4096), testing both the forward output and the backward gradients for x and weight.
WARNING
BF16 has only 7 bits of mantissa (vs. 23 for FP32). Per-element differences of ~0.01-0.02 are normal for forward and grad_x. The weight gradient (dw) shows larger absolute differences (up to ~0.5) because it sums per-element contributions across all 2,048 token positions — different summation order between our FP32-accumulated kernel and PyTorch's autograd produces BF16 rounding differences that accumulate. The test uses relaxed tolerance for dw to account for this.
The FP32 test uses tolerance atol=1e-4 and the BF16 test uses atol=1e-2 for per-element values, with a more relaxed threshold for the accumulated weight gradient. Both forward and backward must pass — many kernel bugs only manifest in the backward pass.
Now let's measure the performance improvement. Run the RMSNorm benchmark:
python benchmark_kernels.py --kernel rmsnorm
Example output:
======================================================================
RMSNorm Benchmark — Custom Triton vs. PyTorch Reference
======================================================================
GPU: NVIDIA GB300
Tokens Custom (us) PyTorch (us) Custom (GB/s) PyTorch (GB/s) Speedup
-------- ------------- -------------- --------------- ---------------- ---------
256 313.5 479.6 40 26 1.53x
1,024 313.6 495.8 161 102 1.58x
4,096 319.4 576.5 630 349 1.80x
16,384 298.9 2,041.7 2,694 394 6.83x
How to read these results:
Now re-profile the full training step with the custom RMSNorm to confirm the bottleneck is resolved:
python profile_baseline.py --use-custom-rmsnorm
Compare the profiler output to Step 3. The aten::pow, aten::mean, and aten::rsqrt calls from RMSNorm should be gone, replaced by fewer, faster Triton kernel calls. The remaining top operations should be matrix multiplications and FlashAttention — operations already handled by highly optimized libraries.
Now for the more complex kernel. Open cross_entropy_kernel.py:
cat cross_entropy_kernel.py
This implements the online softmax algorithm from Step 4 as a Triton kernel. The structure mirrors the RMSNorm kernel (forward kernel, backward kernel, autograd function, nn.Module), but the forward kernel is more complex because it loops over the vocabulary in chunks.
The forward kernel, annotated:
@triton.jit
def _cross_entropy_fwd_kernel(Logits_ptr, Targets_ptr, Losses_ptr, Max_ptr, Denom_ptr,
vocab_size, stride_logits, BLOCK_SIZE: tl.constexpr):
row_idx = tl.program_id(0)
...
m = float("-inf") # Running maximum logit
d = 0.0 # Running sum of exp(logit_i - m)
target_logit = 0.0 # Logit at the target index
for start in range(0, vocab_size, BLOCK_SIZE):
...
Key differences from the RMSNorm kernel:
Loop over vocabulary chunks. The RMSNorm kernel loads the entire row at once (4096 elements fits in registers). The cross-entropy kernel can't do that — 128,256 vocabulary entries is too large. Instead, it processes BLOCK_SIZE elements at a time (e.g., 4096 per iteration, 32 iterations total). Triton unrolls this loop for efficiency.
Running state across iterations. The kernel maintains m (running max) and d (running sum-of-exp) across loop iterations. The update rule handles the rescaling when a new maximum is found:
new_m = tl.maximum(m, chunk_max)
d = d * tl.exp(m - new_m) + tl.sum(tl.exp(logits_chunk - new_m), axis=0)
m = new_m
The d * tl.exp(m - new_m) term rescales the previous sum to account for a potentially larger maximum. This is the core of the online softmax algorithm.
No intermediate tensor allocation. The standard approach would allocate a [num_tokens, 128256] tensor for the softmax output. This kernel only stores three scalars per row (loss, m, d) plus the target logit.
The backward kernel also loops over the vocabulary in chunks. For each chunk, it computes softmax(logit) = exp(logit - m) / d using the saved m and d values, subtracts 1 at the target position, and writes the gradient. Like the forward kernel, it never materializes the full softmax vector.
TIP
This kernel is inspired by the Liger-Kernel project from LinkedIn. Liger-Kernel also implements a more advanced variant called Fused Linear Cross-Entropy that fuses the final linear projection (hidden_states @ lm_head_weight) with the cross-entropy loss, computing logits chunk-by-chunk and never materializing them at all. This is even more memory-efficient but significantly more complex (it requires tiled matrix multiplication within the kernel). See the Next Steps section for pointers.
Run the correctness and memory tests:
python cross_entropy_test.py
Expected output:
Cross-Entropy Correctness Tests
============================================================
Test 1: Float32
FP32 Loss — ref: 12.331120 custom: 12.331120 diff: 0.00e+00 PASSED
FP32 Gradient — max diff: 9.09e-13 PASSED
Test 2: BFloat16 (relaxed tolerance)
BF16 Loss — ref: 12.250000 custom: 12.247243 diff: 2.76e-03 PASSED
BF16 Gradient — max diff: 2.98e-08 PASSED
Memory Comparison
------------------------------------------------------------
Standard PyTorch CE — peak memory: 504.0 MB
Fused Triton CE — peak memory: 252.0 MB
Memory reduction: 2.0x
============================================================
All cross-entropy tests PASSED
The memory comparison shows that standard PyTorch cross-entropy allocates ~500 MB (for the softmax output and other intermediates), while the fused kernel uses ~250 MB. The 2x reduction measured here understates the real benefit: in the benchmark (Step 10), where memory is measured more precisely per-operation, the reduction is ~6x. The larger benefit appears because the benchmark isolates just the cross-entropy overhead, while this test includes the base logit tensor allocation in both measurements.
NOTE
Cross-entropy involves log(sum(exp(...))), which is numerically sensitive. The online softmax algorithm maintains stability through the running-max trick — subtracting the maximum logit before exponentiating prevents overflow. The loss values should match PyTorch within 1e-5 in FP32 or 1e-2 in BF16.
Run the cross-entropy benchmark:
python benchmark_kernels.py --kernel cross_entropy
Example output:
======================================================================
Cross-Entropy Benchmark — Custom Triton (online softmax) vs. PyTorch
======================================================================
GPU: NVIDIA GB300
Vocabulary size: 128,256 (LLaMA 3.1)
Tokens Custom (us) PyTorch (us) Speedup Custom Mem (MB) PyTorch Mem (MB) Mem Reduction
-------- ------------- -------------- --------- ----------------- ------------------ ---------------
128 311 220 0.71x 32 188 5.9x
256 300 338 1.12x 63 378 6.0x
512 306 676 2.21x 126 752 6.0x
1,024 315 1,277 4.06x 251 1,506 6.0x
How to read these results:
Now re-profile with both custom kernels active:
python profile_baseline.py --use-custom-rmsnorm --use-custom-ce
The profiler output should now show matrix multiplications and FlashAttention as the dominant operations. The RMSNorm and cross-entropy bottlenecks from Step 3 have been eliminated. The remaining operations are already handled by cuBLAS and FlashAttention — the most highly optimized GPU libraries available.
Let's put it all together: run a real fine-tuning loop and measure the end-to-end impact.
First, run the baseline (vanilla PyTorch):
python finetune_baseline.py
Then run the optimized version with both custom kernels:
python finetune_optimized.py
Example comparison:
======================================================================
Baseline Results
======================================================================
Average time per step: 1.842 s
Average throughput: 278 tokens/sec
Peak GPU memory: 112.4 GB
======================================================================
Optimized Results
======================================================================
Average time per step: 1.614 s
Average throughput: 317 tokens/sec
Peak GPU memory: 78.6 GB
How the custom kernels are integrated:
The finetune_optimized.py script uses the "surgical replacement" pattern to swap in custom kernels without modifying the model source code:
# Walk the model tree and collect every LlamaRMSNorm for replacement.
# We collect first, then apply — modifying the tree during iteration is unsafe.
replacements = []
for name, module in model.named_modules():
if type(module).__name__ == "LlamaRMSNorm":
parts = name.split(".")
parent = model.get_submodule(".".join(parts[:-1])) if len(parts) > 1 else model
replacements.append((parent, parts[-1], module))
for parent, attr_name, old_module in replacements:
setattr(parent, attr_name, TritonRMSNorm.from_llama_rmsnorm(old_module))
# Use custom cross-entropy instead of the model's built-in loss
outputs = model(input_ids=input_ids) # Forward without computing loss
logits = outputs.logits[:, :-1, :].contiguous()
loss = custom_ce(logits, labels[:, 1:].contiguous())
This pattern — find modules by type, create optimized replacements, swap them in — is widely used in production inference and training optimization.
NOTE
Amdahl's law in action. An 8x faster RMSNorm does not make training 8x faster. If RMSNorm was 10% of total step time, making it 8x faster saves about 8.75% of total time. The cross-entropy memory reduction has an outsized impact because it frees GPU memory that enables larger batch sizes, which improves GPU utilization across all operations — including the matrix multiplications and attention that dominate the compute profile.
When you're finished, exit the container:
exit
Since we used --rm, the container is automatically removed. Your source code and profiler traces are preserved in the assets/ directory on your host machine. Model weights are cached in ~/.cache/huggingface/ on the host (via the volume mount).
To remove the container image:
WARNING
This deletes the built Docker image. You'll need to rebuild it if you want to use it again.
docker rmi kernel-dev-ft
To remove downloaded model weights cached by Hugging Face:
rm -rf ~/.cache/huggingface/hub/models--meta-llama--Llama-3.1-8B
Next steps:
You've profiled a real training workload, identified the bottlenecks, written custom Triton kernels to address them, and measured end-to-end improvements. Here's where to go next:
lm_head linear projection with the cross-entropy, computing logits chunk-by-chunk and never materializing the full [B*T, V] tensor at all. See Liger-Kernel's FusedLinearCrossEntropy for a production implementation.torch.autograd.Function pattern used in this playbook.pip install liger-kernel and apply_liger_kernel_to_llama(). Compare its throughput against your hand-written kernels.--batch-size 2 or --batch-size 4 to see how GPU utilization improves when more compute work is available per step.