Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion checkpoint/orbax/checkpoint/_src/multihost/dispatchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Provides dispatchers for running functions on multiple workers."""

import abc
from collections.abc import Sequence
from collections.abc import Mapping, Sequence
from typing import Any, Callable

from absl import logging
Expand Down Expand Up @@ -148,6 +148,10 @@ def dispatch(
"""
...

def split_by_slice(self, arrays: PyTree) -> Mapping[int, PyTree]:
"""Splits arrays into per-slice arrays."""
return {0: arrays}


class ColocatedPythonDispatcher(Dispatcher):
"""Dispatches functions using colocated Python."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from orbax.checkpoint._src.serialization import types
from orbax.checkpoint._src.serialization import worker_memory_utils
from orbax.checkpoint._src.tree import utils as tree_utils

import tensorstore as ts

Pytree: TypeAlias = Any
Expand Down Expand Up @@ -217,9 +218,15 @@ def _get_replica_slices(
)
for arr in arrays
]
start_d2h = time.perf_counter()
logging.info(
'[process=%s] Starting D2H numpy array conversion for %d arrays.',
multihost.process_index(),
len(arrays),
)
# D2H copy is performed automatically as part of dispatcher call, but
# we must set properties correctly to pass later consistency checks.
return [
res = [
dataclasses.replace(
rslices,
is_on_host=True,
Expand All @@ -234,6 +241,15 @@ def _get_replica_slices(
)
for rslices in rslices_per_array
]
d2h_duration = time.perf_counter() - start_d2h
logging.info(
'[process=%s] Completed D2H numpy array conversion for %d arrays in %.4f'
' seconds.',
multihost.process_index(),
len(arrays),
d2h_duration,
)
return res


def _worker_serialize_arrays(
Expand All @@ -251,6 +267,8 @@ def _worker_serialize_arrays(
ext_metadata: Dict[str, Any],
):
"""Worker function to serialize arrays."""
if replica_id is None and use_replica_parallel:
replica_id = multihost.process_index()
rslices_per_array = _get_replica_slices(
arrays,
replica_id,
Expand All @@ -270,6 +288,7 @@ def _worker_serialize_arrays(
enable_replica_parallel_separate_folder=enable_replica_parallel_separate_folder,
use_replica_parallel=use_replica_parallel,
ext_metadata=ext_metadata,
replica_id=replica_id,
)
)

Expand Down Expand Up @@ -486,30 +505,42 @@ def _serialize_batch(
batch_args: Sequence[types.SaveArgs],
batch_arrays: Sequence[jax.Array],
):
ret = dispatcher.dispatch(
_worker_serialize_arrays,
input_arrays=batch_arrays,
func_kwargs={
'infos': batch_infos,
'args': batch_args,
'replica_id': replica_id,
'use_replica_parallel': use_replica_parallel,
'min_slice_bytes_for_replica_parallel': (
min_slice_bytes_for_replica_parallel
),
'max_replicas_for_replica_parallel': (
max_replicas_for_replica_parallel
),
'primary_host': primary_host,
'metadata_key': metadata_key,
'array_metadata_store': array_metadata_store,
'enable_replica_parallel_separate_folder': (
enable_replica_parallel_separate_folder
),
'ext_metadata': ext_metadata,
},
)
jax.block_until_ready(ret)
if use_replica_parallel:
per_slice_arrays = dispatcher.split_by_slice(batch_arrays)
else:
per_slice_arrays = {0: batch_arrays}
rets = []
for slice_id, slice_arrays in per_slice_arrays.items():
if replica_id is None and use_replica_parallel:
slice_replica_id = slice_id
else:
slice_replica_id = replica_id
ret = dispatcher.dispatch(
_worker_serialize_arrays,
input_arrays=slice_arrays,
func_kwargs={
'infos': batch_infos,
'args': batch_args,
'replica_id': slice_replica_id,
'use_replica_parallel': use_replica_parallel,
'min_slice_bytes_for_replica_parallel': (
min_slice_bytes_for_replica_parallel
),
'max_replicas_for_replica_parallel': (
max_replicas_for_replica_parallel
),
'primary_host': primary_host,
'metadata_key': metadata_key,
'array_metadata_store': array_metadata_store,
'enable_replica_parallel_separate_folder': (
enable_replica_parallel_separate_folder
),
'ext_metadata': ext_metadata,
},
)
rets.append(ret)
for ret in rets:
jax.block_until_ready(ret)

# Enqueue D2H operation for prioritized values.
if prioritized:
Expand Down Expand Up @@ -572,6 +603,7 @@ async def _async_serialize_replica_slices(
enable_replica_parallel_separate_folder: bool,
use_replica_parallel: bool,
ext_metadata: Dict[str, Any],
replica_id: int | None = None,
) -> None:
"""This function contains the logic from ArrayHandler._background_serialize."""
write_coros = []
Expand All @@ -592,16 +624,20 @@ async def _async_serialize_replica_slices(
1,
)
await info.await_path_creation()
if replica_id is not None:
process_index = replica_id
else:
process_index = ocdbt_utils.get_process_index_for_subdir(
info.is_ocdbt_checkpoint
)
array_write_spec = ts_utils.build_array_write_spec(
info=info,
arg=arg,
global_shape=value.global_shape,
local_shape=value.local_shape,
dtype=value.dtype,
use_ocdbt=info.is_ocdbt_checkpoint,
process_index=ocdbt_utils.get_process_index_for_subdir(
info.is_ocdbt_checkpoint
),
process_index=process_index,
replica_separate_folder=replica_separate_folder,
metadata_key=metadata_key,
ext_metadata=ext_metadata.get(info.name),
Expand Down
27 changes: 24 additions & 3 deletions checkpoint/orbax/checkpoint/_src/serialization/replica_slices.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import dataclasses
import functools
import math
import time
from typing import Optional, Sequence

from absl import logging
Expand Down Expand Up @@ -262,6 +263,8 @@ def get_replica_slices(
Returns:
ReplicaSlices object.
"""
if replica_id is None and use_replica_parallel:
replica_id = multihost.process_index()
Result = tuple[list[ReplicaSlice], Shape]
shard0 = arr.addressable_shards[0]

