Skip to content

Fix MJX make_data/put_data int dtype mismatches under jax_enable_x64#3334

Open
discobot wants to merge 1 commit into
google-deepmind:mainfrom
discobot:fix/2565-put-data-contact-dtype
Open

Fix MJX make_data/put_data int dtype mismatches under jax_enable_x64#3334
discobot wants to merge 1 commit into
google-deepmind:mainfrom
discobot:fix/2565-put-data-contact-dtype

Conversation

@discobot

Copy link
Copy Markdown

Fixes #2565.

Two families of int fields are involved, both with the same mechanism. For the
contact.geom1/geom2/geom fields in the traceback: _make_data_contact_jax
deliberately allocates them with the canonical JAX int dtype ("let jax pick
contact.geom int precision, for interop with jax_enable_x64"), but
_put_contact copies the raw int32 arrays from MjData.contact, and
jax.device_put preserves them when x64 is enabled. Separately,
smooth.tendon emits ten_wrapadr/ten_wrapnum/wrap_obj with the canonical int
dtype inside jit, while _make_data_jax hardcodes np.int32 for them and
put_data again copies int32 from MjData — this second mismatch hits any
model with tendons (e.g. model/humanoid/humanoid.xml) and breaks make_data
the same way. Either one makes the viewer's AOT-compiled step reject its own
output on the second call.

This change casts both groups to the canonical int dtype (in _put_contact,
_make_data_jax, and _put_data_jax), extending the design intent already
documented in _make_data_contact_jax. Under the default config every cast is
a no-op; under x64 a full pytree dtype audit on the humanoid model shows
make_data, put_data, and jitted Data now agree on all 85 leaves, so they
stay interchangeable for any AOT consumer without touching viewer.py.
(solver_niter is left int32 on purpose: jitted solvers produce it from
weakly-typed Python ints, so it is int32 even under x64.)

Added test_make_matches_put_x64 next to the existing test_make_matches_put:
under jax.enable_x64 it asserts these dtypes follow the canonical int
precision and that mjx.step AOT-compiled against put_data output accepts
make_data output and its own output — the exact viewer failure mode (the
existing test model's fixed tendon exercises the wrap fields). It fails before
this change and passes after; the full io_test.py still passes under the
default config.

Under jax_enable_x64, jit-compiled MJX functions produce Data whose
contact.geom1/geom2/geom and ten_wrapadr/ten_wrapnum/wrap_obj fields carry the
canonical JAX int dtype (int64), but put_data copied these arrays from MjData
as int32 and make_data hardcoded np.int32 for the tendon wrap fields. As a
result a step function AOT-compiled against put_data (or make_data) output
rejected its own output, which broke mjx/viewer.py with x64 enabled. Cast the
contact geom index fields in _put_contact and the tendon wrap index fields in
_make_data_jax/_put_data_jax to the canonical int dtype, following the design
intent documented in _make_data_contact_jax; under the default config this is
a no-op. Fixes google-deepmind#2565.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[MJX] jax_enable_x64 not work in mujoco/mjx/viewer.py with put_data

1 participant