Copyright © 2026 NVIDIA Corporation
Use torch.profiler to find training bottlenecks, then write custom Triton kernels to optimize LLaMA 8B fine-tuning
DGX Station puts a full Blackwell GPU on your desk, which makes it an ideal environment for profiling and optimizing GPU kernels used during model training. This playbook walks through a real optimization workflow: profiling a LLaMA 3.1 8B fine-tuning run to identify bottlenecks, then writing custom Triton kernels that eliminate those bottlenecks — specifically a fused RMSNorm and a fused cross-entropy loss using online softmax.
For inference workloads, tools like torch.compile and serving frameworks (vLLM, TensorRT-LLM) already ship highly optimized fused kernels. But training workloads are different. Backward passes double the kernel count, large vocabularies create massive intermediate tensors during loss computation, and torch.compile does not restructure algorithms to avoid these allocations. Projects like Liger-Kernel and Unsloth demonstrate that custom training kernels deliver real results: 20-60% memory reduction and 10-30% throughput improvement.
This playbook uses Triton instead of raw CUDA C++. Triton is a Python-native GPU programming language that JIT-compiles to optimized GPU code — no nvcc compiler, no C++ build systems, no manual thread indexing. It is the standard for custom training kernels: Liger-Kernel, Unsloth, and FlashAttention are all written in Triton.
No prior Triton, CUDA, or GPU programming experience is required. The instructions explain each concept as it comes up.
You will profile a LLaMA 3.1 8B fine-tuning workload, identify the key performance bottlenecks, and write custom Triton kernels that address them.
torch.profiler and interpret the results to identify two targets: RMSNorm (memory-bandwidth-bound) and cross-entropy loss (memory-capacity-bound).Hardware:
Software:
docker run --rm --gpus all nvcr.io/nvidia/cuda:12.8.0-devel-ubuntu24.04 nvidia-smiAll required assets are in the playbook directory nvidia/station-kernel-dev-ft/assets (see the dgx-spark-playbooks repository).
assets/Dockerfile — Development container based on NVIDIA's PyTorch NGC image with Triton, transformers, and profiling dependencies.assets/requirements.txt — Python dependencies installed inside the container.assets/profile_baseline.py — Profiling script that captures a torch.profiler trace of a LLaMA 3.1 8B training step and prints a breakdown of GPU time by operation. Supports flags to enable custom kernels for re-profiling.assets/rmsnorm_kernel.py — Fused RMSNorm Triton kernel with forward and backward passes, wrapped as a drop-in torch.nn.Module replacement. Heavily commented with explanations of each Triton concept.assets/rmsnorm_test.py — Correctness tests comparing the custom RMSNorm against PyTorch's reference implementation (forward and backward, FP32 and BF16).assets/cross_entropy_kernel.py — Fused cross-entropy Triton kernel using online softmax, with forward and backward passes. Processes the vocabulary in chunks to avoid materializing the full logit tensor.assets/cross_entropy_test.py — Correctness tests and memory usage comparison against torch.nn.CrossEntropyLoss.assets/benchmark_kernels.py — Benchmarking script that measures latency, throughput, bandwidth utilization, and peak memory for both custom kernels.assets/finetune_baseline.py — Minimal LLaMA 3.1 8B fine-tuning script using vanilla PyTorch, reporting tokens/sec and peak memory.assets/finetune_optimized.py — Identical fine-tuning script with both custom kernels monkey-patched in for direct comparison.assets/ directory; everything else is discarded.