Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils


Expand All @@ -36,7 +37,8 @@
_BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')


class CheckpointablesMetadataCompatibilityTest(parameterized.TestCase):
class CheckpointablesMetadataCompatibilityTestBase(parameterized.TestCase):
"""Tests for V1 checkpointables_metadata API against generated Checkpoints."""

def setUp(self) -> None:
super().setUp()
Expand Down Expand Up @@ -94,6 +96,14 @@ def test_checkpointables_metadata_compatibility(
is_direct_checkpoint: bool,
is_pytree: bool,
) -> None:
"""Tests checkpointables_metadata against various checkpoint formats.

Args:
version: v0 or v1.
metadata_present: Whether the checkpoint has metadata files.
is_direct_checkpoint: Whether the checkpoint is a direct checkpoint.
is_pytree: Whether the checkpoint is a pytree checkpoint.
"""
path = compatibility_test_utils.get_checkpoint_path(
version, metadata_present, is_direct_checkpoint, is_pytree
)
Expand Down Expand Up @@ -133,7 +143,11 @@ def test_checkpointables_metadata_compatibility(
else:
expected = self.expected_checkpointables_metadata

test_utils.assert_tree_equal(self, expected, loaded.metadata)
actual = loaded.metadata
if multihost.is_pathways_backend() or jax.process_count() > 1:
expected = compatibility_test_utils.strip_sharding_metadata(expected)
actual = compatibility_test_utils.strip_sharding_metadata(actual)
test_utils.assert_tree_equal(self, expected, actual)
else:
with self.assertRaisesRegex(error_type, expected_error_msg):
ocp.checkpointables_metadata(path)
Expand All @@ -153,6 +167,12 @@ def test_checkpointables_metadata_compatibility(
def test_checkpointables_metadata_non_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests checkpointables_metadata against non-critical corruptions.

Args:
version: The checkpoint version to test against.
alteration: The alteration to apply to the checkpoint.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand All @@ -162,16 +182,24 @@ def test_checkpointables_metadata_non_critical_corruptions(
# Missing sharding metadata results in a pytree identical to expected
# values except sharding metadata is None.
loaded = ocp.checkpointables_metadata(path)
test_utils.assert_tree_equal(
self, self.expected_checkpointables_metadata, loaded.metadata
)
expected = self.expected_checkpointables_metadata
actual = loaded.metadata
if multihost.is_pathways_backend() or jax.process_count() > 1:
expected = compatibility_test_utils.strip_sharding_metadata(expected)
actual = compatibility_test_utils.strip_sharding_metadata(actual)
test_utils.assert_tree_equal(self, expected, actual)

@parameterized.product(
version=['v0', 'v1'],
)
def test_checkpointables_metadata_missing_sharding_corruption(
self, version: str
) -> None:
"""Tests checkpointables_metadata against missing sharding corruption.

Args:
version: The checkpoint version to test against.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand All @@ -193,6 +221,12 @@ def test_checkpointables_metadata_missing_sharding_corruption(
def test_checkpointables_metadata_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests checkpointables_metadata against critical corruptions.

Args:
version: The checkpoint version to test against.
alteration: The alteration to apply to the checkpoint.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from etils import epath
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.sharding_utils import make_single_device_sharding
import orbax.checkpoint.experimental.v1 as ocp
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils


Expand All @@ -36,7 +38,8 @@
_BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')


class LoadCheckpointablesCompatibilityTest(parameterized.TestCase):
class LoadCheckpointablesCompatibilityTestBase(parameterized.TestCase):
"""Tests for V1 load_checkpointables API against generated Checkpoints."""

def setUp(self) -> None:
super().setUp()
Expand All @@ -46,6 +49,12 @@ def setUp(self) -> None:
'b': {'c': jnp.array([1, 2, 3], dtype=jnp.int32)},
}
sharding = make_single_device_sharding(jax.devices()[0])
if multihost.is_pathways_backend() or jax.process_count() > 1:
self.expected_state = compatibility_test_utils.replicate_on_mesh(
self.expected_state
)
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
self.abstract_state = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding),
self.expected_state
Expand Down Expand Up @@ -177,12 +186,32 @@ def test_load_checkpointables_compatibility(
has_pytree_metadata: bool,
handler_registered: bool,
) -> None:
"""Tests load_checkpointables compatibility with v0 and v1 checkpoints.

Args:
version: The version of the checkpoint to load.
checkpointable_names_provided: Whether to provide checkpointable_names to
load_checkpointables.
abstract_checkpointables_provided: Whether to provide
abstract_checkpointables to ocp.load_checkpointables.
names_registered: Whether to register all checkpointable names to a
handler.
metadata_present: Whether the checkpoint has metadata files.
is_direct_checkpoint: Whether the checkpoint is a direct checkpoint.
has_pytree_metadata: Whether the checkpoint has pytree metadata files.
handler_registered: Whether to register all handlers.
"""
path = compatibility_test_utils.get_checkpoint_path(
version, metadata_present, is_direct_checkpoint, has_pytree_metadata
)
if path is None or not path.exists():
self.skipTest('Checkpoint for combination does not exist.')

if (
multihost.is_pathways_backend() or jax.process_count() > 1
) and not abstract_checkpointables_provided:
self.skipTest('Sharding metadata not matching in Pathways/Multiprocess.')

if not checkpointable_names_provided and abstract_checkpointables_provided:
self.skipTest(
'Cannot provide abstract_checkpointables without'
Expand Down Expand Up @@ -252,6 +281,12 @@ def test_load_checkpointables_compatibility(
def test_load_checkpointables_non_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests load_checkpointables with non-critical corruptions.

Args:
version: The checkpoint version to test against.
alteration: The alteration to apply to the checkpoint.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand All @@ -261,7 +296,9 @@ def test_load_checkpointables_non_critical_corruptions(
loaded = ocp.load_checkpointables(
path, abstract_checkpointables=self.abstract_checkpointables
)
test_utils.assert_tree_equal(self, loaded, self.expected_checkpointables)
test_utils.assert_tree_equal(
self, loaded, self.expected_checkpointables
)

@parameterized.product(
version=['v0', 'v1'],
Expand All @@ -273,6 +310,12 @@ def test_load_checkpointables_non_critical_corruptions(
def test_load_checkpointables_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests load_checkpointables with critical corruptions.

Args:
version: The checkpoint version to test against.
alteration: The alteration to apply to the checkpoint.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

"""Tests for V1 load_pytree API against generated V0 and V1 Checkpoints."""

import os
from typing import Tuple, Type

Expand All @@ -21,12 +22,14 @@
from etils import epath
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import test_utils
from orbax.checkpoint._src.sharding_utils import make_single_device_sharding
import orbax.checkpoint.experimental.v1 as ocp
from orbax.checkpoint.experimental.v1._src.context import options as options_lib
from orbax.checkpoint.experimental.v1._src.handlers import registration
from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib
from orbax.checkpoint.experimental.v1._src.synchronization import multihost
from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils


Expand All @@ -37,7 +40,8 @@
_BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')


class LoadPytreeCompatibilityTest(parameterized.TestCase):
class LoadPytreeCompatibilityTestBase(parameterized.TestCase):
"""Tests for V1 load_pytree API against generated Checkpoints."""

def setUp(self) -> None:
super().setUp()
Expand All @@ -47,6 +51,12 @@ def setUp(self) -> None:
'b': {'c': jnp.array([1, 2, 3], dtype=jnp.int32)},
}
sharding = make_single_device_sharding(jax.devices()[0])
if multihost.is_pathways_backend() or jax.process_count() > 1:
self.expected_state = compatibility_test_utils.replicate_on_mesh(
self.expected_state
)
mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',))
sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec())
self.abstract_state = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding),
self.expected_state
Expand Down Expand Up @@ -194,12 +204,33 @@ def test_load_pytree_compatibility(
handler_registered: bool,
pytree_registered: bool,
) -> None:
"""Tests load_pytree against various checkpoint configurations.

Args:
version: The checkpoint version to test against.
checkpointable_name: The name of the checkpointable to load.
abstract_pytree_provided: Whether an abstract pytree is provided to
ocp.load_pytree.
name_registered: Whether a handler is registered for the
checkpointable_name.
metadata_present: Whether the checkpoint has metadata.
is_direct_checkpoint: Whether the checkpoint is a direct checkpoint.
is_pytree: Whether the checkpoint is a pytree checkpoint.
handler_registered: Whether a handler is registered for the handler
typestr derived from checkpoint metadata.
pytree_registered: Whether a handler is registered for the 'pytree' scope.
"""
path = compatibility_test_utils.get_checkpoint_path(
version, metadata_present, is_direct_checkpoint, is_pytree
)
if path is None or not path.exists():
self.skipTest('Checkpoint for combination does not exist.')

if (
multihost.is_pathways_backend() or jax.process_count() > 1
) and not abstract_pytree_provided:
self.skipTest('Sharding metadata not matching in Pathways/Multiprocess.')

registry = self.setup_registry(
path,
checkpointable_name,
Expand Down Expand Up @@ -262,7 +293,12 @@ def test_load_pytree_compatibility(
def test_load_pytree_non_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests load_pytree against checkpoints with non-critical corruptions.

Args:
version: The version of the checkpoint to load.
alteration: The non-critical corruption alteration to apply.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand All @@ -284,6 +320,12 @@ def test_load_pytree_non_critical_corruptions(
def test_load_pytree_critical_corruptions(
self, version: str, alteration: str
) -> None:
"""Tests load_pytree against checkpoints with critical corruptions.

Args:
version: The version of the checkpoint to load.
alteration: The critical corruption alteration to apply.
"""
path = self.base_dir.joinpath(
f'{version}_checkpoints',
'composite_checkpoint',
Expand All @@ -303,6 +345,11 @@ def test_load_pytree_critical_corruptions(
version=['v0', 'v1'],
)
def test_load_incorrect_path(self, version: str) -> None:
"""Tests load_pytree against checkpoints with incorrect paths.

Args:
version: The version of the checkpoint to test against.
"""
checkpoint_path = (
self.base_dir
/ f'{version}_checkpoints'
Expand Down
Loading
Loading