Description
On RTX 5090 (Blackwell, sm_120) with driver 610.43.02, loading ~20 distinct compiled CUDA programs (via JAX/XLA JIT compilation) into a single CUDA context triggers Xid 31 MMU FAULT_PDE, corrupting the GPU virtual address space and killing the process with CUDA_ERROR_ILLEGAL_ADDRESS.
The same pattern works on RTX 3090 (sm_86, Ampere) and V100 (sm_70, Volta) with hundreds of compilations. The issue is specific to Blackwell.
Environment
| Component |
Version |
| GPU |
3× NVIDIA GeForce RTX 5090 32 GB |
| Driver |
610.43.02 (open kernel) |
| CUDA |
13.3 |
| OS |
Ubuntu 26.04 LTS, kernel 7.0.0-22-generic |
| CPU |
AMD Ryzen 9 9950X |
| Framework |
JAX 0.10.1 / XLA (jax[cuda13]) |
Reproduced on all three identical cards. Not hardware-specific.
Kernel log
NVRM: Xid (PCI:0000:01:00): 31, pid=<varies>, name=python3, channel 0x00000002,
intr 00000000. MMU Fault: ENGINE GRAPHICS GPC0 GPCCLIENT_T1_0 faulted
@ 0x0_00019000. Fault is of type FAULT_PDE ACCESS_TYPE_VIRT_READ
Always GPC0, GPCCLIENT_T1_{0,1,6}, low virtual addresses (0x3000–0x19000), FAULT_PDE.
Root cause
Our JAX application created a new jax.jit(jax.value_and_grad(closure)) per optimization run (~20 times in a loop). Each closure captured a different Python scalar, producing a distinct HLO graph → distinct XLA compilation → distinct CUDA program loaded onto the GPU.
After ~20 compilations, the process crashes with Xid 31. This was isolated through bisection:
| Test |
Distinct compilations |
Result on RTX 5090 |
| Single JIT function reused, 31 min sustained, 97% GPU |
1 |
PASS |
| New JIT closure per iteration (different captured scalar) |
~20 |
Xid 31 crash |
New JIT closure + jax.clear_caches() after each |
1 (cleared) |
PASS |
| Refactored: cached JIT function, varying data as args |
1 (reused) |
PASS |
The crash correlates strictly with accumulated GPU code cache entries, not with compute duration, VRAM pressure, GPU utilization, or any specific CUDA library (cuFFT, cuBLAS, etc.).
Why this matters
Compiling ~20 CUDA programs in a single process is not unusual. This happens in:
- Hyperparameter sweeps (each trial JIT-compiles a different loss function)
- Multi-task / multi-objective optimization
- Dynamic neural architectures (different shapes → different compilations)
- Any framework that JIT-compiles per-configuration (JAX, Triton, torch.compile)
The driver should handle GPU code cache growth by evicting old entries, not by corrupting page tables.
Workarounds
- Reuse JIT compilations (restructure code to pass varying data as arguments)
- Call
jax.clear_caches() periodically to free GPU code cache
- Restart the process every ~15 compilations
Possibly related
Reproducer note
We were unable to produce a self-contained standalone reproducer — the XLA programs from our spectral PDE solver are large enough (~dozens of CUDA kernels per compilation) to exhaust the cache in ~20 compilations, but simplified approximations produce smaller programs that don't trigger it. The full codebase is available on request.
The bisection evidence above (same code passes with 1 compilation, crashes with 20, passes again with cache clearing) is the strongest proof the root cause is cache accumulation.
Description
On RTX 5090 (Blackwell, sm_120) with driver 610.43.02, loading ~20 distinct compiled CUDA programs (via JAX/XLA JIT compilation) into a single CUDA context triggers Xid 31 MMU FAULT_PDE, corrupting the GPU virtual address space and killing the process with
CUDA_ERROR_ILLEGAL_ADDRESS.The same pattern works on RTX 3090 (sm_86, Ampere) and V100 (sm_70, Volta) with hundreds of compilations. The issue is specific to Blackwell.
Environment
jax[cuda13])Reproduced on all three identical cards. Not hardware-specific.
Kernel log
Always GPC0, GPCCLIENT_T1_{0,1,6}, low virtual addresses (0x3000–0x19000), FAULT_PDE.
Root cause
Our JAX application created a new
jax.jit(jax.value_and_grad(closure))per optimization run (~20 times in a loop). Each closure captured a different Python scalar, producing a distinct HLO graph → distinct XLA compilation → distinct CUDA program loaded onto the GPU.After ~20 compilations, the process crashes with Xid 31. This was isolated through bisection:
jax.clear_caches()after eachThe crash correlates strictly with accumulated GPU code cache entries, not with compute duration, VRAM pressure, GPU utilization, or any specific CUDA library (cuFFT, cuBLAS, etc.).
Why this matters
Compiling ~20 CUDA programs in a single process is not unusual. This happens in:
The driver should handle GPU code cache growth by evicting old entries, not by corrupting page tables.
Workarounds
jax.clear_caches()periodically to free GPU code cachePossibly related
Reproducer note
We were unable to produce a self-contained standalone reproducer — the XLA programs from our spectral PDE solver are large enough (~dozens of CUDA kernels per compilation) to exhaust the cache in ~20 compilations, but simplified approximations produce smaller programs that don't trigger it. The full codebase is available on request.
The bisection evidence above (same code passes with 1 compilation, crashes with 20, passes again with cache clearing) is the strongest proof the root cause is cache accumulation.