Fix MJX make_data/put_data int dtype mismatches under jax_enable_x64#3334
Open
discobot wants to merge 1 commit into
Open
Fix MJX make_data/put_data int dtype mismatches under jax_enable_x64#3334discobot wants to merge 1 commit into
discobot wants to merge 1 commit into
Conversation
Under jax_enable_x64, jit-compiled MJX functions produce Data whose contact.geom1/geom2/geom and ten_wrapadr/ten_wrapnum/wrap_obj fields carry the canonical JAX int dtype (int64), but put_data copied these arrays from MjData as int32 and make_data hardcoded np.int32 for the tendon wrap fields. As a result a step function AOT-compiled against put_data (or make_data) output rejected its own output, which broke mjx/viewer.py with x64 enabled. Cast the contact geom index fields in _put_contact and the tendon wrap index fields in _make_data_jax/_put_data_jax to the canonical int dtype, following the design intent documented in _make_data_contact_jax; under the default config this is a no-op. Fixes google-deepmind#2565.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #2565.
Two families of int fields are involved, both with the same mechanism. For the
contact.geom1/geom2/geomfields in the traceback:_make_data_contact_jaxdeliberately allocates them with the canonical JAX int dtype ("let jax pick
contact.geom int precision, for interop with jax_enable_x64"), but
_put_contactcopies the raw int32 arrays fromMjData.contact, andjax.device_putpreserves them when x64 is enabled. Separately,smooth.tendonemitsten_wrapadr/ten_wrapnum/wrap_objwith the canonical intdtype inside jit, while
_make_data_jaxhardcodesnp.int32for them andput_dataagain copies int32 fromMjData— this second mismatch hits anymodel with tendons (e.g.
model/humanoid/humanoid.xml) and breaksmake_datathe same way. Either one makes the viewer's AOT-compiled step reject its own
output on the second call.
This change casts both groups to the canonical int dtype (in
_put_contact,_make_data_jax, and_put_data_jax), extending the design intent alreadydocumented in
_make_data_contact_jax. Under the default config every cast isa no-op; under x64 a full pytree dtype audit on the humanoid model shows
make_data,put_data, and jittedDatanow agree on all 85 leaves, so theystay interchangeable for any AOT consumer without touching
viewer.py.(
solver_niteris left int32 on purpose: jitted solvers produce it fromweakly-typed Python ints, so it is int32 even under x64.)
Added
test_make_matches_put_x64next to the existingtest_make_matches_put:under
jax.enable_x64it asserts these dtypes follow the canonical intprecision and that
mjx.stepAOT-compiled againstput_dataoutput acceptsmake_dataoutput and its own output — the exact viewer failure mode (theexisting test model's fixed tendon exercises the wrap fields). It fails before
this change and passes after; the full
io_test.pystill passes under thedefault config.