diff --git a/checkpoint/orbax/checkpoint/_src/multihost/dispatchers_test.py b/checkpoint/orbax/checkpoint/_src/multihost/dispatchers_test.py index 5891584c1..1b33aeceb 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/dispatchers_test.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/dispatchers_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio from unittest import mock from absl.testing import absltest @@ -22,6 +23,7 @@ import numpy as np from orbax.checkpoint._src.metadata import sharding as sharding_metadata from orbax.checkpoint._src.multihost import dispatchers +from orbax.checkpoint._src.serialization import jax_array_handlers from orbax.checkpoint._src.serialization import type_handlers from orbax.checkpoint._src.sharding_utils import make_single_device_sharding @@ -307,6 +309,255 @@ def test_dispatch_no_input(self): self.assertEqual(self.mock_device_put.call_count, 3) +class PrngKeyDispatchTest(parameterized.TestCase): + """Tests that PRNG key arrays are handled correctly through the dispatch path. + + With colocated_python PRNG key support, result_specs can use PRNG key dtypes. + The colocated_python framework handles the conversion between PRNG key arrays + and their physical representation at the IFRT boundary transparently. + + When no array_metadata_store is provided, _get_abstract_arrays produces + result_specs with whatever dtype the args have (physical or PRNG key). + When array_metadata_store IS provided and contains PRNG key impl metadata, + _get_abstract_arrays produces result_specs with PRNG key dtypes and + logical shapes. + """ + + @parameterized.parameters('threefry2x32', 'rbg') + def test_result_specs_use_physical_dtype_without_metadata(self, key_impl): + """Without array_metadata_store, result_specs use whatever dtype args have. + + When args have physical dtype (e.g. uint32), specs should also be physical. + + Args: + key_impl: The PRNG implementation name. + """ + key = jax.random.key(0, impl=key_impl) + key_data = jax.random.key_data(key) + + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(np.array(jax.devices()[:1]), ('x',)), + jax.sharding.PartitionSpec(), + ) + + # Args with physical dtype (no metadata store to detect PRNG key). + args = [ + type_handlers.ArrayRestoreArgs( + global_shape=key_data.shape, + dtype=key_data.dtype, + sharding=sharding, + ), + ] + shardings = [sharding] + + result_specs = asyncio.run( + jax_array_handlers._get_abstract_arrays(args, shardings) + ) + + # Without metadata, specs should use the physical dtype from args. + self.assertEqual(result_specs[0].dtype, key_data.dtype) + self.assertFalse( + jax.dtypes.issubdtype(result_specs[0].dtype, jax.dtypes.prng_key) + ) + + @parameterized.parameters('threefry2x32', 'rbg') + def test_result_specs_preserve_prng_dtype_from_args(self, key_impl): + """When args already have PRNG key dtype, specs preserve logical shape. + + This tests the case where _maybe_read_metadata_and_update_restore_args + has already set the PRNG key dtype on the args. _get_abstract_arrays + should detect this and NOT strip trailing dimensions (avoiding the + double-conversion bug). + + Args: + key_impl: The PRNG implementation name. + """ + key = jax.random.key(0, impl=key_impl) + + sharding = jax.sharding.NamedSharding( + jax.sharding.Mesh(np.array(jax.devices()[:1]), ('x',)), + jax.sharding.PartitionSpec(), + ) + + # Args with PRNG key dtype and logical shape (already converted). + args = [ + type_handlers.ArrayRestoreArgs( + global_shape=key.shape, + dtype=key.dtype, + sharding=sharding, + ), + ] + shardings = [sharding] + + result_specs = asyncio.run( + jax_array_handlers._get_abstract_arrays(args, shardings) + ) + + # Specs should preserve PRNG dtype and logical shape. + self.assertTrue( + jax.dtypes.issubdtype(result_specs[0].dtype, jax.dtypes.prng_key), + f'Expected prng_key dtype, got {result_specs[0].dtype}', + ) + self.assertEqual(result_specs[0].shape, key.shape) + + @parameterized.parameters('threefry2x32', 'rbg') + def test_dispatch_result_dtype_matches_result_specs(self, key_impl): + """Verifies that dispatch result dtypes match the provided result_specs. + + With PRNG key support in colocated_python (cl/907183304), result_specs + can use PRNG key dtypes. The worker wraps physical data into PRNG keys, + and colocated_python handles the transport transparently. + + Args: + key_impl: The PRNG implementation name. + """ + arr = _get_mock_dispatcher_array() + key = jax.random.key(42, impl=key_impl) + + # Worker wraps physical data into PRNG keys and returns them. + fn = mock.MagicMock(return_value=[key]) + + with mock.patch('jax.block_until_ready'), mock.patch( + 'jax.device_put', side_effect=lambda x, d, may_alias=False: x + ), mock.patch.object( + cp, 'colocated_python', autospec=True + ) as mock_cp, mock.patch.object( + cp, 'colocated_cpu_devices', autospec=True + ) as mock_devices: + + mock_specialize = mock.MagicMock() + + def cp_decorator(f): + def unspecialized_wrapper(*_a, **_kw): + raise RuntimeError('Unspecialized wrapper called.') + + def specialized_wrapper(*a, **kw): + return f(*a, **kw) + + mock_specialize.return_value = specialized_wrapper + unspecialized_wrapper.specialize = mock_specialize + return unspecialized_wrapper + + mock_cp.side_effect = cp_decorator + + def colocated_cpu_devices_side_effect(arg): + if isinstance(arg, jax.sharding.Mesh): + return jax.sharding.Mesh(np.array(jax.devices()), arg.axis_names) + return list(arg) + + mock_devices.side_effect = colocated_cpu_devices_side_effect + + dispatcher = dispatchers.ColocatedPythonDispatcher() + + # result_specs use PRNG key dtype (matching worker output). + # Scalar keys need replicated sharding. + key_sharding = jax.sharding.NamedSharding( + arr.sharding.mesh, jax.sharding.PartitionSpec() + ) + prng_specs = [ + jax.ShapeDtypeStruct( + shape=key.shape, + dtype=key.dtype, + sharding=key_sharding, + ) + ] + + result = dispatcher.dispatch( + fn, input_arrays=arr, result_specs=prng_specs + ) + + # Result dtype must match specs — PRNG key dtype. + self.assertTrue( + jax.dtypes.issubdtype(result[0].dtype, jax.dtypes.prng_key), + 'Dispatch result should be a PRNGKeyArray when result_specs' + f' use PRNG key dtype. Got dtype={result[0].dtype}.', + ) + + @parameterized.parameters('threefry2x32', 'rbg') + def test_dispatch_mixed_tree_dtypes_match_result_specs(self, key_impl): + """Tests the common case: a mixed tree of regular arrays and PRNG keys. + + This mirrors the actual orbax restore path where _sync_deserialize_arrays + returns a list containing both bfloat16 model arrays and PRNG key arrays. + All result dtypes must match their corresponding result_specs. + + Args: + key_impl: The PRNG implementation name. + """ + arr = _get_mock_dispatcher_array() + key = jax.random.key(42, impl=key_impl) + + # Worker returns wrapped PRNG key arrays (the desired behavior). + fn = mock.MagicMock(return_value=[arr, key]) + + with mock.patch('jax.block_until_ready'), mock.patch( + 'jax.device_put', side_effect=lambda x, d, may_alias=False: x + ), mock.patch.object( + cp, 'colocated_python', autospec=True + ) as mock_cp, mock.patch.object( + cp, 'colocated_cpu_devices', autospec=True + ) as mock_devices: + + mock_specialize = mock.MagicMock() + + def cp_decorator(f): + def unspecialized_wrapper(*_a, **_kw): + raise RuntimeError('Unspecialized wrapper called.') + + def specialized_wrapper(*a, **kw): + return f(*a, **kw) + + mock_specialize.return_value = specialized_wrapper + unspecialized_wrapper.specialize = mock_specialize + return unspecialized_wrapper + + mock_cp.side_effect = cp_decorator + + def colocated_cpu_devices_side_effect(arg): + if isinstance(arg, jax.sharding.Mesh): + return jax.sharding.Mesh(np.array(jax.devices()), arg.axis_names) + return list(arg) + + mock_devices.side_effect = colocated_cpu_devices_side_effect + + dispatcher = dispatchers.ColocatedPythonDispatcher() + + # Mixed result_specs: regular array + PRNG key (PRNG key dtype). + # Scalar keys need replicated sharding. + key_sharding = jax.sharding.NamedSharding( + arr.sharding.mesh, jax.sharding.PartitionSpec() + ) + result_specs = [ + jax.ShapeDtypeStruct( + shape=arr.shape, + dtype=arr.dtype, + sharding=arr.sharding, + ), + jax.ShapeDtypeStruct( + shape=key.shape, + dtype=key.dtype, + sharding=key_sharding, + ), + ] + + result = dispatcher.dispatch( + fn, input_arrays=arr, result_specs=result_specs + ) + + # All result dtypes must match their corresponding specs. + self.assertEqual( + result[0].dtype, + result_specs[0].dtype, + 'Regular array dtype mismatch.', + ) + # PRNG key element should be a PRNGKeyArray. + self.assertTrue( + jax.dtypes.issubdtype(result[1].dtype, jax.dtypes.prng_key), + 'PRNG key result should be a PRNGKeyArray when result_specs' + f' use PRNG key dtype. Got dtype={result[1].dtype}.', + ) + + if __name__ == '__main__': diff --git a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py index d7cf764f4..95e436af3 100644 --- a/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py +++ b/checkpoint/orbax/checkpoint/_src/serialization/jax_array_handlers.py @@ -803,19 +803,36 @@ async def _async_deserialize( # set dtype=None to deserialize for random keys dtype = None arg_global_shape = _get_underlying_shape(base_shape, arg.dtype) + if isinstance(sharding, jax.sharding.NamedSharding): + # Extend PartitionSpec with None dims for key trailing shape + # instead of forcing replicated, so local-mode reads work. + key_trailing_ndim = ( + len(arg_global_shape) - len(base_shape) + if arg_global_shape and base_shape + else 0 + ) + physical_spec = jax.sharding.PartitionSpec( + *sharding.spec, *([None] * key_trailing_ndim) + ) + sharding_for_read = jax.sharding.NamedSharding( + sharding.mesh, physical_spec + ) + else: + sharding_for_read = sharding else: dtype = arg.dtype arg_global_shape = base_shape + sharding_for_read = sharding if logging.vlog_is_on(1): logging.vlog(1, 'tspec = %s', tspec) logging.vlog(1, 'info = %s', info) logging.vlog(1, 'arg = %s', arg) logging.vlog(1, 'dtype = %s', dtype) - logging.vlog(1, 'sharding = %s', sharding) + logging.vlog(1, 'sharding = %s', sharding_for_read) deserialize_ops.append( serialization.async_deserialize( - sharding, + sharding_for_read, tspec, global_shape=arg_global_shape, dtype=dtype, @@ -913,22 +930,96 @@ def _sync_deserialize_arrays( ) -def _get_abstract_arrays( +async def _get_abstract_arrays( args: Sequence[types.RestoreArgs], shardings: Sequence[jax.sharding.Sharding], + array_metadata_store: array_metadata_store_lib.Store | None = None, + infos: Sequence[types.ParamInfo] | None = None, ) -> Sequence[jax.ShapeDtypeStruct]: - """Returns result specs for the given restore args.""" - abstract_arrays = [] - for arg, sharding in zip(args, shardings): + """Returns result specs for dispatchers. + + Computes ShapeDtypeStruct specs that describe the expected output of the + dispatched worker function. For PRNG key parameters (detected via + array_metadata_store), the specs use PRNG key dtypes and logical shapes. + The colocated_python framework handles the PRNG key <-> physical + conversion at the IFRT boundary. + + Args: + args: ArrayRestoreArgs for each parameter. + shardings: Shardings for each parameter. + array_metadata_store: Store to read PRNG key impl metadata from. + infos: ParamInfo for each parameter. + + Returns: + Sequence of ShapeDtypeStruct result specs for the dispatcher. + """ + metadatas_cache: dict[str, Any] = {} + if array_metadata_store is not None and infos: + array_metadatas = await array_metadata_store.read( + checkpoint_dir=infos[0].parent_dir, + ) + if array_metadatas: + if isinstance(array_metadatas, dict): + target_list = next(iter(array_metadatas.values())) + else: + target_list = array_metadatas + metadatas_cache = {meta.param_name: meta for meta in target_list} + + abstract_arrays: list[jax.ShapeDtypeStruct] = [] + for i, (arg, sharding) in enumerate(zip(args, shardings)): assert isinstance(arg, ArrayRestoreArgs) assert arg.global_shape is not None assert arg.dtype is not None if sharding is None: raise ValueError('Sharding of jax.Array cannot be None.') + + shape = arg.global_shape + dtype = arg.dtype + + if infos and (meta := metadatas_cache.get(infos[i].name)) is not None: + if meta.ext_metadata and isinstance(meta.ext_metadata, dict): + if ( + impl := meta.ext_metadata.get(array_metadata_lib.RANDOM_KEY_IMPL) + ) is not None: + prng_key_dtype = jax.random.key(0, impl=impl).dtype + + if jax.dtypes.issubdtype(dtype, jax.dtypes.prng_key): + # arg.dtype is already a PRNG key dtype, so arg.global_shape + # is already the logical shape. Use it as-is. + dtype = prng_key_dtype + else: + # arg.dtype is physical (e.g. uint32), so arg.global_shape + # is the physical shape. Convert to logical shape. + dtype = prng_key_dtype + key_trailing_shape = jax.eval_shape( + jax.random.key_data, + jax.ShapeDtypeStruct(shape=(), dtype=prng_key_dtype), + ).shape + key_trailing_ndim = len(key_trailing_shape) + shape = shape[:-key_trailing_ndim] if key_trailing_ndim else shape + + # Fix rank mismatch between logical shape and physical sharding. + if isinstance(sharding, jax.sharding.NamedSharding): + original_spec = sharding.spec + # Drop trailing dims to match the rank of the logical shape. + logical_spec = jax.sharding.PartitionSpec( + *original_spec[: len(shape)] + ) + sharding = jax.sharding.NamedSharding(sharding.mesh, logical_spec) + + logging.vlog( + 1, + '_get_abstract_arrays: PRNG key parameter %s: impl=%s,' + ' logical shape=%s, dtype=%s, sharding=%s', + infos[i].name, + impl, + shape, + dtype, + sharding, + ) + abstract_arrays.append( - jax.ShapeDtypeStruct( - shape=arg.global_shape, dtype=arg.dtype, sharding=sharding - ) + jax.ShapeDtypeStruct(shape=shape, dtype=dtype, sharding=sharding) ) return abstract_arrays @@ -1134,9 +1225,32 @@ async def serialize( f' serializable objects. Array.sharding: {v.sharding}' ) + logging.vlog( + 1, + 'serialize: param %s, dtype=%s, shape=%s, sharding=%s', + info.name, + v.dtype, + v.shape, + getattr(v, 'sharding', None), + ) if jax.dtypes.issubdtype(v.dtype, jax.dtypes.prng_key): # a JAX random key - arrays.append(jax.random.key_data(v)) + logging.vlog( + 1, + 'serialize: PRNG key param %s, logical shape=%s, sharding=%s', + info.name, + v.shape, + v.sharding, + ) + key_data = jax.random.key_data(v) + logging.vlog( + 1, + 'serialize: PRNG key param %s, physical shape=%s, sharding=%s', + info.name, + key_data.shape, + key_data.sharding, + ) + arrays.append(key_data) self._ext_metadata[info.name] = { array_metadata_lib.RANDOM_KEY_IMPL: str(jax.random.key_impl(v)) } @@ -1261,9 +1375,12 @@ async def deserialize( args = await self._maybe_read_metadata_and_update_restore_args( infos, args ) + result_specs = await _get_abstract_arrays( + args, shardings, self._array_metadata_store, infos + ) ret = self._dispatcher.dispatch( _sync_deserialize_arrays, - result_specs=_get_abstract_arrays(args, shardings), + result_specs=result_specs, func_kwargs={ 'infos': infos, 'args': args, @@ -1602,7 +1719,12 @@ async def deserialize( ret = self._dispatcher.dispatch( _single_replica_deserialize_on_worker, input_arrays=dummy_input_array, - result_specs=_get_abstract_arrays(args, single_replica_shardings), + result_specs=await _get_abstract_arrays( + args, + single_replica_shardings, + self._array_metadata_store, + infos, + ), func_kwargs={ 'infos': infos, 'args': args,