From 4b6b6ace56c8778c7391858cfd6a915e755fe516 Mon Sep 17 00:00:00 2001 From: Sujeeth Jinesh Date: Mon, 4 May 2026 22:03:46 -0700 Subject: [PATCH] No public description PiperOrigin-RevId: 910429402 --- .../_src/multihost/colocated_transport.py | 124 ++------------ .../multihost/colocated_transport_test.py | 57 +++---- .../colocated_controller.py | 105 ++++++++++-- .../colocated_controller_test.py | 157 ++++++++++++++++-- .../colocated_utils.py | 66 +++++++- .../colocated_utils_test.py | 67 ++++++++ .../initialization.py | 52 ++++-- .../initialization_test.py | 60 +++++-- .../replicator_checkpoint_manager.py | 5 +- .../sidecar_worker_checkpoint_manager.py | 14 +- .../sidecar_worker_checkpoint_manager_test.py | 43 ++++- 11 files changed, 536 insertions(+), 214 deletions(-) diff --git a/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py b/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py index 5c486d5b4..47efea9d5 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py @@ -19,7 +19,7 @@ import functools import re import types -from typing import Any, cast, Sequence +from typing import Any, Sequence, cast from absl import logging import jax @@ -32,17 +32,10 @@ PyTree = Any -_PATHWAYS_SERIALIZATION_PATCH_INSTALLED = False +_PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED = False _PJRT_IFRT_DEVICE_ID_RE = re.compile(r'PjRtIFRTDeviceId=(\d+)') -def _to_serializable_cpu_device(device: jax.Device) -> jax.Device: - """Normalizes a device to the CPU device used by colocated Python.""" - if device.platform == 'cpu': - return device - return cp.colocated_cpu_devices((device,))[0] - - def _device_platform(device: jax.Device) -> str: platform = getattr(device, 'platform', None) if platform is not None: @@ -108,19 +101,11 @@ def _all_backend_devices() -> tuple[jax.Device, ...]: def _get_cpu_device_map() -> Mapping[int, jax.Device]: """Builds a worker-side CPU lookup for JAX colocated deserialization. - JAX calls this while unreducing every colocated `Mesh`, `DeviceList`, or - `SingleDeviceSharding`. Backend device topology is stable for this process, - so the map is cached to avoid repeated backend scans. - - The controller serializes colocated CPU objects by controller-visible - `device.id`, which is a backend-global IFRT id in the Pathways runtime. The - worker-side remote-Python CPU backend can expose different local `device.id` - values, so we first register the backend-global IFRT id parsed from repr and - then fall back to the local `device.id` without overwriting the global entry. - - When the IFRT id and local `device.id` namespaces do not collide, this keeps - both lookups working. If the namespaces collide for different devices, the - map is ambiguous and this function fails instead of returning the wrong CPU. + JAX colocated Python serializes `Mesh`, `DeviceList`, and + `SingleDeviceSharding` by integer CPU device id. In Pathways, controller-side + CPU ids can be backend-global IFRT ids while worker-side Python can expose + local CPU ids. Worker CPU reprs include the backend-global + `PjRtIFRTDeviceId`, so register both namespaces when they are unambiguous. """ cpu_device_map: dict[int, jax.Device] = {} backend_devices = _all_backend_devices() @@ -154,98 +139,19 @@ def _get_cpu_device_map() -> Mapping[int, jax.Device]: return types.MappingProxyType(cpu_device_map) -def _normalize_mesh_to_colocated_cpu( - mesh: jax.sharding.Mesh, -) -> jax.sharding.Mesh: - devices = tuple(mesh.devices.flat) - if all(_device_platform(device) == 'cpu' for device in devices): - return mesh - cpu_devices = np.vectorize( - _to_serializable_cpu_device, otypes=[object] - )(mesh.devices) - return jax.sharding.Mesh( - cpu_devices, mesh.axis_names, axis_types=mesh.axis_types - ) - - -def _normalize_device_list_to_colocated_cpu( - device_list: cp_serialization.DeviceList, -) -> cp_serialization.DeviceList: - if all(_device_platform(device) == 'cpu' for device in device_list): - return device_list - return cp_serialization.DeviceList( - tuple(_to_serializable_cpu_device(device) for device in device_list) - ) - - -def _normalize_single_device_sharding_to_colocated_cpu( - sharding: jax.sharding.SingleDeviceSharding, -) -> jax.sharding.SingleDeviceSharding: - device = next(iter(sharding.device_set)) - if _device_platform(device) == 'cpu': - return sharding - return jax.sharding.SingleDeviceSharding( - _to_serializable_cpu_device(device), memory_kind=sharding.memory_kind - ) - - -def install_pathways_colocated_serialization_patch() -> None: - """Installs a Pathways-aware colocated-python serialization patch. +def install_pathways_colocated_cpu_device_lookup_patch() -> None: + """Installs a narrow Pathways CPU lookup patch for colocated Python. - The live Pathways failures are below Orbax checkpoint semantics. They happen - while JAX is pickling and unpickling callable specializations that contain - mesh-backed shardings. - - The patch is intentionally narrow: - - 1. Keep JAX's existing serialized representation based on integer CPU ids - 2. Normalize any non-CPU mesh/device-list/sharding to colocated CPU devices - before it reaches JAX's reducers - 3. Teach worker-side CPU lookup to recognize backend-global PjRt-IFRT ids, - which are what controller-side CPU `device.id` values correspond to in the - Pathways remote-Python runtime - - This keeps Orbax close to upstream JAX semantics while fixing the exact - controller/proxy/worker identity mismatch seen in Pathways logs. - - The important constraint is that we are not changing the checkpoint contract - or inventing a second serialized format. We are only making JAX's existing - colocated serialization contract portable across the controller/worker CPU-id - namespace split used by Pathways single-controller. - - Tracked at b/503051746 to make proper changes to JAX. + This deliberately leaves JAX's colocated Python reducers unchanged. It only + swaps the CPU id lookup used while deserializing already-CPU shardings/device + lists so Pathways workers can resolve controller-global IFRT CPU ids. """ # pylint: disable=global-statement - global _PATHWAYS_SERIALIZATION_PATCH_INSTALLED - if _PATHWAYS_SERIALIZATION_PATCH_INSTALLED: + global _PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED + if _PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED: return - - original_reduce_mesh = cp_serialization._reduce_mesh # pylint: disable=protected-access - original_reduce_device_list = cp_serialization._reduce_device_list # pylint: disable=protected-access - original_reduce_single_device_sharding = cp_serialization._reduce_single_device_sharding # pylint: disable=protected-access - - def _orbax_reduce_mesh(mesh: jax.sharding.Mesh) -> Any: - return original_reduce_mesh(_normalize_mesh_to_colocated_cpu(mesh)) - - def _orbax_reduce_device_list( - device_list: cp_serialization.DeviceList, - ) -> Any: - return original_reduce_device_list( - _normalize_device_list_to_colocated_cpu(device_list) - ) - - def _orbax_reduce_single_device_sharding( - sharding: jax.sharding.SingleDeviceSharding, - ) -> Any: - return original_reduce_single_device_sharding( - _normalize_single_device_sharding_to_colocated_cpu(sharding) - ) - - cp_serialization._reduce_mesh = _orbax_reduce_mesh # pylint: disable=protected-access - cp_serialization._reduce_device_list = _orbax_reduce_device_list # pylint: disable=protected-access - cp_serialization._reduce_single_device_sharding = _orbax_reduce_single_device_sharding # pylint: disable=protected-access cp_serialization._get_cpu_device_map = _get_cpu_device_map # pylint: disable=protected-access - _PATHWAYS_SERIALIZATION_PATCH_INSTALLED = True + _PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED = True def unique_colocated_cpu_devices( diff --git a/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport_test.py b/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport_test.py index f7823e629..34bf78427 100644 --- a/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport_test.py +++ b/checkpoint/orbax/checkpoint/_src/multihost/colocated_transport_test.py @@ -268,46 +268,35 @@ def test_get_cpu_device_map_returns_immutable_mapping(self): with self.assertRaises(TypeError): device_map[1] = cpu # pytype: disable=unsupported-operands - def test_normalize_mesh_to_colocated_cpu_remaps_non_cpu_devices(self): - cpu0 = mock.Mock(platform='cpu') - cpu1 = mock.Mock(platform='cpu') - tpu0 = mock.Mock(platform='tpu') - tpu1 = mock.Mock(platform='tpu') - - class _FakeMesh: - devices = np.array([tpu0, tpu1], dtype=object) - axis_names = ('d',) - axis_types = None - - mesh = _FakeMesh() - - with mock.patch.object( - colocated_transport, - '_to_serializable_cpu_device', - side_effect=[cpu0, cpu1], - ): - cpu_mesh = colocated_transport._normalize_mesh_to_colocated_cpu( # pytype: disable=wrong-arg-types # pylint: disable=protected-access - mesh - ) - - self.assertEqual(cpu_mesh.axis_names, mesh.axis_names) - self.assertEqual(tuple(cpu_mesh.devices.flat), (cpu0, cpu1)) - - def test_install_pathways_colocated_serialization_patch_is_idempotent(self): - original_reduce_mesh = cp_serialization._reduce_mesh # pylint: disable=protected-access + def test_install_pathways_colocated_cpu_device_lookup_patch_is_idempotent( + self, + ): original_get_cpu_device_map = cp_serialization._get_cpu_device_map # pylint: disable=protected-access + original_installed = ( + colocated_transport._PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED # pylint: disable=protected-access + ) + self.addCleanup( + setattr, + cp_serialization, + '_get_cpu_device_map', + original_get_cpu_device_map, + ) + self.addCleanup( + setattr, + colocated_transport, + '_PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED', + original_installed, + ) + colocated_transport._PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED = False # pylint: disable=protected-access - colocated_transport.install_pathways_colocated_serialization_patch() - patched_reduce_mesh = cp_serialization._reduce_mesh # pylint: disable=protected-access + colocated_transport.install_pathways_colocated_cpu_device_lookup_patch() patched_get_cpu_device_map = cp_serialization._get_cpu_device_map # pylint: disable=protected-access - colocated_transport.install_pathways_colocated_serialization_patch() + colocated_transport.install_pathways_colocated_cpu_device_lookup_patch() - self.assertIsNot(original_reduce_mesh, patched_reduce_mesh) self.assertIs( - patched_reduce_mesh, - cp_serialization._reduce_mesh, # pylint: disable=protected-access + patched_get_cpu_device_map, + colocated_transport._get_cpu_device_map, # pylint: disable=protected-access ) - self.assertIsNot(original_get_cpu_device_map, patched_get_cpu_device_map) self.assertIs( patched_get_cpu_device_map, cp_serialization._get_cpu_device_map, # pylint: disable=protected-access diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller.py index 613f66ac0..a2b2f595d 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller.py @@ -24,6 +24,7 @@ import jax.numpy as jnp from orbax.checkpoint import args as args_lib from orbax.checkpoint import checkpoint_manager +from orbax.checkpoint._src.handlers import base_pytree_checkpoint_handler from orbax.checkpoint._src.handlers import handler_registration from orbax.checkpoint._src.metadata import sharding as sharding_metadata from orbax.checkpoint._src.multihost import colocated_transport @@ -92,7 +93,7 @@ def _step_arr_on_state_devices(step: int, state: PyTree) -> jax.Array: def _scalar_on_dummy( value: Any, dummy: jax.Array, *, dtype: Any ) -> jax.Array: - return jax.device_put(jnp.asarray(value, dtype=dtype), dummy.sharding) + return colocated_utils.make_scalar_on_like(value, dummy, dtype=dtype) def _single_mapping_child(tree: PyTree) -> Any | None: @@ -127,7 +128,7 @@ def __init__( ) self._local_directory = local_directory - colocated_transport.install_pathways_colocated_serialization_patch() + colocated_transport.install_pathways_colocated_cpu_device_lookup_patch() worker_manager_cls = ( sidecar_worker_checkpoint_manager.WorkerCheckpointManager ) @@ -238,10 +239,10 @@ def is_saving_in_progress(self) -> bool: """Returns whether any worker still has async save work in flight.""" result = self._worker_manager.is_saving_in_progress(self._dummy) jax.block_until_ready(result) - value = colocated_utils.require_unanimous_scalar_result( + values = colocated_utils.scalar_result_values( result, op_name='is_saving_in_progress' ) - return bool(value) + return any(bool(value) for value in values) def save( self, @@ -300,9 +301,6 @@ def restore( if resolved_step is None: raise FileNotFoundError(f'No steps found in {self.directory}.') - target_shardings = self._prepare_restore_target_shardings( - state_restore_args - ) step_arr = _scalar_on_dummy(resolved_step, self._dummy, dtype=jnp.int32) # Restore is metadata-driven on the worker side. Avoid sending a large # synthetic restore-spec tree across colocated Python just to describe the @@ -327,17 +325,16 @@ def restore( state = self._rebuild_restored_state( result, state_restore_args.item ) - if jax.tree.structure(state) != jax.tree.structure(target_shardings): + effective_restore_args = self._prepare_effective_restore_args( + state_restore_args + ) + if jax.tree.structure(state) != jax.tree.structure(effective_restore_args): raise ValueError( 'colocated restore produced a pytree structure that does not match ' 'restore_args.' ) - target_specs = jax.tree.map( - lambda leaf, sharding: jax.ShapeDtypeStruct( - leaf.shape, leaf.dtype, sharding=sharding - ), - state, - target_shardings, + target_specs = self._prepare_restore_target_specs( + state, effective_restore_args ) state = colocated_transport.to_final_specs(state, target_specs) return self._finalize_restore_result( @@ -351,6 +348,13 @@ def wait_until_finished(self) -> None: """Blocks until worker-side async save work finishes.""" result = self._worker_manager.wait_until_finished(self._dummy) jax.block_until_ready(result) + if self._persistent_checkpoint_manager is not None: + self._persistent_checkpoint_manager.wait_until_finished() + + def check_for_errors(self) -> None: + """Raises if colocated persistent save work has failed.""" + if self._persistent_checkpoint_manager is not None: + self._persistent_checkpoint_manager.check_for_errors() def close(self) -> None: """Closes worker-side and persistent managers.""" @@ -479,7 +483,7 @@ def _prepare_state_for_save(self, args: args_lib.Composite) -> PyTree: def _prepare_restore_target_shardings( self, - state_restore_args: args_lib.PyTreeRestore, + restore_args: PyTree, ) -> PyTree: """Resolves target shardings for restore.""" def _resolve_sharding(ra: Any) -> jax.sharding.Sharding: @@ -510,7 +514,76 @@ def _resolve_sharding(ra: Any) -> jax.sharding.Sharding: ) return sharding - return jax.tree.map(_resolve_sharding, state_restore_args.restore_args) + return jax.tree.map(_resolve_sharding, restore_args) + + def _prepare_effective_restore_args( + self, + state_restore_args: args_lib.PyTreeRestore, + ) -> PyTree: + """Fills restore args from the caller template like PyTree restore does.""" + restore_args = state_restore_args.restore_args + if restore_args is None: + raise ValueError('colocated restore requires explicit restore_args.') + if state_restore_args.item is None: + return restore_args + return base_pytree_checkpoint_handler._fill_missing_save_or_restore_args( # pylint: disable=protected-access + state_restore_args.item, + restore_args, + mode='restore', + ) + + def _prepare_restore_target_specs( + self, + restored_state: PyTree, + restore_args: PyTree, + ) -> PyTree: + """Builds final restore specs and rejects unsupported restore transforms.""" + target_shardings = self._prepare_restore_target_shardings(restore_args) + + def _target_spec( + leaf: jax.Array, + restore_arg: type_handlers.ArrayRestoreArgs, + sharding: jax.sharding.Sharding, + ) -> jax.ShapeDtypeStruct: + requested_restore_type = restore_arg.restore_type + if requested_restore_type not in (None, jax.Array): + raise NotImplementedError( + 'colocated restore only supports ArrayRestoreArgs restore_type ' + f'jax.Array, got {requested_restore_type}.' + ) + + requested_shape = restore_arg.global_shape + if requested_shape is None: + requested_shape = restore_arg.shape + if requested_shape is not None and tuple(requested_shape) != tuple( + leaf.shape + ): + raise NotImplementedError( + 'colocated restore does not support ArrayRestoreArgs shape ' + 'transforms. Worker-side restore inferred shape ' + f'{leaf.shape}, but restore_args requested shape ' + f'{tuple(requested_shape)}.' + ) + + requested_dtype = restore_arg.dtype + if requested_dtype is not None and jnp.dtype( + requested_dtype + ) != jnp.dtype(leaf.dtype): + raise NotImplementedError( + 'colocated restore does not support ArrayRestoreArgs dtype ' + 'transforms. Worker-side restore inferred dtype ' + f'{leaf.dtype}, but restore_args requested dtype ' + f'{jnp.dtype(requested_dtype)}.' + ) + + return jax.ShapeDtypeStruct(leaf.shape, leaf.dtype, sharding=sharding) + + return jax.tree.map( + _target_spec, + restored_state, + restore_args, + target_shardings, + ) def _rebuild_restored_state( self, diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller_test.py index 7b68f081e..ff41120ef 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_controller_test.py @@ -33,7 +33,7 @@ def _steps_array(steps: list[int]) -> np.ndarray: - padded = steps + [0] * ( + padded = steps + [controller_lib.colocated_utils.NO_STEP_SENTINEL] * ( controller_lib.colocated_utils.MAX_TRACKED_STEPS - len(steps) ) return np.asarray(padded, dtype=np.int32) @@ -67,7 +67,7 @@ def test_init_preserves_global_mesh_device_order(self): return_value=worker_manager, ) as mock_worker_manager, mock.patch.object( controller_lib.colocated_transport, - 'install_pathways_colocated_serialization_patch', + 'install_pathways_colocated_cpu_device_lookup_patch', ) as mock_install_patch, mock.patch.object( controller_lib.colocated_transport, 'unique_colocated_cpu_devices', @@ -112,7 +112,7 @@ def test_init_rejects_empty_colocated_cpu_devices(self): return_value=mock.Mock(), ), mock.patch.object( controller_lib.colocated_transport, - 'install_pathways_colocated_serialization_patch', + 'install_pathways_colocated_cpu_device_lookup_patch', ), mock.patch.object( controller_lib.colocated_transport, 'unique_colocated_cpu_devices', @@ -223,7 +223,7 @@ def test_prepare_restore_target_shardings_uses_restore_args(self): restore_args = {'x': type_handlers.ArrayRestoreArgs(sharding=sharding)} target_shardings = controller._prepare_restore_target_shardings( - args_lib.PyTreeRestore(item=None, restore_args=restore_args), + restore_args, ) self.assertEqual(target_shardings['x'], sharding) @@ -294,10 +294,12 @@ def test_latest_step_handles_max_steps_and_sentinel(self): controller._worker_manager.all_steps.return_value = object() # Worker 1 has steps 1..MAX_TRACKED_STEPS - # Worker 2 has steps 1..MAX_TRACKED_STEPS-1 and padded with 0 + # Worker 2 has steps 1..MAX_TRACKED_STEPS-1 and one sentinel. max_steps = controller_lib.colocated_utils.MAX_TRACKED_STEPS w1_steps = list(range(1, max_steps + 1)) - w2_steps = list(range(1, max_steps)) + [0] + w2_steps = list(range(1, max_steps)) + [ + controller_lib.colocated_utils.NO_STEP_SENTINEL + ] with mock.patch.object( controller_lib.colocated_utils, @@ -370,16 +372,16 @@ def test_is_saving_in_progress_returns_worker_vote(self): with mock.patch.object( controller_lib.colocated_utils, - 'require_unanimous_scalar_result', - return_value=False, - ) as mock_unanimous: + 'scalar_result_values', + return_value=[False, True], + ) as mock_values: in_progress = controller.is_saving_in_progress() - self.assertFalse(in_progress) + self.assertTrue(in_progress) controller._worker_manager.is_saving_in_progress.assert_called_once_with( controller._dummy ) - mock_unanimous.assert_called_once() + mock_values.assert_called_once() def test_wait_until_finished_blocks_on_worker_result(self): controller, _ = self._make_controller_for_restore() @@ -397,6 +399,31 @@ def test_wait_until_finished_blocks_on_worker_result(self): ) mock_block.assert_called_once_with(worker_result) + def test_wait_until_finished_waits_for_persistent_manager(self): + controller, _ = self._make_controller_for_restore() + worker_result = object() + controller._worker_manager = mock.Mock() + controller._worker_manager.wait_until_finished.return_value = worker_result + controller._persistent_checkpoint_manager = mock.Mock() + + controller.wait_until_finished() + + wait_until_finished = ( + controller._persistent_checkpoint_manager.wait_until_finished + ) + wait_until_finished.assert_called_once_with() + + def test_check_for_errors_checks_persistent_manager(self): + controller, _ = self._make_controller_for_restore() + controller._persistent_checkpoint_manager = mock.Mock() + + controller.check_for_errors() + + check_for_errors = ( + controller._persistent_checkpoint_manager.check_for_errors + ) + check_for_errors.assert_called_once_with() + def test_close_shuts_down_persistent_and_worker_managers(self): controller, _ = self._make_controller_for_restore() worker_result = object() @@ -431,7 +458,9 @@ def test_restore_uses_worker_restore_infer_and_reshards(self): 'weights': jax.ShapeDtypeStruct((2,), jnp.float32, sharding=sharding) } restore_args = { - 'weights': type_handlers.ArrayRestoreArgs(sharding=sharding) + 'weights': type_handlers.ArrayRestoreArgs( + sharding=sharding, global_shape=(2,), dtype=jnp.float32 + ) } controller._worker_manager = mock.Mock() controller._worker_manager.restore_infer.return_value = restored_cpu_state @@ -455,6 +484,110 @@ def test_restore_uses_worker_restore_infer_and_reshards(self): np.asarray(result['weights']), np.arange(2, dtype=np.float32) ) + def test_restore_rejects_shape_transform(self): + controller, sharding = self._make_controller_for_restore() + mesh = sharding.mesh + cpu_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + restored_cpu_state = { + 'weights': jax.device_put( + jnp.arange(2, dtype=jnp.float32), + cpu_sharding, + ) + } + template_state = { + 'weights': jax.ShapeDtypeStruct((3,), jnp.float32, sharding=sharding) + } + restore_args = { + 'weights': type_handlers.ArrayRestoreArgs( + sharding=sharding, global_shape=(3,), strict=False + ) + } + controller._worker_manager = mock.Mock() + controller._worker_manager.restore_infer.return_value = restored_cpu_state + + with self.assertRaisesRegex(NotImplementedError, 'shape transforms'): + controller.restore( + 7, + args_lib.Composite( + state=args_lib.PyTreeRestore( + item=template_state, + restore_args=restore_args, + ) + ), + default_item_mode=True, + ) + + def test_restore_rejects_dtype_transform(self): + controller, sharding = self._make_controller_for_restore() + mesh = sharding.mesh + cpu_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + restored_cpu_state = { + 'weights': jax.device_put( + jnp.arange(2, dtype=jnp.float32), + cpu_sharding, + ) + } + template_state = { + 'weights': jax.ShapeDtypeStruct((2,), jnp.int32, sharding=sharding) + } + restore_args = { + 'weights': type_handlers.ArrayRestoreArgs(sharding=sharding) + } + controller._worker_manager = mock.Mock() + controller._worker_manager.restore_infer.return_value = restored_cpu_state + + with self.assertRaisesRegex(NotImplementedError, 'dtype transforms'): + controller.restore( + 7, + args_lib.Composite( + state=args_lib.PyTreeRestore( + item=template_state, + restore_args=restore_args, + ) + ), + default_item_mode=True, + ) + + def test_restore_rejects_non_array_restore_type(self): + controller, sharding = self._make_controller_for_restore() + mesh = sharding.mesh + cpu_sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec() + ) + restored_cpu_state = { + 'weights': jax.device_put( + jnp.arange(2, dtype=jnp.float32), + cpu_sharding, + ) + } + template_state = { + 'weights': jax.ShapeDtypeStruct((2,), jnp.float32, sharding=sharding) + } + restore_args = { + 'weights': type_handlers.ArrayRestoreArgs( + restore_type=np.ndarray, + sharding=sharding, + ) + } + controller._worker_manager = mock.Mock() + controller._worker_manager.restore_infer.return_value = restored_cpu_state + + with self.assertRaisesRegex(NotImplementedError, 'restore_type'): + controller.restore( + 7, + args_lib.Composite( + state=args_lib.PyTreeRestore( + item=template_state, + restore_args=restore_args, + ) + ), + default_item_mode=True, + ) + def test_restore_none_step_resolves_single_explicit_step(self): controller, sharding = self._make_controller_for_restore() mesh = sharding.mesh diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils.py index 05bf4c122..883fca4f4 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils.py @@ -40,11 +40,7 @@ def compute_distributed_to_device_ids( devices: Sequence[jax.Device], ) -> list[list[int]]: """Returns per-worker device ids in slice-major order.""" - worker_groups = pathways.group_devices_by_worker(devices) - sorted_worker_groups = sorted( - worker_groups.items(), - key=lambda item: _worker_key_sort_key(item[0]), - ) + sorted_worker_groups = _sorted_worker_groups(devices) distributed_to_device_ids = [ sorted(d.id for d in worker_devices) for _, worker_devices in sorted_worker_groups @@ -58,6 +54,50 @@ def compute_distributed_to_device_ids( return distributed_to_device_ids +def colocated_cpu_devices_by_worker( + devices: Sequence[jax.Device], +) -> tuple[jax.Device, ...]: + """Returns one colocated CPU device per Pathways worker. + + The ordering matches `compute_distributed_to_device_ids`, so controller-routed + per-worker metadata can use the returned device position as the MTC node rank. + + Args: + devices: A sequence of JAX devices representing all available TPU devices. + """ + sorted_worker_groups = _sorted_worker_groups(devices) + representative_devices = tuple( + min(worker_devices, key=lambda d: d.id) + for _, worker_devices in sorted_worker_groups + ) + cpu_devices = colocated_transport.unique_colocated_cpu_devices( + representative_devices + ) + if len(cpu_devices) != len(representative_devices): + raise ValueError( + 'Expected one unique colocated CPU device per Pathways worker, got ' + f'{len(cpu_devices)} CPU devices for {len(representative_devices)} ' + 'workers.' + ) + logging.vlog( + 1, + 'Resolved %d colocated CPU devices in Pathways worker order: %s', + len(cpu_devices), + cpu_devices, + ) + return cpu_devices + + +def _sorted_worker_groups( + devices: Sequence[jax.Device], +) -> list[tuple[tuple[int, ...], list[jax.Device]]]: + worker_groups = pathways.group_devices_by_worker(devices) + return sorted( + worker_groups.items(), + key=lambda item: _worker_key_sort_key(item[0]), + ) + + def _worker_key_sort_key(worker_key: tuple[int, ...]) -> tuple[int, ...]: """Sorts `(task, slice)` worker keys in slice-major order.""" if len(worker_key) == 2: @@ -172,6 +212,19 @@ def require_unanimous_scalar_result( return values[0] +def require_single_local_scalar_result( + result: jax.Array, *, op_name: str +) -> Any: + """Returns the single local scalar shard value or raises.""" + values = scalar_result_values(result, op_name=op_name) + if len(values) != 1: + raise ValueError( + f'{op_name}: expected exactly one local scalar shard, got ' + f'{len(values)} values: {values[:8]}.' + ) + return values[0] + + def scalar_result_values(result: jax.Array, *, op_name: str) -> list[Any]: """Returns scalar values from workers.""" values = [] @@ -203,6 +256,3 @@ def array_result_values(result: jax.Array, *, op_name: str) -> list[np.ndarray]: if value.ndim == 0: raise ValueError(f'{op_name}: expected array shard value, got scalar.') return values - - - diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils_test.py index a1e0025af..bbc9b2013 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/colocated_utils_test.py @@ -133,6 +133,28 @@ def test_require_unanimous_scalar_result_raises_on_disagreement(self): with self.assertRaisesRegex(RuntimeError, 'workers disagreed'): colocated_utils.require_unanimous_scalar_result(result, op_name='test_op') # pytype: disable=wrong-arg-types + def test_require_single_local_scalar_result_success(self): + result = _FakeResult( + addressable_shards=[_FakeShard(np.asarray(4, dtype=np.int32))] + ) + + value = colocated_utils.require_single_local_scalar_result( # pytype: disable=wrong-arg-types + result, op_name='test_op' + ) + + self.assertEqual(value, 4) + + def test_require_single_local_scalar_result_raises_on_multiple_shards(self): + result = _FakeResult( + addressable_shards=[ + _FakeShard(np.asarray(4, dtype=np.int32)), + _FakeShard(np.asarray(5, dtype=np.int32)), + ] + ) + + with self.assertRaisesRegex(ValueError, 'exactly one local scalar shard'): + colocated_utils.require_single_local_scalar_result(result, op_name='test_op') # pytype: disable=wrong-arg-types + def test_assert_arrays_on_platform(self): arr = self._replicated_array(jnp.array([1, 2], dtype=jnp.int32)) @@ -197,6 +219,51 @@ def test_compute_distributed_to_device_ids_sorted_by_worker_key(self): self.assertEqual(distributed, [[0, 1], [2], [72], [74]]) + def test_colocated_cpu_devices_by_worker_matches_worker_order(self): + devices = [ + _FakeDevice(id=74, virtual_task_index=1, slice_index=1), + _FakeDevice(id=72, virtual_task_index=0, slice_index=1), + _FakeDevice(id=2, virtual_task_index=1, slice_index=0), + _FakeDevice(id=0, virtual_task_index=0, slice_index=0), + _FakeDevice(id=1, virtual_task_index=0, slice_index=0), + ] + cpu_devices = tuple(_FakeDevice(id=i) for i in [100, 102, 172, 174]) + + with mock.patch.object( + colocated_utils.colocated_transport, + 'unique_colocated_cpu_devices', + return_value=cpu_devices, + ) as mock_unique: + result = colocated_utils.colocated_cpu_devices_by_worker( # pytype: disable=wrong-arg-types + devices + ) + + self.assertEqual(result, cpu_devices) + mock_unique.assert_called_once_with(( + _FakeDevice(id=0, virtual_task_index=0, slice_index=0), + _FakeDevice(id=2, virtual_task_index=1, slice_index=0), + _FakeDevice(id=72, virtual_task_index=0, slice_index=1), + _FakeDevice(id=74, virtual_task_index=1, slice_index=1), + )) + + def test_colocated_cpu_devices_by_worker_raises_on_duplicate_cpu(self): + devices = [ + _FakeDevice(id=0, virtual_task_index=0, slice_index=0), + _FakeDevice(id=2, virtual_task_index=1, slice_index=0), + ] + + with mock.patch.object( + colocated_utils.colocated_transport, + 'unique_colocated_cpu_devices', + return_value=(_FakeDevice(id=100),), + ): + with self.assertRaisesRegex( + ValueError, 'one unique colocated CPU device' + ): + colocated_utils.colocated_cpu_devices_by_worker( # pytype: disable=wrong-arg-types + devices + ) + if __name__ == '__main__': absltest.main() diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization.py index bcbb3bed7..98dab4fc6 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization.py @@ -30,6 +30,9 @@ from orbax.checkpoint._src.multihost import dispatchers from orbax.checkpoint._src.multihost import multihost from orbax.checkpoint._src.multihost import multislice +from orbax.checkpoint.experimental.emergency.multi_tier_checkpointing import ( + colocated_utils, +) _REPLICATOR_FILE = 'replicator.yaml' @@ -89,6 +92,25 @@ def _create_replicator_file( logging.info('Replicator file written and renamed successfully.') +def _node_rank_input_array( + colocated_cpu_devices: tuple[jax.Device, ...], +) -> jax.Array: + """Builds a per-worker rank array over colocated CPU devices.""" + node_ranks = np.arange(len(colocated_cpu_devices), dtype=np.int32) + mesh = jax.sharding.Mesh( + np.array(colocated_cpu_devices, dtype=object), ('worker',) + ) + sharding = jax.sharding.NamedSharding( + mesh, jax.sharding.PartitionSpec('worker') + ) + return jax.make_array_from_callback( + node_ranks.shape, + sharding, + lambda idx: node_ranks[idx], + dtype=jnp.int32, + ) + + def _initialize_mtc_colocated( local_checkpoint_directory: epath.Path, backup_interval_minutes: int, @@ -110,33 +132,39 @@ def _initialize_mtc_colocated( """ logging.info( 'Initializing colocated MTC setup: ' - f'process_count={jax.process_count()}, device_count={jax.device_count()}' + f'controller_device_count={jax.device_count()}' ) - colocated_transport.install_pathways_colocated_serialization_patch() + colocated_transport.install_pathways_colocated_cpu_device_lookup_patch() all_devices = jax.devices() - unique_cpu_devices = colocated_transport.unique_colocated_cpu_devices( + colocated_cpu_devices = colocated_utils.colocated_cpu_devices_by_worker( tuple(all_devices) ) + num_nodes = len(colocated_cpu_devices) + if num_nodes == 0: + raise ValueError('No colocated CPU devices found for MTC initialization.') logging.info( - f'Dispatching MTC initialization to {len(unique_cpu_devices)} ' - 'colocated CPU devices.' + f'Dispatching MTC initialization to {num_nodes} colocated CPU devices.' ) - dummy_in = dispatchers.get_dummy_input_array(unique_cpu_devices) + dummy_in = dispatchers.get_dummy_input_array(colocated_cpu_devices) + node_rank_in = _node_rank_input_array(colocated_cpu_devices) local_dir_str = str(local_checkpoint_directory) - def _setup(dummy_arg: jax.Array) -> jax.Array: + def _setup(dummy_arg: jax.Array, node_rank_arg: jax.Array) -> jax.Array: signaling_client.mark_pathways_colocated_runtime_active() - num_nodes = jax.process_count() if num_nodes % num_slices != 0: raise ValueError( 'num_nodes must be divisible by num_slices, got ' f'num_nodes={num_nodes}, num_slices={num_slices}.' ) nodes_per_slice = num_nodes // num_slices - node_rank = jax.process_index() + node_rank = int( + colocated_utils.require_single_local_scalar_result( + node_rank_arg, op_name='mtc_node_rank' + ) + ) if not 0 <= node_rank < num_nodes: raise ValueError( f'Invalid node_rank={node_rank} for num_nodes={num_nodes}.' @@ -181,10 +209,12 @@ def _setup(dummy_arg: jax.Array) -> jax.Array: ) wrapped_setup_fn = colocated_python.colocated_python(_setup) - wrapped_setup_fn = wrapped_setup_fn.specialize(out_specs_fn=lambda x: x) + wrapped_setup_fn = wrapped_setup_fn.specialize( + out_specs_fn=lambda dummy_arg, _node_rank_arg: dummy_arg + ) dispatch_start = time.time() - result = wrapped_setup_fn(dummy_in) + result = wrapped_setup_fn(dummy_in, node_rank_in) jax.block_until_ready(result) logging.info( 'All shards ready (%.1fs total). Setup complete on all hosts.', diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization_test.py index ddf46fdd7..c31ec71ce 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/initialization_test.py @@ -345,6 +345,13 @@ def test_initialize_multi_tier_checkpointing_colocated_success( # Verify standard multi-controller JAX init is bypassed mock_jax_distributed_initialize.assert_not_called() + def test_node_rank_input_array(self): + devices = tuple(jax.devices()[:1]) + + result = initialization._node_rank_input_array(devices) + + np.testing.assert_array_equal(np.asarray(result), np.asarray([0])) + @mock.patch.object(initialization.jax, "make_array_from_callback") @mock.patch.object(initialization.jax, "block_until_ready") @mock.patch.object(initialization, "_block_and_process_restore_dir") @@ -354,12 +361,12 @@ def test_initialize_multi_tier_checkpointing_colocated_success( @mock.patch.object(initialization.colocated_python, "colocated_python") @mock.patch.object( initialization.colocated_transport, - "install_pathways_colocated_serialization_patch", + "install_pathways_colocated_cpu_device_lookup_patch", ) - @mock.patch.object(initialization.jax, "devices") @mock.patch.object( - initialization.colocated_transport, "unique_colocated_cpu_devices" + initialization.colocated_utils, "colocated_cpu_devices_by_worker" ) + @mock.patch.object(initialization.jax, "devices") @mock.patch.object(initialization.jax, "device_count", return_value=8) @mock.patch.object(initialization.jax, "process_index", return_value=0) @mock.patch.object(initialization.jax, "process_count", return_value=1) @@ -368,8 +375,8 @@ def test_initialize_mtc_colocated_marks_sidecar_runtime( mock_process_count, mock_process_index, mock_device_count, - mock_unique_colocated_cpu_devices, mock_devices, + mock_colocated_cpu_devices_by_worker, mock_install_patch, mock_colocated_python, mock_get_dummy_input_array, @@ -385,11 +392,19 @@ def test_initialize_mtc_colocated_marks_sidecar_runtime( self.assertIsNotNone(mock_device_count) dummy_in = mock.Mock(shape=(), sharding="dummy-sharding") - mock_get_dummy_input_array.return_value = dummy_in - mock_devices.return_value = ["tpu0"] - mock_unique_colocated_cpu_devices.return_value = ( - mock.Mock(id=7, process_index=0), + rank_in = mock.Mock( + addressable_shards=[ + mock.Mock(data=np.asarray(0, dtype=np.int32)), + ] ) + mock_get_dummy_input_array.return_value = dummy_in + mock_device = mock.Mock() + mock_device.id = 7 + mock_device.virtual_task_index = 0 + mock_device.slice_index = 0 + mock_devices.return_value = [mock_device] + mock_cpu_device = mock.Mock(id=7, process_index=0) + mock_colocated_cpu_devices_by_worker.return_value = (mock_cpu_device,) mock_make_array_from_callback.return_value = np.asarray(True) def _wrap_setup(fn): @@ -402,9 +417,16 @@ def specialize(self, *, out_specs_fn): mock_colocated_python.side_effect = _wrap_setup - with mock.patch( - "orbax.checkpoint._src.futures.signaling_client.mark_pathways_colocated_runtime_active" - ) as mock_mark_sidecar_runtime: + with ( + mock.patch.object( + initialization, + "_node_rank_input_array", + return_value=rank_in, + ) as mock_node_rank_input, + mock.patch( + "orbax.checkpoint._src.futures.signaling_client.mark_pathways_colocated_runtime_active" + ) as mock_mark_sidecar_runtime, + ): initialization._initialize_mtc_colocated( local_checkpoint_directory=epath.Path("/tmp/mtc"), backup_interval_minutes=15, @@ -415,12 +437,24 @@ def specialize(self, *, out_specs_fn): ) mock_install_patch.assert_called_once_with() - mock_unique_colocated_cpu_devices.assert_called_once_with(("tpu0",)) + mock_colocated_cpu_devices_by_worker.assert_called_once_with((mock_device,)) + mock_get_dummy_input_array.assert_called_once_with((mock_cpu_device,)) + mock_node_rank_input.assert_called_once_with((mock_cpu_device,)) mock_mark_sidecar_runtime.assert_called_once_with() - mock_create_replicator_file.assert_called_once() + mock_create_replicator_file.assert_called_once_with( + epath.Path("/tmp/mtc"), + run_name="test-run", + num_nodes=1, + data_parallelism=1, + node_rank=0, + peer_ranks=[], + backup_interval_minutes=15, + ) mock_wait_for_replicator_file_to_disappear.assert_called_once() mock_block_and_process_restore_dir.assert_called_once() mock_block_until_ready.assert_called_once() + mock_process_count.assert_not_called() + mock_process_index.assert_not_called() @mock.patch.object( initialization, "_wait_for_replicator_file_to_disappear", autospec=True diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager.py index 709ec7567..87da5d84c 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/replicator_checkpoint_manager.py @@ -1011,9 +1011,10 @@ def wait_until_finished(self) -> None: return self._non_null_local_engine.wait_until_finished() def check_for_errors(self) -> None: - if self._local_engine is None: + if self._colocated_controller is not None: + self._colocated_controller.check_for_errors() return None - return self._local_engine.check_for_errors() + return self._non_null_local_engine.check_for_errors() def close(self) -> None: if self._colocated_controller is not None: diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager.py index c7aa3a2f2..1b1f521b9 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager.py @@ -46,7 +46,7 @@ def __init__( save_interval_steps: int, mesh_axis_types: tuple[jax.sharding.AxisType, ...] | None = None, ) -> None: - colocated_transport.install_pathways_colocated_serialization_patch() + colocated_transport.install_pathways_colocated_cpu_device_lookup_patch() signaling_client.mark_pathways_colocated_runtime_active() cpu_mesh = jax.sharding.Mesh( @@ -113,6 +113,7 @@ def restore_infer(self, step_array: jax.Array) -> PyTree: def latest_step(self, dummy_array: jax.Array) -> jax.Array: """Returns latest_step_or_sentinel as a scalar int32.""" + self._rcm.reload() step = self._rcm.latest_step() val = step if step is not None else colocated_utils.NO_STEP_SENTINEL return colocated_utils.make_scalar_on_like( @@ -121,16 +122,21 @@ def latest_step(self, dummy_array: jax.Array) -> jax.Array: def all_steps(self, dummy_array: jax.Array) -> jax.Array: """Returns a fixed-size array of up to colocated_utils.MAX_TRACKED_STEPS local checkpoint steps.""" - local_steps = sorted(self._rcm.all_steps()) + self._rcm.reload() + local_steps = sorted(self._rcm.all_steps(read=True)) # Keep only the latest MAX_TRACKED_STEPS steps if there are more. local_steps = local_steps[-colocated_utils.MAX_TRACKED_STEPS:] # Pad with NO_STEP_SENTINEL if fewer than MAX_TRACKED_STEPS. padded_steps = local_steps + [colocated_utils.NO_STEP_SENTINEL] * ( colocated_utils.MAX_TRACKED_STEPS - len(local_steps) ) - return jax.device_put( - jnp.asarray(padded_steps, dtype=jnp.int32), + def callback(idx): + return np.asarray(padded_steps, dtype=np.int32)[idx] + + return jax.make_array_from_callback( + (colocated_utils.MAX_TRACKED_STEPS,), dummy_array.sharding, + callback, ) def is_saving_in_progress(self, dummy_array: jax.Array) -> jax.Array: diff --git a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager_test.py b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager_test.py index 0a336cc98..280a4ce25 100644 --- a/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager_test.py +++ b/checkpoint/orbax/checkpoint/experimental/emergency/multi_tier_checkpointing/sidecar_worker_checkpoint_manager_test.py @@ -113,6 +113,7 @@ def test_latest_step_delegates_to_rcm_and_packs_result(self): result = self.manager.latest_step(dummy_array) self.assertIs(result, packed_result) + self.manager._rcm.reload.assert_called_once_with() self.manager._rcm.latest_step.assert_called_once_with() mock_make_scalar.assert_called_once_with(9, dummy_array, dtype=jnp.int32) @@ -129,7 +130,12 @@ def test_latest_step_returns_sentinel_when_no_step_exists(self): result = self.manager.latest_step(dummy_array) self.assertIs(result, packed_result) - mock_make_scalar.assert_called_once_with(0, dummy_array, dtype=jnp.int32) + self.manager._rcm.reload.assert_called_once_with() + mock_make_scalar.assert_called_once_with( + sidecar_lib.colocated_utils.NO_STEP_SENTINEL, + dummy_array, + dtype=jnp.int32, + ) def test_all_steps_returns_sorted_fixed_size_array(self): dummy_array = jnp.asarray(0, dtype=jnp.int32) @@ -141,11 +147,36 @@ def test_all_steps_returns_sorted_fixed_size_array(self): max_steps = sidecar_lib.colocated_utils.MAX_TRACKED_STEPS self.assertEqual(steps_array.shape, (max_steps,)) self.assertEqual(steps_array.dtype, np.int32) - expected = [1, 4, 5] + [0] * (max_steps - 3) + expected = [1, 4, 5] + [sidecar_lib.colocated_utils.NO_STEP_SENTINEL] * ( + max_steps - 3 + ) np.testing.assert_array_equal( steps_array, np.asarray(expected, dtype=np.int32) ) - self.manager._rcm.all_steps.assert_called_once_with() + self.manager._rcm.reload.assert_called_once_with() + self.manager._rcm.all_steps.assert_called_once_with(read=True) + + def test_all_steps_uses_callback_construction(self): + dummy_array = jnp.asarray(0, dtype=jnp.int32) + self.manager._rcm.all_steps.return_value = [1] + + with mock.patch.object( + sidecar_lib.jax, + 'make_array_from_callback', + wraps=sidecar_lib.jax.make_array_from_callback, + ) as mock_make_array, mock.patch.object( + sidecar_lib.jax, + 'device_put', + side_effect=AssertionError('device_put should not be used'), + ): + result = self.manager.all_steps(dummy_array) + + np.testing.assert_array_equal( + np.asarray(result)[:1], np.asarray([1], dtype=np.int32) + ) + self.manager._rcm.reload.assert_called_once_with() + self.manager._rcm.all_steps.assert_called_once_with(read=True) + mock_make_array.assert_called_once() def test_all_steps_limits_to_latest_max_steps(self): dummy_array = jnp.asarray(0, dtype=jnp.int32) @@ -170,10 +201,12 @@ def test_all_steps_returns_all_sentinels_when_no_steps_exist(self): steps_array = np.asarray(result) max_steps = sidecar_lib.colocated_utils.MAX_TRACKED_STEPS self.assertEqual(steps_array.shape, (max_steps,)) - expected = [0] * max_steps + expected = [sidecar_lib.colocated_utils.NO_STEP_SENTINEL] * max_steps np.testing.assert_array_equal( steps_array, np.asarray(expected, dtype=np.int32) ) + self.manager._rcm.reload.assert_called_once_with() + self.manager._rcm.all_steps.assert_called_once_with(read=True) def test_is_saving_in_progress_delegates_to_rcm_and_packs_result(self): dummy_array = jnp.asarray(True) @@ -212,7 +245,7 @@ def test_init_reconstructs_cpu_mesh_from_local_devices(self): with mock.patch.object( sidecar_lib.colocated_transport, - 'install_pathways_colocated_serialization_patch', + 'install_pathways_colocated_cpu_device_lookup_patch', ) as mock_install_patch, mock.patch.object( sidecar_lib.signaling_client, 'mark_pathways_colocated_runtime_active',