Confirm your NVIDIA Spark system meets the requirements and has GPU access configured.
# Verify GPU access
nvidia-smi
# Verify ARM64 architecture
uname -m
# Check Docker GPU support
docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi
If you see a permission denied error (something like permission denied while trying to connect to the Docker daemon socket), add your user to the docker group so that you don't need to run the command with sudo .
sudo usermod -aG docker $USER
newgrp docker
git clone https://github.com/NVIDIA/dgx-spark-playbooks
WARNING
This command will download a base image and build a container locally to support this environment.
cd dgx-spark-playbooks/nvidia/jax/assets
docker build -t jax-on-spark .
Run the JAX development environment in a Docker container with GPU support and port forwarding for marimo access.
docker run --gpus all --rm -it \
--shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \
-p 8080:8080 \
jax-on-spark
Connect to the marimo notebook server to begin the JAX tutorial.
# Access via web browser
# Navigate to: http://localhost:8080
The interface will load a table-of-contents display and brief introduction to marimo.
Work through the introductory material to understand JAX programming model differences from NumPy.
Navigate to and complete the JAX introduction notebook, which covers:
Complete the NumPy-based self-organized map (SOM) implementation to establish a performance baseline.
Work through the NumPy SOM notebook to:
Progress through the iteratively refined JAX implementations to see performance improvements.
Complete the JAX SOM notebook sections:
The notebooks will show you how to check the performance of each SOM training implementation; you'll see that that JAX implementations show performance improvements over NumPy baseline (and some will be quite a lot faster).
Visually inspect the SOM training output on random color data to confirm algorithm correctness.
Apply JAX optimization techniques to your own NumPy-based machine learning code.
# Example: Profile your existing NumPy code
python -m cProfile your_numpy_script.py
# Then adapt to JAX and compare performance
Try adapting your favorite NumPy algorithms to JAX and measure performance improvements on Blackwell GPU architecture.