Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 251 additions & 0 deletions checkpoint/orbax/checkpoint/_src/multihost/dispatchers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from unittest import mock

from absl.testing import absltest
Expand All @@ -22,6 +23,7 @@
import numpy as np
from orbax.checkpoint._src.metadata import sharding as sharding_metadata
from orbax.checkpoint._src.multihost import dispatchers
from orbax.checkpoint._src.serialization import jax_array_handlers
from orbax.checkpoint._src.serialization import type_handlers
from orbax.checkpoint._src.sharding_utils import make_single_device_sharding

Expand Down Expand Up @@ -307,6 +309,255 @@ def test_dispatch_no_input(self):
self.assertEqual(self.mock_device_put.call_count, 3)


class PrngKeyDispatchTest(parameterized.TestCase):
"""Tests that PRNG key arrays are handled correctly through the dispatch path.

With colocated_python PRNG key support, result_specs can use PRNG key dtypes.
The colocated_python framework handles the conversion between PRNG key arrays
and their physical representation at the IFRT boundary transparently.

When no array_metadata_store is provided, _get_abstract_arrays produces
result_specs with whatever dtype the args have (physical or PRNG key).
When array_metadata_store IS provided and contains PRNG key impl metadata,
_get_abstract_arrays produces result_specs with PRNG key dtypes and
logical shapes.
"""

@parameterized.parameters('threefry2x32', 'rbg')
def test_result_specs_use_physical_dtype_without_metadata(self, key_impl):
"""Without array_metadata_store, result_specs use whatever dtype args have.

When args have physical dtype (e.g. uint32), specs should also be physical.

Args:
key_impl: The PRNG implementation name.
"""
key = jax.random.key(0, impl=key_impl)
key_data = jax.random.key_data(key)

sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(np.array(jax.devices()[:1]), ('x',)),
jax.sharding.PartitionSpec(),
)

# Args with physical dtype (no metadata store to detect PRNG key).
args = [
type_handlers.ArrayRestoreArgs(
global_shape=key_data.shape,
dtype=key_data.dtype,
sharding=sharding,
),
]
shardings = [sharding]

result_specs = asyncio.run(
jax_array_handlers._get_abstract_arrays(args, shardings)
)

# Without metadata, specs should use the physical dtype from args.
self.assertEqual(result_specs[0].dtype, key_data.dtype)
self.assertFalse(
jax.dtypes.issubdtype(result_specs[0].dtype, jax.dtypes.prng_key)
)

@parameterized.parameters('threefry2x32', 'rbg')
def test_result_specs_preserve_prng_dtype_from_args(self, key_impl):
"""When args already have PRNG key dtype, specs preserve logical shape.

This tests the case where _maybe_read_metadata_and_update_restore_args
has already set the PRNG key dtype on the args. _get_abstract_arrays
should detect this and NOT strip trailing dimensions (avoiding the
double-conversion bug).

Args:
key_impl: The PRNG implementation name.
"""
key = jax.random.key(0, impl=key_impl)

sharding = jax.sharding.NamedSharding(
jax.sharding.Mesh(np.array(jax.devices()[:1]), ('x',)),
jax.sharding.PartitionSpec(),
)

# Args with PRNG key dtype and logical shape (already converted).
args = [
type_handlers.ArrayRestoreArgs(
global_shape=key.shape,
dtype=key.dtype,
sharding=sharding,
),
]
shardings = [sharding]

result_specs = asyncio.run(
jax_array_handlers._get_abstract_arrays(args, shardings)
)

# Specs should preserve PRNG dtype and logical shape.
self.assertTrue(
jax.dtypes.issubdtype(result_specs[0].dtype, jax.dtypes.prng_key),
f'Expected prng_key dtype, got {result_specs[0].dtype}',
)
self.assertEqual(result_specs[0].shape, key.shape)

@parameterized.parameters('threefry2x32', 'rbg')
def test_dispatch_result_dtype_matches_result_specs(self, key_impl):
"""Verifies that dispatch result dtypes match the provided result_specs.

With PRNG key support in colocated_python (cl/907183304), result_specs
can use PRNG key dtypes. The worker wraps physical data into PRNG keys,
and colocated_python handles the transport transparently.

Args:
key_impl: The PRNG implementation name.
"""
arr = _get_mock_dispatcher_array()
key = jax.random.key(42, impl=key_impl)

# Worker wraps physical data into PRNG keys and returns them.
fn = mock.MagicMock(return_value=[key])

with mock.patch('jax.block_until_ready'), mock.patch(
'jax.device_put', side_effect=lambda x, d, may_alias=False: x
), mock.patch.object(
cp, 'colocated_python', autospec=True
) as mock_cp, mock.patch.object(
cp, 'colocated_cpu_devices', autospec=True
) as mock_devices:

mock_specialize = mock.MagicMock()

def cp_decorator(f):
def unspecialized_wrapper(*_a, **_kw):
raise RuntimeError('Unspecialized wrapper called.')

def specialized_wrapper(*a, **kw):
return f(*a, **kw)

mock_specialize.return_value = specialized_wrapper
unspecialized_wrapper.specialize = mock_specialize
return unspecialized_wrapper

mock_cp.side_effect = cp_decorator

def colocated_cpu_devices_side_effect(arg):
if isinstance(arg, jax.sharding.Mesh):
return jax.sharding.Mesh(np.array(jax.devices()), arg.axis_names)
return list(arg)

mock_devices.side_effect = colocated_cpu_devices_side_effect

dispatcher = dispatchers.ColocatedPythonDispatcher()

# result_specs use PRNG key dtype (matching worker output).
# Scalar keys need replicated sharding.
key_sharding = jax.sharding.NamedSharding(
arr.sharding.mesh, jax.sharding.PartitionSpec()
)
prng_specs = [
jax.ShapeDtypeStruct(
shape=key.shape,
dtype=key.dtype,
sharding=key_sharding,
)
]

result = dispatcher.dispatch(
fn, input_arrays=arr, result_specs=prng_specs
)

# Result dtype must match specs — PRNG key dtype.
self.assertTrue(
jax.dtypes.issubdtype(result[0].dtype, jax.dtypes.prng_key),
'Dispatch result should be a PRNGKeyArray when result_specs'
f' use PRNG key dtype. Got dtype={result[0].dtype}.',
)

@parameterized.parameters('threefry2x32', 'rbg')
def test_dispatch_mixed_tree_dtypes_match_result_specs(self, key_impl):
"""Tests the common case: a mixed tree of regular arrays and PRNG keys.

This mirrors the actual orbax restore path where _sync_deserialize_arrays
returns a list containing both bfloat16 model arrays and PRNG key arrays.
All result dtypes must match their corresponding result_specs.

Args:
key_impl: The PRNG implementation name.
"""
arr = _get_mock_dispatcher_array()
key = jax.random.key(42, impl=key_impl)

# Worker returns wrapped PRNG key arrays (the desired behavior).
fn = mock.MagicMock(return_value=[arr, key])

with mock.patch('jax.block_until_ready'), mock.patch(
'jax.device_put', side_effect=lambda x, d, may_alias=False: x
), mock.patch.object(
cp, 'colocated_python', autospec=True
) as mock_cp, mock.patch.object(
cp, 'colocated_cpu_devices', autospec=True
) as mock_devices:

mock_specialize = mock.MagicMock()

def cp_decorator(f):
def unspecialized_wrapper(*_a, **_kw):
raise RuntimeError('Unspecialized wrapper called.')

def specialized_wrapper(*a, **kw):
return f(*a, **kw)

mock_specialize.return_value = specialized_wrapper
unspecialized_wrapper.specialize = mock_specialize
return unspecialized_wrapper

mock_cp.side_effect = cp_decorator

def colocated_cpu_devices_side_effect(arg):
if isinstance(arg, jax.sharding.Mesh):
return jax.sharding.Mesh(np.array(jax.devices()), arg.axis_names)
return list(arg)

mock_devices.side_effect = colocated_cpu_devices_side_effect

dispatcher = dispatchers.ColocatedPythonDispatcher()

# Mixed result_specs: regular array + PRNG key (PRNG key dtype).
# Scalar keys need replicated sharding.
key_sharding = jax.sharding.NamedSharding(
arr.sharding.mesh, jax.sharding.PartitionSpec()
)
result_specs = [
jax.ShapeDtypeStruct(
shape=arr.shape,
dtype=arr.dtype,
sharding=arr.sharding,
),
jax.ShapeDtypeStruct(
shape=key.shape,
dtype=key.dtype,
sharding=key_sharding,
),
]

result = dispatcher.dispatch(
fn, input_arrays=arr, result_specs=result_specs
)

# All result dtypes must match their corresponding specs.
self.assertEqual(
result[0].dtype,
result_specs[0].dtype,
'Regular array dtype mismatch.',
)
# PRNG key element should be a PRNGKeyArray.
self.assertTrue(
jax.dtypes.issubdtype(result[1].dtype, jax.dtypes.prng_key),
'PRNG key result should be a PRNGKeyArray when result_specs'
f' use PRNG key dtype. Got dtype={result[1].dtype}.',
)




if __name__ == '__main__':
Expand Down
Loading
Loading