From e390bf78749cb6ba3fd54b01a94721748938af75 Mon Sep 17 00:00:00 2001 From: Angel Mau Date: Mon, 11 May 2026 12:03:25 -0700 Subject: [PATCH] Refactor compatiblity tests for v1 free functions to support internal multihost tests and remove older save_load testing since we are already are covering its functionality in our testing/compatibility base files. PiperOrigin-RevId: 913815631 --- ...ables_metadata_compatibility_test_base.py} | 44 ++- ...heckpointables_compatibility_test_base.py} | 47 ++- ...=> load_pytree_compatibility_test_base.py} | 49 ++- ...ytree_metadata_compatibility_test_base.py} | 51 +++- .../_src/testing/compatibility/test_utils.py | 31 ++ .../v0v1_compatibility_save_load_test_base.py | 288 ------------------ .../testing/v1_compatibility_load_test.py | 59 ++++ 7 files changed, 265 insertions(+), 304 deletions(-) rename checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/{checkpointables_metadata_compatibility_test.py => checkpointables_metadata_compatibility_test_base.py} (80%) rename checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/{load_checkpointables_compatibility_test.py => load_checkpointables_compatibility_test_base.py} (83%) rename checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/{load_pytree_compatibility_test.py => load_pytree_compatibility_test_base.py} (83%) rename checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/{pytree_metadata_compatibility_test.py => pytree_metadata_compatibility_test_base.py} (81%) delete mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v0v1_compatibility_save_load_test_base.py create mode 100644 checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v1_compatibility_load_test.py diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py similarity index 80% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py index 41f64811a..54f601de9 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/checkpointables_metadata_compatibility_test_base.py @@ -26,6 +26,7 @@ from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils @@ -36,7 +37,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class CheckpointablesMetadataCompatibilityTest(parameterized.TestCase): +class CheckpointablesMetadataCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 checkpointables_metadata API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -94,6 +96,14 @@ def test_checkpointables_metadata_compatibility( is_direct_checkpoint: bool, is_pytree: bool, ) -> None: + """Tests checkpointables_metadata against various checkpoint formats. + + Args: + version: v0 or v1. + metadata_present: Whether the checkpoint has metadata files. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + is_pytree: Whether the checkpoint is a pytree checkpoint. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, is_pytree ) @@ -133,7 +143,11 @@ def test_checkpointables_metadata_compatibility( else: expected = self.expected_checkpointables_metadata - test_utils.assert_tree_equal(self, expected, loaded.metadata) + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) else: with self.assertRaisesRegex(error_type, expected_error_msg): ocp.checkpointables_metadata(path) @@ -153,6 +167,12 @@ def test_checkpointables_metadata_compatibility( def test_checkpointables_metadata_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests checkpointables_metadata against non-critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -162,9 +182,12 @@ def test_checkpointables_metadata_non_critical_corruptions( # Missing sharding metadata results in a pytree identical to expected # values except sharding metadata is None. loaded = ocp.checkpointables_metadata(path) - test_utils.assert_tree_equal( - self, self.expected_checkpointables_metadata, loaded.metadata - ) + expected = self.expected_checkpointables_metadata + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) @parameterized.product( version=['v0', 'v1'], @@ -172,6 +195,11 @@ def test_checkpointables_metadata_non_critical_corruptions( def test_checkpointables_metadata_missing_sharding_corruption( self, version: str ) -> None: + """Tests checkpointables_metadata against missing sharding corruption. + + Args: + version: The checkpoint version to test against. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -193,6 +221,12 @@ def test_checkpointables_metadata_missing_sharding_corruption( def test_checkpointables_metadata_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests checkpointables_metadata against critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py similarity index 83% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py index 2868df6c8..ffe275ea0 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_checkpointables_compatibility_test_base.py @@ -20,12 +20,14 @@ from etils import epath import jax import jax.numpy as jnp +import numpy as np from orbax.checkpoint import test_utils from orbax.checkpoint._src.sharding_utils import make_single_device_sharding import orbax.checkpoint.experimental.v1 as ocp from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils @@ -36,7 +38,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class LoadCheckpointablesCompatibilityTest(parameterized.TestCase): +class LoadCheckpointablesCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 load_checkpointables API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -46,6 +49,12 @@ def setUp(self) -> None: 'b': {'c': jnp.array([1, 2, 3], dtype=jnp.int32)}, } sharding = make_single_device_sharding(jax.devices()[0]) + if multihost.is_pathways_backend() or jax.process_count() > 1: + self.expected_state = compatibility_test_utils.replicate_on_mesh( + self.expected_state + ) + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) self.abstract_state = jax.tree.map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding), self.expected_state @@ -177,12 +186,32 @@ def test_load_checkpointables_compatibility( has_pytree_metadata: bool, handler_registered: bool, ) -> None: + """Tests load_checkpointables compatibility with v0 and v1 checkpoints. + + Args: + version: The version of the checkpoint to load. + checkpointable_names_provided: Whether to provide checkpointable_names to + load_checkpointables. + abstract_checkpointables_provided: Whether to provide + abstract_checkpointables to ocp.load_checkpointables. + names_registered: Whether to register all checkpointable names to a + handler. + metadata_present: Whether the checkpoint has metadata files. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + has_pytree_metadata: Whether the checkpoint has pytree metadata files. + handler_registered: Whether to register all handlers. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, has_pytree_metadata ) if path is None or not path.exists(): self.skipTest('Checkpoint for combination does not exist.') + if ( + multihost.is_pathways_backend() or jax.process_count() > 1 + ) and not abstract_checkpointables_provided: + self.skipTest('Sharding metadata not matching in Pathways/Multiprocess.') + if not checkpointable_names_provided and abstract_checkpointables_provided: self.skipTest( 'Cannot provide abstract_checkpointables without' @@ -252,6 +281,12 @@ def test_load_checkpointables_compatibility( def test_load_checkpointables_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_checkpointables with non-critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -261,7 +296,9 @@ def test_load_checkpointables_non_critical_corruptions( loaded = ocp.load_checkpointables( path, abstract_checkpointables=self.abstract_checkpointables ) - test_utils.assert_tree_equal(self, loaded, self.expected_checkpointables) + test_utils.assert_tree_equal( + self, loaded, self.expected_checkpointables + ) @parameterized.product( version=['v0', 'v1'], @@ -273,6 +310,12 @@ def test_load_checkpointables_non_critical_corruptions( def test_load_checkpointables_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_checkpointables with critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py similarity index 83% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py index 5d82a33f5..c9a63def2 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/load_pytree_compatibility_test_base.py @@ -13,6 +13,7 @@ # limitations under the License. """Tests for V1 load_pytree API against generated V0 and V1 Checkpoints.""" + import os from typing import Tuple, Type @@ -21,12 +22,14 @@ from etils import epath import jax import jax.numpy as jnp +import numpy as np from orbax.checkpoint import test_utils from orbax.checkpoint._src.sharding_utils import make_single_device_sharding import orbax.checkpoint.experimental.v1 as ocp from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils @@ -37,7 +40,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class LoadPytreeCompatibilityTest(parameterized.TestCase): +class LoadPytreeCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 load_pytree API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -47,6 +51,12 @@ def setUp(self) -> None: 'b': {'c': jnp.array([1, 2, 3], dtype=jnp.int32)}, } sharding = make_single_device_sharding(jax.devices()[0]) + if multihost.is_pathways_backend() or jax.process_count() > 1: + self.expected_state = compatibility_test_utils.replicate_on_mesh( + self.expected_state + ) + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) self.abstract_state = jax.tree.map( lambda x: jax.ShapeDtypeStruct(x.shape, x.dtype, sharding=sharding), self.expected_state @@ -194,12 +204,33 @@ def test_load_pytree_compatibility( handler_registered: bool, pytree_registered: bool, ) -> None: + """Tests load_pytree against various checkpoint configurations. + + Args: + version: The checkpoint version to test against. + checkpointable_name: The name of the checkpointable to load. + abstract_pytree_provided: Whether an abstract pytree is provided to + ocp.load_pytree. + name_registered: Whether a handler is registered for the + checkpointable_name. + metadata_present: Whether the checkpoint has metadata. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + is_pytree: Whether the checkpoint is a pytree checkpoint. + handler_registered: Whether a handler is registered for the handler + typestr derived from checkpoint metadata. + pytree_registered: Whether a handler is registered for the 'pytree' scope. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, is_pytree ) if path is None or not path.exists(): self.skipTest('Checkpoint for combination does not exist.') + if ( + multihost.is_pathways_backend() or jax.process_count() > 1 + ) and not abstract_pytree_provided: + self.skipTest('Sharding metadata not matching in Pathways/Multiprocess.') + registry = self.setup_registry( path, checkpointable_name, @@ -262,7 +293,12 @@ def test_load_pytree_compatibility( def test_load_pytree_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_pytree against checkpoints with non-critical corruptions. + Args: + version: The version of the checkpoint to load. + alteration: The non-critical corruption alteration to apply. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -284,6 +320,12 @@ def test_load_pytree_non_critical_corruptions( def test_load_pytree_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests load_pytree against checkpoints with critical corruptions. + + Args: + version: The version of the checkpoint to load. + alteration: The critical corruption alteration to apply. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -303,6 +345,11 @@ def test_load_pytree_critical_corruptions( version=['v0', 'v1'], ) def test_load_incorrect_path(self, version: str) -> None: + """Tests load_pytree against checkpoints with incorrect paths. + + Args: + version: The version of the checkpoint to test against. + """ checkpoint_path = ( self.base_dir / f'{version}_checkpoints' diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py similarity index 81% rename from checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test.py rename to checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py index 6cfa1f4f4..986c4633e 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/pytree_metadata_compatibility_test_base.py @@ -27,9 +27,9 @@ from orbax.checkpoint.experimental.v1._src.context import options as options_lib from orbax.checkpoint.experimental.v1._src.handlers import registration from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout as checkpoint_layout_lib +from orbax.checkpoint.experimental.v1._src.synchronization import multihost from orbax.checkpoint.experimental.v1._src.testing.compatibility import test_utils as compatibility_test_utils - CheckpointLayoutEnum = options_lib.CheckpointLayout InvalidLayoutError = checkpoint_layout_lib.InvalidLayoutError @@ -37,7 +37,8 @@ _BASE_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints') -class PytreeMetadataCompatibilityTest(parameterized.TestCase): +class PytreeMetadataCompatibilityTestBase(parameterized.TestCase): + """Tests for V1 pytree_metadata API against generated Checkpoints.""" def setUp(self) -> None: super().setUp() @@ -155,6 +156,17 @@ def test_pytree_metadata_compatibility( is_pytree: bool, handler_registered: bool, ) -> None: + """Tests pytree_metadata compatibility across V0 and V1 checkpoints. + + Args: + version: The checkpoint version to test against. + checkpointable_name: The name of the checkpointable to load. + name_registered: Whether the checkpointable name is registered. + metadata_present: Whether the checkpoint metadata file is present. + is_direct_checkpoint: Whether the checkpoint is a direct checkpoint. + is_pytree: Whether the checkpointable is a pytree. + handler_registered: Whether the handler is registered. + """ path = compatibility_test_utils.get_checkpoint_path( version, metadata_present, is_direct_checkpoint, is_pytree ) @@ -184,9 +196,12 @@ def test_pytree_metadata_compatibility( path, checkpointable_name=checkpointable_name, ) - test_utils.assert_tree_equal( - self, self.expected_state_metadata, loaded.metadata - ) + expected = self.expected_state_metadata + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) else: with self.assertRaisesRegex(error_type, error_msg): ocp.pytree_metadata( @@ -209,6 +224,12 @@ def test_pytree_metadata_compatibility( def test_pytree_metadata_non_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests pytree_metadata with non-critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -216,9 +237,12 @@ def test_pytree_metadata_non_critical_corruptions( alteration, ) loaded = ocp.pytree_metadata(path, checkpointable_name='state') - test_utils.assert_tree_equal( - self, self.expected_state_metadata, loaded.metadata - ) + expected = self.expected_state_metadata + actual = loaded.metadata + if multihost.is_pathways_backend() or jax.process_count() > 1: + expected = compatibility_test_utils.strip_sharding_metadata(expected) + actual = compatibility_test_utils.strip_sharding_metadata(actual) + test_utils.assert_tree_equal(self, expected, actual) @parameterized.product( version=['v0', 'v1'], @@ -226,6 +250,11 @@ def test_pytree_metadata_non_critical_corruptions( def test_pytree_metadata_missing_sharding_corruption( self, version: str ) -> None: + """Tests pytree_metadata with missing sharding corruption. + + Args: + version: The checkpoint version to test against. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', @@ -247,6 +276,12 @@ def test_pytree_metadata_missing_sharding_corruption( def test_pytree_metadata_critical_corruptions( self, version: str, alteration: str ) -> None: + """Tests pytree_metadata with critical corruptions. + + Args: + version: The checkpoint version to test against. + alteration: The alteration to apply to the checkpoint. + """ path = self.base_dir.joinpath( f'{version}_checkpoints', 'composite_checkpoint', diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py index 9b3c7eba6..40eb81ea8 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/compatibility/test_utils.py @@ -82,3 +82,34 @@ def create_value_metadata(value: Any) -> Any: return 0.0 else: raise TypeError(f'Unsupported type: {type(value)}') + + +def strip_sharding_metadata(tree: Any) -> Any: + """Strips concrete sharding_metadata from Metadata to decouple from topologies.""" + def _strip(x): + # Check for sharding_metadata attribute since it may reach leaves that are + # not arrays. + if hasattr(x, 'sharding_metadata'): + return array_leaf_handler.ArrayMetadata( + shape=x.shape, + dtype=x.dtype, + sharding_metadata=None, + storage_metadata=x.storage_metadata, + ) + return jax.tree.map( + _strip, + tree, + is_leaf=lambda leaf: hasattr(leaf, 'sharding_metadata'), + ) + + +def replicate_on_mesh(tree: Any) -> Any: + """Replicates a PyTree of arrays across all devices in the current mesh.""" + mesh = jax.sharding.Mesh(np.asarray(jax.devices()), ('devices',)) + sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec()) + return jax.tree.map( + lambda x: jax.device_put(x, sharding) + if isinstance(x, (jax.Array, np.ndarray)) + else x, + tree, + ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v0v1_compatibility_save_load_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v0v1_compatibility_save_load_test_base.py deleted file mode 100644 index 6ea0bfb66..000000000 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v0v1_compatibility_save_load_test_base.py +++ /dev/null @@ -1,288 +0,0 @@ -# Copyright 2026 The Orbax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Base class for v0/v1 compatibility save/load tests.""" - -# pylint: disable=missing-class-docstring,protected-access,missing-function-docstring - -from __future__ import annotations - -from absl.testing import parameterized -from etils import epath -from orbax.checkpoint import args as args_lib -from orbax.checkpoint import test_utils -from orbax.checkpoint._src.checkpointers import checkpointer as v0_checkpointer -from orbax.checkpoint._src.checkpointers import standard_checkpointer -from orbax.checkpoint._src.handlers import composite_checkpoint_handler -import orbax.checkpoint.experimental.v1 as ocp -from orbax.checkpoint.experimental.v1._src.layout import checkpoint_layout -from orbax.checkpoint.experimental.v1._src.path import types as path_types -from orbax.checkpoint.experimental.v1._src.synchronization import multihost -from orbax.checkpoint.experimental.v1._src.testing import array_utils as array_test_utils -from orbax.checkpoint.experimental.v1._src.tree import types as tree_types - -PyTree = tree_types.PyTree -Path = path_types.Path -InvalidLayoutError = checkpoint_layout.InvalidLayoutError - -create_sharded_pytree = array_test_utils.create_sharded_pytree - - -class CompatibilitySaveLoadTestBase: - - class Test(parameterized.TestCase): - - def setUp(self): - super().setUp() - - self.root_directory = epath.Path( - self.create_tempdir(name='root').full_path - ) - self.ckpt_directory = ( - epath.Path(self.create_tempdir(name='direct').full_path) / 'ckpt' - ) - self.pytree, self.abstract_pytree = create_sharded_pytree() - - test_utils.set_tensorstore_driver_for_test() - test_utils.sync_global_processes('CompatibilityTest:setup_complete') - - def tearDown(self): - super().tearDown() - test_utils.sync_global_processes('CompatibilityTest:teardown_complete') - - def save_v0_checkpoint(self, directory: Path): - with standard_checkpointer.StandardCheckpointer() as checkpointer: - checkpointer.save(directory, self.pytree) - - def save_v0_checkpoints( - self, base_dir: Path, *, checkpointable_names: list[str] - ): - args = args_lib.Composite(**{ - checkpointable_name: args_lib.StandardSave(self.pytree) - for checkpointable_name in checkpointable_names - }) - with v0_checkpointer.Checkpointer( - composite_checkpoint_handler.CompositeCheckpointHandler() - ) as checkpointer: - checkpointer.save(base_dir, args) - - def test_async_load(self): - with self.assertRaises(NotImplementedError): - ocp.load_pytree_async( - self.root_directory, - ) - with self.assertRaises(NotImplementedError): - ocp.load_checkpointables_async( - self.root_directory, - ) - - @parameterized.product( - with_abstract_pytree=[True, False], - ) - def test_load_v0_checkpoint_with_v1_load_pytree( - self, with_abstract_pytree: bool - ): - - checkpointable_names = ['default', 'state', 'pytree'] - step_dir = self.root_directory / 'load_pytree_0' - self.save_v0_checkpoints( - step_dir, - checkpointable_names=checkpointable_names, - ) - self.save_v0_checkpoint(self.ckpt_directory) - - with self.subTest('no_checkpointable_name'): - loaded = ocp.load_pytree( - step_dir, - self.abstract_pytree if with_abstract_pytree else None, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - - with self.subTest('flat_layout_no_checkpointable_name'): - loaded = ocp.load_pytree( - self.ckpt_directory, - self.abstract_pytree if with_abstract_pytree else None, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - - with self.subTest('root_path_no_checkpointable_name_error'): - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.root_directory, - self.abstract_pytree if with_abstract_pytree else None, - ) - - for checkpointable_name in checkpointable_names: - with self.subTest(f'pass_{checkpointable_name}'): - loaded = ocp.load_pytree( - step_dir, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=checkpointable_name, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - - with self.subTest(f'pass_{checkpointable_name}_error'): - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.ckpt_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=checkpointable_name, - ) - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.root_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=checkpointable_name, - ) - - with self.subTest('pass_none'): - loaded = ocp.load_pytree( - self.ckpt_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=None, - ) - test_utils.assert_tree_equal(self, self.pytree, loaded) - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - step_dir, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=None, - ) - with self.assertRaises(InvalidLayoutError): - ocp.load_pytree( - self.root_directory, - self.abstract_pytree if with_abstract_pytree else None, - checkpointable_name=None, - ) - - def test_load_v0_checkpoint_with_v1_pytree_metadata(self): - checkpointable_names = ['default', 'state', 'pytree'] - step_dir = self.root_directory / 'load_pytree_0' - self.save_v0_checkpoints( - step_dir, - checkpointable_names=checkpointable_names, - ) - self.save_v0_checkpoint(self.ckpt_directory) - - with self.subTest('no_checkpointable_name'): - loaded = ocp.pytree_metadata(step_dir) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, loaded.metadata - ) - - with self.subTest('flat_layout_no_checkpointable_name'): - metadata = ocp.pytree_metadata(self.ckpt_directory) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, metadata.metadata - ) - - with self.subTest('root_path_no_checkpointable_name_error'): - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata(self.root_directory) - - for checkpointable_name in checkpointable_names: - with self.subTest(f'pass_{checkpointable_name}'): - loaded = ocp.pytree_metadata( - step_dir, - checkpointable_name=checkpointable_name, - ) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, loaded.metadata - ) - - with self.subTest(f'pass_{checkpointable_name}_error'): - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - self.ckpt_directory, - checkpointable_name=checkpointable_name, - ) - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - self.root_directory, - checkpointable_name=checkpointable_name, - ) - - with self.subTest('pass_none'): - loaded = ocp.pytree_metadata( - self.ckpt_directory, - checkpointable_name=None, - ) - test_utils.assert_tree_same_structure( - self, self.abstract_pytree, loaded.metadata - ) - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - step_dir, - checkpointable_name=None, - ) - with self.assertRaises(InvalidLayoutError): - ocp.pytree_metadata( - self.root_directory, - checkpointable_name=None, - ) - - @parameterized.product( - checkpointable_name=['default', 'state'], - with_abstract_pytree=[True, False], - ) - def test_load_v0_checkpoint_with_v1_load_checkpointables( - self, - checkpointable_name: str, - with_abstract_pytree: bool, - ): - - checkpointable_names = ['default', 'state'] - step_dir = self.root_directory / 'load_checkpointables_0' - self.save_v0_checkpoints( - step_dir, - checkpointable_names=checkpointable_names, - ) - self.save_v0_checkpoint(self.ckpt_directory) - - abstract_checkpointables = ( - {checkpointable_name: self.abstract_pytree} - ) - - with self.subTest('with_context'): - checkpointables_options = ( - ocp.options.CheckpointablesOptions.create_with_handlers( - **{checkpointable_name: ocp.handlers.PyTreeHandler} - ) - ) - with ocp.Context(checkpointables_options=checkpointables_options): - loaded = ocp.load_checkpointables(step_dir, abstract_checkpointables) - test_utils.assert_tree_equal( - self, self.pytree, loaded[checkpointable_name] - ) - - with self.subTest('without_context'): - loaded = ocp.load_checkpointables(step_dir, abstract_checkpointables) - test_utils.assert_tree_equal( - self, self.pytree, loaded[checkpointable_name] - ) - # TODO(b/484400394): Find a better way to inform the user that they need - # to use load_pytree(..., checkpointable_name=None) when item_handlers is - # a str. - with self.subTest('error_with_checkpoint_path'): - with self.assertRaisesRegex( - KeyError, 'Requested checkpointables:' - ): - ocp.load_checkpointables( - self.ckpt_directory, abstract_checkpointables - ) - with self.subTest('error_with_root_path'): - with self.assertRaises(InvalidLayoutError): - ocp.load_checkpointables( - self.root_directory, abstract_checkpointables - ) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v1_compatibility_load_test.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v1_compatibility_load_test.py new file mode 100644 index 000000000..b4c2409e8 --- /dev/null +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/testing/v1_compatibility_load_test.py @@ -0,0 +1,59 @@ +# Copyright 2026 The Orbax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl import flags +from absl.testing import absltest +from absl.testing import parameterized +import jax +from orbax.checkpoint.experimental.v1._src.testing.compatibility import checkpointables_metadata_compatibility_test_base +from orbax.checkpoint.experimental.v1._src.testing.compatibility import load_checkpointables_compatibility_test_base +from orbax.checkpoint.experimental.v1._src.testing.compatibility import load_pytree_compatibility_test_base +from orbax.checkpoint.experimental.v1._src.testing.compatibility import pytree_metadata_compatibility_test_base + + +FLAGS = flags.FLAGS + +jax.config.update('jax_enable_x64', True) + + +class CheckpointablesMetadataTest( + checkpointables_metadata_compatibility_test_base.CheckpointablesMetadataCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +class LoadCheckpointablesTest( + load_checkpointables_compatibility_test_base.LoadCheckpointablesCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +class LoadPytreeTest( + load_pytree_compatibility_test_base.LoadPytreeCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +class PytreeMetadataTest( + pytree_metadata_compatibility_test_base.PytreeMetadataCompatibilityTestBase, + parameterized.TestCase, +): + pass + + +if __name__ == '__main__': + absltest.main()