Skip to content

Xid 31 MMU FAULT_PDE on RTX 5090 (sm_120) after ~20 distinct CUDA program compilations in a single context #1208

@riichard

Description

@riichard

Description

On RTX 5090 (Blackwell, sm_120) with driver 610.43.02, loading ~20 distinct compiled CUDA programs (via JAX/XLA JIT compilation) into a single CUDA context triggers Xid 31 MMU FAULT_PDE, corrupting the GPU virtual address space and killing the process with CUDA_ERROR_ILLEGAL_ADDRESS.

The same pattern works on RTX 3090 (sm_86, Ampere) and V100 (sm_70, Volta) with hundreds of compilations. The issue is specific to Blackwell.

Environment

Component Version
GPU 3× NVIDIA GeForce RTX 5090 32 GB
Driver 610.43.02 (open kernel)
CUDA 13.3
OS Ubuntu 26.04 LTS, kernel 7.0.0-22-generic
CPU AMD Ryzen 9 9950X
Framework JAX 0.10.1 / XLA (jax[cuda13])

Reproduced on all three identical cards. Not hardware-specific.

Kernel log

NVRM: Xid (PCI:0000:01:00): 31, pid=<varies>, name=python3, channel 0x00000002,
  intr 00000000. MMU Fault: ENGINE GRAPHICS GPC0 GPCCLIENT_T1_0 faulted
  @ 0x0_00019000. Fault is of type FAULT_PDE ACCESS_TYPE_VIRT_READ

Always GPC0, GPCCLIENT_T1_{0,1,6}, low virtual addresses (0x3000–0x19000), FAULT_PDE.

Root cause

Our JAX application created a new jax.jit(jax.value_and_grad(closure)) per optimization run (~20 times in a loop). Each closure captured a different Python scalar, producing a distinct HLO graph → distinct XLA compilation → distinct CUDA program loaded onto the GPU.

After ~20 compilations, the process crashes with Xid 31. This was isolated through bisection:

Test Distinct compilations Result on RTX 5090
Single JIT function reused, 31 min sustained, 97% GPU 1 PASS
New JIT closure per iteration (different captured scalar) ~20 Xid 31 crash
New JIT closure + jax.clear_caches() after each 1 (cleared) PASS
Refactored: cached JIT function, varying data as args 1 (reused) PASS

The crash correlates strictly with accumulated GPU code cache entries, not with compute duration, VRAM pressure, GPU utilization, or any specific CUDA library (cuFFT, cuBLAS, etc.).

Why this matters

Compiling ~20 CUDA programs in a single process is not unusual. This happens in:

  • Hyperparameter sweeps (each trial JIT-compiles a different loss function)
  • Multi-task / multi-objective optimization
  • Dynamic neural architectures (different shapes → different compilations)
  • Any framework that JIT-compiles per-configuration (JAX, Triton, torch.compile)

The driver should handle GPU code cache growth by evicting old entries, not by corrupting page tables.

Workarounds

  1. Reuse JIT compilations (restructure code to pass varying data as arguments)
  2. Call jax.clear_caches() periodically to free GPU code cache
  3. Restart the process every ~15 compilations

Possibly related

Reproducer note

We were unable to produce a self-contained standalone reproducer — the XLA programs from our spectral PDE solver are large enough (~dozens of CUDA kernels per compilation) to exhaust the cache in ~20 compilations, but simplified approximations produce smaller programs that don't trigger it. The full codebase is available on request.

The bisection evidence above (same code passes with 1 compilation, crashes with 20, passes again with cache clearing) is the strongest proof the root cause is cache accumulation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions