Pretrain Llama 3.1 8B with NVFP4 mixed precision on DGX Station using Megatron Bridge
NVFP4 is a 4-bit floating-point format natively supported by NVIDIA Blackwell Tensor Cores. When applied during pretraining, NVFP4 reduces memory bandwidth and compute cost for matrix multiplications while preserving model quality through mixed-precision accumulation in higher precision (BF16/FP32).
Megatron-Bridge is NVIDIA's library for large-scale distributed training built on top of Megatron-Core.
It provides composable recipe configs for models, optimizers, and mixed-precision strategies — including the first-class bf16_with_nvfp4_mixed recipe used in this playbook.
Combining the two lets you pretrain LLMs at lower memory cost and higher throughput compared to BF16-only training, with minimal accuracy trade-off.
Key benefits:
bf16_mixed, bf16_with_fp8_current_scaling_mixed, and bf16_with_nvfp4_mixed with a single linefirst_last_layers_bf16)Pretrain a Llama 3.1 8B model using Megatron-Bridge with NVFP4 mixed precision on NVIDIA DGX Station.
You'll run a short training loop with mock data to verify the full pipeline end-to-end, compare against a plain BF16 baseline via the --disable-fp4 flag and then learn how to point it at real data if required.
Run settings:
llama3_8b_pretrain_config())MockGPTDataset — synthetic random token IDs, no real corpus)nvcr.io/nvidia/nemo:26.04 containernvidia-smi --query-compute-apps=used_memory sampled every 2 s during the run| Precision | Recipe | Avg step time | Throughput (Model TFLOP/s/GPU) | Peak VRAM |
|---|---|---|---|---|
| BF16 baseline | bf16_mixed() | 9.05 s | ~1399 | 221.6 GB |
| NVFP4 (last-4 BF16) | bf16_with_nvfp4_mixed() + first_last_layers_bf16=True, num_layers_at_end_in_bf16=4 | 5.39 s | ~2347 | 207.8 GB |
NVFP4 is 1.68× faster than BF16 (≈68% higher throughput) with ≈13.8 GB (≈6%) less peak VRAM — the regime NVFP4 was designed for, where matmul FLOPs dominate each step and quantization overhead is amortized over wide linear projections.
torchrun)Verify your setup:
# Check GPU availability and architecture
nvidia-smi
# Verify Python and torch
python3 -c "import torch; print(torch.cuda.get_device_name(0))"
--train-iters 50); longer for real dataeval_iters=0); real data requires a preprocessed Megatron-format datasettorchrun process and remove any checkpoint directories