Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 15 additions & 109 deletions checkpoint/orbax/checkpoint/_src/multihost/colocated_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import functools
import re
import types
from typing import Any, cast, Sequence
from typing import Any, Sequence, cast

from absl import logging
import jax
Expand All @@ -32,17 +32,10 @@


PyTree = Any
_PATHWAYS_SERIALIZATION_PATCH_INSTALLED = False
_PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED = False
_PJRT_IFRT_DEVICE_ID_RE = re.compile(r'PjRtIFRTDeviceId=(\d+)')


def _to_serializable_cpu_device(device: jax.Device) -> jax.Device:
"""Normalizes a device to the CPU device used by colocated Python."""
if device.platform == 'cpu':
return device
return cp.colocated_cpu_devices((device,))[0]


def _device_platform(device: jax.Device) -> str:
platform = getattr(device, 'platform', None)
if platform is not None:
Expand Down Expand Up @@ -108,19 +101,11 @@ def _all_backend_devices() -> tuple[jax.Device, ...]:
def _get_cpu_device_map() -> Mapping[int, jax.Device]:
"""Builds a worker-side CPU lookup for JAX colocated deserialization.

JAX calls this while unreducing every colocated `Mesh`, `DeviceList`, or
`SingleDeviceSharding`. Backend device topology is stable for this process,
so the map is cached to avoid repeated backend scans.

The controller serializes colocated CPU objects by controller-visible
`device.id`, which is a backend-global IFRT id in the Pathways runtime. The
worker-side remote-Python CPU backend can expose different local `device.id`
values, so we first register the backend-global IFRT id parsed from repr and
then fall back to the local `device.id` without overwriting the global entry.

When the IFRT id and local `device.id` namespaces do not collide, this keeps
both lookups working. If the namespaces collide for different devices, the
map is ambiguous and this function fails instead of returning the wrong CPU.
JAX colocated Python serializes `Mesh`, `DeviceList`, and
`SingleDeviceSharding` by integer CPU device id. In Pathways, controller-side
CPU ids can be backend-global IFRT ids while worker-side Python can expose
local CPU ids. Worker CPU reprs include the backend-global
`PjRtIFRTDeviceId`, so register both namespaces when they are unambiguous.
"""
cpu_device_map: dict[int, jax.Device] = {}
backend_devices = _all_backend_devices()
Expand Down Expand Up @@ -154,98 +139,19 @@ def _get_cpu_device_map() -> Mapping[int, jax.Device]:
return types.MappingProxyType(cpu_device_map)


def _normalize_mesh_to_colocated_cpu(
mesh: jax.sharding.Mesh,
) -> jax.sharding.Mesh:
devices = tuple(mesh.devices.flat)
if all(_device_platform(device) == 'cpu' for device in devices):
return mesh
cpu_devices = np.vectorize(
_to_serializable_cpu_device, otypes=[object]
)(mesh.devices)
return jax.sharding.Mesh(
cpu_devices, mesh.axis_names, axis_types=mesh.axis_types
)


def _normalize_device_list_to_colocated_cpu(
device_list: cp_serialization.DeviceList,
) -> cp_serialization.DeviceList:
if all(_device_platform(device) == 'cpu' for device in device_list):
return device_list
return cp_serialization.DeviceList(
tuple(_to_serializable_cpu_device(device) for device in device_list)
)


def _normalize_single_device_sharding_to_colocated_cpu(
sharding: jax.sharding.SingleDeviceSharding,
) -> jax.sharding.SingleDeviceSharding:
device = next(iter(sharding.device_set))
if _device_platform(device) == 'cpu':
return sharding
return jax.sharding.SingleDeviceSharding(
_to_serializable_cpu_device(device), memory_kind=sharding.memory_kind
)


def install_pathways_colocated_serialization_patch() -> None:
"""Installs a Pathways-aware colocated-python serialization patch.
def install_pathways_colocated_cpu_device_lookup_patch() -> None:
"""Installs a narrow Pathways CPU lookup patch for colocated Python.

The live Pathways failures are below Orbax checkpoint semantics. They happen
while JAX is pickling and unpickling callable specializations that contain
mesh-backed shardings.

The patch is intentionally narrow:

1. Keep JAX's existing serialized representation based on integer CPU ids
2. Normalize any non-CPU mesh/device-list/sharding to colocated CPU devices
before it reaches JAX's reducers
3. Teach worker-side CPU lookup to recognize backend-global PjRt-IFRT ids,
which are what controller-side CPU `device.id` values correspond to in the
Pathways remote-Python runtime

This keeps Orbax close to upstream JAX semantics while fixing the exact
controller/proxy/worker identity mismatch seen in Pathways logs.

The important constraint is that we are not changing the checkpoint contract
or inventing a second serialized format. We are only making JAX's existing
colocated serialization contract portable across the controller/worker CPU-id
namespace split used by Pathways single-controller.

