Skip to content

Make drjax fully compatible with JAX Explicit Sharding.#39

Open
copybara-service[bot] wants to merge 1 commit into
mainfrom
cl/914948662
Open

Make drjax fully compatible with JAX Explicit Sharding.#39
copybara-service[bot] wants to merge 1 commit into
mainfrom
cl/914948662

Conversation

@copybara-service
Copy link
Copy Markdown

@copybara-service copybara-service Bot commented May 15, 2026

Make drjax fully compatible with JAX Explicit Sharding.

JAX has introduced Explicit Sharding, which strictly enforces mesh axes checks and disables older implicit fallbacks. This CL updates DrJAX to be compatible with the explicit model under both eager and compiled (JIT) execution, and overhauls the test suite to ensure robust, clean, and idiomatic eager/JIT test coverage.

Implementation

  • Explicit Axis Verification in primitives.py -
    We raise a clear ValueError early during abstract evaluation if a required placement axis is missing in an Explicit mesh. This prevents downstream compilation errors while allowing replication fallback under Auto meshes.

  • Robust Eager Fallback in impls.py -
    JAX eager mode fails with AttributeError (on SingleDeviceSharding) or ValueError (on device mismatch) when broadcasting directly to a mesh via jnp.broadcast_to. We resolve this with a robust two-step fallback where we broadcast first to an un-sharded temporary array, and then call jax.sharding.reshard to move it to the target mesh.

Testing

  • Flattened subTest Architecture -
    We extended test targets to run both JIT and eager modes. To avoid test class bloating or complex loops, we flattened tests into sequential self.subTest("eager") and self.subTest("jit") blocks, using local checker functions to eliminate code duplication.

  • JIT Mesh Context Fixes -
    We resolved JIT "mesh mismatch" errors by explicitly binding the active mesh context via with jax.set_mesh(mesh) during compiled execution in tests.

@copybara-service copybara-service Bot force-pushed the cl/914948662 branch 2 times, most recently from 35161d5 to 8de6026 Compare May 15, 2026 17:27
JAX has introduced Explicit Sharding, which strictly enforces mesh axes checks and disables older implicit fallbacks. This CL updates DrJAX to be compatible with the explicit model under both eager and compiled (JIT) execution, and overhauls the test suite to ensure robust, clean, and idiomatic eager/JIT test coverage.

### Implementation

*   Explicit Axis Verification in `primitives.py` -
    We raise a clear `ValueError` early during abstract evaluation if a required placement axis is missing in an `Explicit` mesh. This prevents downstream compilation errors while allowing replication fallback under `Auto` meshes.

*   Robust Eager Fallback in `impls.py` -
    JAX eager mode fails with `AttributeError` (on `SingleDeviceSharding`) or `ValueError` (on device mismatch) when broadcasting directly to a mesh via `jnp.broadcast_to`. We resolve this with a robust two-step fallback where we broadcast first to an un-sharded temporary array, and then call `jax.sharding.reshard` to move it to the target mesh.

### Testing

*   Flattened `subTest` Architecture -
    We extended test targets to run both JIT and eager modes. To avoid test class bloating or complex loops, we flattened tests into sequential `self.subTest("eager")` and `self.subTest("jit")` blocks, using local checker functions to eliminate code duplication.

*   JIT Mesh Context Fixes -
    We resolved JIT "mesh mismatch" errors by explicitly binding the active mesh context via `with jax.set_mesh(mesh)` during compiled execution in tests.

PiperOrigin-RevId: 914948662
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant