Optimized JAX

2 HRS

Optimize JAX to run on Spark

Verify system prerequisites

Confirm your NVIDIA Spark system meets the requirements and has GPU access configured.

# Verify GPU access
nvidia-smi

# Verify ARM64 architecture  
uname -m

# Check Docker GPU support
docker run --gpus all --rm nvcr.io/nvidia/cuda:13.0.1-runtime-ubuntu24.04 nvidia-smi

If you see a permission denied error (something like permission denied while trying to connect to the Docker daemon socket), add your user to the docker group so that you don't need to run the command with sudo .

sudo usermod -aG docker $USER
newgrp docker

Clone the playbook repository

git clone https://github.com/NVIDIA/dgx-spark-playbooks

Build the Docker image

WARNING

This command will download a base image and build a container locally to support this environment.

cd jax/assets
docker build -t jax-on-spark .

Launch Docker container

Run the JAX development environment in a Docker container with GPU support and port forwarding for marimo access.

docker run --gpus all --rm -it \
    --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 \
    -p 8080:8080 \
    jax-on-spark

Access the marimo interface

Connect to the marimo notebook server to begin the JAX tutorial.

# Access via web browser
# Navigate to: http://localhost:8080

The interface will load a table-of-contents display and brief introduction to marimo.

Complete the JAX introduction tutorial

Work through the introductory material to understand JAX programming model differences from NumPy.

Navigate to and complete the JAX introduction notebook, which covers:

  • JAX programming model fundamentals
  • Key differences from NumPy
  • Performance evaluation techniques

Implement NumPy baseline

Complete the NumPy-based self-organized map (SOM) implementation to establish a performance baseline.

Work through the NumPy SOM notebook to:

  • Understand the SOM training algorithm
  • Implement the algorithm using familiar NumPy operations
  • Record performance metrics for comparison

Optimize with JAX implementations

Progress through the iteratively refined JAX implementations to see performance improvements.

Complete the JAX SOM notebook sections:

  • Basic JAX port of NumPy implementation
  • Performance-optimized JAX version
  • GPU-accelerated parallel JAX implementation
  • Compare performance across all versions

Validate performance gains

The notebooks will show you how to check the performance of each SOM training implementation; you'll see that that JAX implementations show performance improvements over NumPy baseline (and some will be quite a lot faster).

Visually inspect the SOM training output on random color data to confirm algorithm correctness.

Next steps

Apply JAX optimization techniques to your own NumPy-based machine learning code.

# Example: Profile your existing NumPy code
python -m cProfile your_numpy_script.py

# Then adapt to JAX and compare performance

Try adapting your favorite NumPy algorithms to JAX and measure performance improvements on Blackwell GPU architecture.