Expand All @@ -276,6 +279,7 @@ def pick_single_replica() -> Result:
)
for shard in arr.addressable_shards
if shard.replica_id == target_replica_id
or (replica_id is not None and shard.replica_id == 0)
]
local_shape = shard0.data.shape
return rslices, local_shape
Expand Down Expand Up @@ -312,7 +316,10 @@ def maybe_pick_replica_parallel() -> Optional[Result]:
assert shard.data.shape == shard0.data.shape

# Parallelize saving across only `replica_count` replicas.
if shard.replica_id >= replica_count:
eff_replica_id = (
replica_id if replica_id is not None else shard.replica_id
)
if eff_replica_id >= replica_count:
continue

size = local_shape[axis]
Expand All @@ -321,7 +328,7 @@ def maybe_pick_replica_parallel() -> Optional[Result]:
assert slize.step is None
assert slize.stop is None or slize.stop == start + shard.data.shape[axis]

start_offset = shard.replica_id * size
start_offset = eff_replica_id * size
end_offset = start_offset + size
new_slice = slice(start + start_offset, start + end_offset)

Expand Down Expand Up @@ -467,13 +474,19 @@ def async_transfer_slice(
)
for arr in arrays
]
start_d2h_time = time.perf_counter()
logging.info(
'[process=%s] Starting D2H transfer for %d arrays.',
multihost.process_index(),
len(arrays),
)
# Kick off transfers for all replica slices to be saved.
transfers_per_array = [
[async_transfer_slice(rslice) for rslice in rslices.replica_slices]
for rslices in rslices_per_array
]
# Wait for all the transferred data to be ready.
return [
res = [
dataclasses.replace(
rslices,
is_on_host=True,
Expand All @@ -489,3 +502,11 @@ def async_transfer_slice(
)
for rslices, transfers in zip(rslices_per_array, transfers_per_array)
]
d2h_duration = time.perf_counter() - start_d2h_time
logging.info(
'[process=%s] Completed D2H transfer for %d arrays in %.4f seconds.',
multihost.process_index(),
len(arrays),
d2h_duration,
)
return res
46 changes: 28 additions & 18 deletions checkpoint/orbax/checkpoint/_src/serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@
from collections.abc import Mapping
import os
import re
import time
from typing import Any, Dict, Optional, Sequence, Union

from absl import logging

import jax
from jax.experimental import layout
import jax.numpy as jnp
Expand Down Expand Up @@ -212,6 +215,8 @@ async def async_serialize_from_host(
Raises:
KeyError: If `metadata` or `dtype` is not found in the tensorstore spec.
"""
del primary_host

if not rslices_on_host.is_on_host:
raise ValueError('Replica slices have not been transferred to host.')
byte_limiter = byte_limiter or limits.get_byte_limiter()
Expand All @@ -222,30 +227,21 @@ async def async_serialize_from_host(
raise KeyError('`dtype` not found in tensorstore spec.')
context = context or ts_utils.get_ts_context(use_ocdbt=False)

# If primary_host is None, all hosts will checkpoint. This is used
# for checkpointing to local filesystem.
if primary_host is None or multihost.process_index() == primary_host:
await ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
context=context,
transaction=transaction,
)

# `ts.open` runs twice for process `primary_host` because for the first time,
# we just get the future to be awaited upon in the background thread. The
# second one runs with `assume_metadata=True` which does no I/O operation and
# returns the tensorstore object.
# For every process other than `primary_host`, we open with
# `assume_metadata=True`.
start_ts_open = time.perf_counter()
t = await ts.open(
ts.Spec(tensorstore_spec),
create=True,
open=True,
assume_metadata=True,
context=context,
transaction=transaction,
)
ts_open_duration = time.perf_counter() - start_ts_open
logging.info(
'[process=%s] Completed TensorStore open for %s in %.4f seconds.',
multihost.process_index(),
tensorstore_spec,
ts_open_duration,
)

async def write_fragment(fragment: fragments.ConcreteFragment):
"""Writes a single fragment using TensorStore. No copy is performed."""
Expand All @@ -260,11 +256,25 @@ async def write_fragment(fragment: fragments.ConcreteFragment):
can_reference_source_data_indefinitely=True,
)

start_ts_write = time.perf_counter()
logging.info(
'[process=%s] Starting TensorStore disk write for %d fragments.',
multihost.process_index(),
len(rslices_on_host.to_fragments().fragments),
)
write_coros = [
write_fragment(fragment)
for fragment in rslices_on_host.to_fragments().fragments
]
await asyncio.gather(*write_coros)
ts_duration = time.perf_counter() - start_ts_write
logging.info(
'[process=%s] Completed TensorStore disk write for %d fragments in %.4f'
' seconds.',
multihost.process_index(),
len(rslices_on_host.to_fragments().fragments),
ts_duration,
)


def estimate_write_memory_footprint(arr: np.ndarray) -> int:
Expand Down
Loading