Make drjax fully compatible with JAX Explicit Sharding.#39
Open
copybara-service[bot] wants to merge 1 commit into
Open
Make drjax fully compatible with JAX Explicit Sharding.#39copybara-service[bot] wants to merge 1 commit into
copybara-service[bot] wants to merge 1 commit into
Conversation
35161d5 to
8de6026
Compare
JAX has introduced Explicit Sharding, which strictly enforces mesh axes checks and disables older implicit fallbacks. This CL updates DrJAX to be compatible with the explicit model under both eager and compiled (JIT) execution, and overhauls the test suite to ensure robust, clean, and idiomatic eager/JIT test coverage.
### Implementation
* Explicit Axis Verification in `primitives.py` -
We raise a clear `ValueError` early during abstract evaluation if a required placement axis is missing in an `Explicit` mesh. This prevents downstream compilation errors while allowing replication fallback under `Auto` meshes.
* Robust Eager Fallback in `impls.py` -
JAX eager mode fails with `AttributeError` (on `SingleDeviceSharding`) or `ValueError` (on device mismatch) when broadcasting directly to a mesh via `jnp.broadcast_to`. We resolve this with a robust two-step fallback where we broadcast first to an un-sharded temporary array, and then call `jax.sharding.reshard` to move it to the target mesh.
### Testing
* Flattened `subTest` Architecture -
We extended test targets to run both JIT and eager modes. To avoid test class bloating or complex loops, we flattened tests into sequential `self.subTest("eager")` and `self.subTest("jit")` blocks, using local checker functions to eliminate code duplication.
* JIT Mesh Context Fixes -
We resolved JIT "mesh mismatch" errors by explicitly binding the active mesh context via `with jax.set_mesh(mesh)` during compiled execution in tests.
PiperOrigin-RevId: 914948662
8de6026 to
43e84f5
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Make drjax fully compatible with JAX Explicit Sharding.
JAX has introduced Explicit Sharding, which strictly enforces mesh axes checks and disables older implicit fallbacks. This CL updates DrJAX to be compatible with the explicit model under both eager and compiled (JIT) execution, and overhauls the test suite to ensure robust, clean, and idiomatic eager/JIT test coverage.
Implementation
Explicit Axis Verification in
primitives.py-We raise a clear
ValueErrorearly during abstract evaluation if a required placement axis is missing in anExplicitmesh. This prevents downstream compilation errors while allowing replication fallback underAutomeshes.Robust Eager Fallback in
impls.py-JAX eager mode fails with
AttributeError(onSingleDeviceSharding) orValueError(on device mismatch) when broadcasting directly to a mesh viajnp.broadcast_to. We resolve this with a robust two-step fallback where we broadcast first to an un-sharded temporary array, and then calljax.sharding.reshardto move it to the target mesh.Testing
Flattened
subTestArchitecture -We extended test targets to run both JIT and eager modes. To avoid test class bloating or complex loops, we flattened tests into sequential
self.subTest("eager")andself.subTest("jit")blocks, using local checker functions to eliminate code duplication.JIT Mesh Context Fixes -
We resolved JIT "mesh mismatch" errors by explicitly binding the active mesh context via
with jax.set_mesh(mesh)during compiled execution in tests.