Tracked at b/503051746 to make proper changes to JAX.
This deliberately leaves JAX's colocated Python reducers unchanged. It only
swaps the CPU id lookup used while deserializing already-CPU shardings/device
lists so Pathways workers can resolve controller-global IFRT CPU ids.
"""
# pylint: disable=global-statement
global _PATHWAYS_SERIALIZATION_PATCH_INSTALLED
if _PATHWAYS_SERIALIZATION_PATCH_INSTALLED:
global _PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED
if _PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED:
return

original_reduce_mesh = cp_serialization._reduce_mesh # pylint: disable=protected-access
original_reduce_device_list = cp_serialization._reduce_device_list # pylint: disable=protected-access
original_reduce_single_device_sharding = cp_serialization._reduce_single_device_sharding # pylint: disable=protected-access

def _orbax_reduce_mesh(mesh: jax.sharding.Mesh) -> Any:
return original_reduce_mesh(_normalize_mesh_to_colocated_cpu(mesh))

def _orbax_reduce_device_list(
device_list: cp_serialization.DeviceList,
) -> Any:
return original_reduce_device_list(
_normalize_device_list_to_colocated_cpu(device_list)
)

def _orbax_reduce_single_device_sharding(
sharding: jax.sharding.SingleDeviceSharding,
) -> Any:
return original_reduce_single_device_sharding(
_normalize_single_device_sharding_to_colocated_cpu(sharding)
)

cp_serialization._reduce_mesh = _orbax_reduce_mesh # pylint: disable=protected-access
cp_serialization._reduce_device_list = _orbax_reduce_device_list # pylint: disable=protected-access
cp_serialization._reduce_single_device_sharding = _orbax_reduce_single_device_sharding # pylint: disable=protected-access
cp_serialization._get_cpu_device_map = _get_cpu_device_map # pylint: disable=protected-access
_PATHWAYS_SERIALIZATION_PATCH_INSTALLED = True
_PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED = True


def unique_colocated_cpu_devices(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,46 +268,35 @@ def test_get_cpu_device_map_returns_immutable_mapping(self):
with self.assertRaises(TypeError):
device_map[1] = cpu # pytype: disable=unsupported-operands

def test_normalize_mesh_to_colocated_cpu_remaps_non_cpu_devices(self):
cpu0 = mock.Mock(platform='cpu')
cpu1 = mock.Mock(platform='cpu')
tpu0 = mock.Mock(platform='tpu')
tpu1 = mock.Mock(platform='tpu')

class _FakeMesh:
devices = np.array([tpu0, tpu1], dtype=object)
axis_names = ('d',)
axis_types = None

mesh = _FakeMesh()

with mock.patch.object(
colocated_transport,
'_to_serializable_cpu_device',
side_effect=[cpu0, cpu1],
):
cpu_mesh = colocated_transport._normalize_mesh_to_colocated_cpu( # pytype: disable=wrong-arg-types # pylint: disable=protected-access
mesh
)

self.assertEqual(cpu_mesh.axis_names, mesh.axis_names)
self.assertEqual(tuple(cpu_mesh.devices.flat), (cpu0, cpu1))

def test_install_pathways_colocated_serialization_patch_is_idempotent(self):
original_reduce_mesh = cp_serialization._reduce_mesh # pylint: disable=protected-access
def test_install_pathways_colocated_cpu_device_lookup_patch_is_idempotent(
self,
):
original_get_cpu_device_map = cp_serialization._get_cpu_device_map # pylint: disable=protected-access
original_installed = (
colocated_transport._PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED # pylint: disable=protected-access
)
self.addCleanup(
setattr,
cp_serialization,
'_get_cpu_device_map',
original_get_cpu_device_map,
)
self.addCleanup(
setattr,
colocated_transport,
'_PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED',
original_installed,
)
colocated_transport._PATHWAYS_CPU_DEVICE_LOOKUP_PATCH_INSTALLED = False # pylint: disable=protected-access

colocated_transport.install_pathways_colocated_serialization_patch()
patched_reduce_mesh = cp_serialization._reduce_mesh # pylint: disable=protected-access
colocated_transport.install_pathways_colocated_cpu_device_lookup_patch()
patched_get_cpu_device_map = cp_serialization._get_cpu_device_map # pylint: disable=protected-access
colocated_transport.install_pathways_colocated_serialization_patch()
colocated_transport.install_pathways_colocated_cpu_device_lookup_patch()

self.assertIsNot(original_reduce_mesh, patched_reduce_mesh)
self.assertIs(
patched_reduce_mesh,
cp_serialization._reduce_mesh, # pylint: disable=protected-access
patched_get_cpu_device_map,
colocated_transport._get_cpu_device_map, # pylint: disable=protected-access
)
self.assertIsNot(original_get_cpu_device_map, patched_get_cpu_device_map)
self.assertIs(
patched_get_cpu_device_map,
cp_serialization._get_cpu_device_map, # pylint: disable=protected-access
Expand Down
Loading
Loading