Add tests and fix for PRNG key handling in Orbax + ColocatedPython.#3179
Merged
Conversation
8b37034 to
23de1a0
Compare
Tests are added to `orbax/checkpoint/_src/multihost/dispatchers_test.py` to demonstrate a type mismatch issue: Orbax's restore process generates `result_specs` with a physical dtype (e.g., uint32) for PRNG keys, but the deserialization returns `PRNGKeyArray` objects with a PRNG key dtype. This mismatch can cause errors in Pathways' IFRT transport layer. This is now fixed in `jax_array_handlers.py` by creating the desired `result_specs` in `_get_abstract_arrays`. PiperOrigin-RevId: 913931497
23de1a0 to
075b32e
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Add tests and fix for PRNG key handling in Orbax + ColocatedPython.
Tests are added to
orbax/checkpoint/_src/multihost/dispatchers_test.pyto demonstrate a type mismatch issue: Orbax's restore process generates
result_specswith a physical dtype (e.g., uint32) for PRNG keys, butthe deserialization returns
PRNGKeyArrayobjects with a PRNG keydtype. This mismatch can cause errors in Pathways' IFRT transport
layer. This is now fixed in
jax_array_handlers.pyby creating the desiredresult_specsin_get_abstract_arrays.