From 0b0f4bea0eedfde8a016537d0ac355092a23ad7b Mon Sep 17 00:00:00 2001 From: Angel Mau Date: Thu, 7 May 2026 08:51:22 -0700 Subject: [PATCH] Refactor save_async to return None when no save is initiated. PiperOrigin-RevId: 911986116 --- .../v1/_src/training/checkpointer.py | 45 ++++++++----- .../_src/training/checkpointer_test_base.py | 67 +++++++++++++++++++ 2 files changed, 96 insertions(+), 16 deletions(-) diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py index 81f079942..e25d08936 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer.py @@ -52,12 +52,14 @@ class _AsyncSaveResponse(async_types.AsyncResponse[bool]): """Response for asynchronous saving.""" def __init__( - self, manager: checkpoint_manager.CheckpointManager, saved: bool + self, manager: checkpoint_manager.CheckpointManager ): async def _wait() -> bool: + # If a background operation fails wait_until_finished() will re-raise the + # exception back to caller. manager.wait_until_finished() - return saved + return True self._thread_runner = thread_utils.BackgroundThreadRunner[bool](_wait()) @@ -366,7 +368,7 @@ def save_pytree( Returns: Whether a checkpoint was saved or not. """ - return self.save_pytree_async( + response = self.save_pytree_async( step, pytree, checkpointable_name=checkpointable_name, @@ -374,7 +376,10 @@ def save_pytree( overwrite=overwrite, metrics=metrics, custom_metadata=custom_metadata, - ).result() + ) + if response is None: + return False + return response.result() def save_checkpointables( self, @@ -453,14 +458,17 @@ def save_checkpointables( Returns: bool: True if the checkpoint was successfully saved, False otherwise. """ - return self.save_checkpointables_async( + response = self.save_checkpointables_async( step, checkpointables, force=force, overwrite=overwrite, metrics=metrics, custom_metadata=custom_metadata, - ).result() + ) + if response is None: + return False + return response.result() def save_pytree_async( self, @@ -472,7 +480,7 @@ def save_pytree_async( overwrite: bool = False, metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, - ) -> async_types.AsyncResponse[bool]: + ) -> async_types.AsyncResponse[bool] | None: """Saves a checkpoint asynchronously. This function is the asynchronous equivalent of @@ -486,7 +494,8 @@ def save_pytree_async( :: async_response = ckptr.save_pytree_async(step=0, pytree=tree) - saved = async_response.result() + if async_response is not None: + saved = async_response.result() Args: step: The step number to save. @@ -500,7 +509,8 @@ def save_pytree_async( Returns: An `AsyncResponse`, which can be awaited via `result()`, which returns a - bool indicating whether a checkpoint was saved or not. + bool indicating whether a checkpoint was saved or not, or None if the save + was skipped by policy. """ return self.save_checkpointables_async( step, @@ -520,7 +530,7 @@ def save_checkpointables_async( overwrite: bool = False, metrics: tree_types.JsonType | None = None, custom_metadata: tree_types.JsonType | None = None, - ) -> async_types.AsyncResponse[bool]: + ) -> async_types.AsyncResponse[bool] | None: """Saves checkpointable objects asynchronously. This function is the asynchronous equivalent of @@ -534,7 +544,8 @@ def save_checkpointables_async( step=0, checkpointables=items_to_save ) - saved = async_response.result() + if async_response is not None: + saved = async_response.result() Args: step: The step number to save. @@ -545,9 +556,9 @@ def save_checkpointables_async( custom_metadata: See `save_checkpointables`. Returns: - An object representing the background operation. Call `.result()` on it - to block and return a boolean indicating whether the checkpoint was - successfully saved. + An object representing the background operation, or None if the save was + skipped by policy. Call `.result()` on it to block and return a boolean + indicating whether the checkpoint was successfully saved. Raises: StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the @@ -573,14 +584,16 @@ def save_checkpointables_async( checkpointables, metrics=metrics ) self._manager._checkpointer = checkpointer # pylint: disable=protected-access - saved = self._manager.save( + save_initiated = self._manager.save( step, args=args, metrics=metrics, force=force, custom_metadata=custom_metadata, ) - return _AsyncSaveResponse(self._manager, saved) + if not save_initiated: + return None + return _AsyncSaveResponse(self._manager) def load_pytree( self, diff --git a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py index 3d952f84d..8f4afbc08 100644 --- a/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py +++ b/checkpoint/orbax/checkpoint/experimental/v1/_src/training/checkpointer_test_base.py @@ -264,6 +264,7 @@ def test_skips_when_ongoing_save(self): checkpointer = Checkpointer(self.directory) self.enter_context(checkpointer) saved_0 = checkpointer.save_pytree_async(0, self.pytree) + self.assertIsNotNone(saved_0) saved_1 = checkpointer.save_pytree(1, self.pytree) self.assertTrue(saved_0.result()) self.assertFalse(saved_1) @@ -291,6 +292,7 @@ def mock_serialize(*args, **kwargs): step = 0 response = checkpointer.save_pytree_async(step, self.pytree) + self.assertIsNotNone(response) initial_d_files_mtimes = tree_test_utils.get_d_files_mtimes( self.directory / str(step) ) @@ -313,6 +315,71 @@ def mock_serialize(*args, **kwargs): ) ) + def test_save_pytree_async_on_complete(self): + save_policy = save_decision_policies.FixedIntervalPolicy(100) + checkpointer = Checkpointer( + self.directory, save_decision_policy=save_policy + ) + self.enter_context(checkpointer) + + results = [] + condition = threading.Condition() + + def callback(saved): + with condition: + results.append(saved) + condition.notify_all() + + # Step 0 should save + response = checkpointer.save_pytree_async(0, self.pytree) + self.assertIsNotNone(response) + response.on_complete(callback) + + with condition: + while not results: + if not condition.wait(timeout=10): + self.fail('Timed out waiting for callback.') + + self.assertEqual(results, [True]) + self.assertTrue(response.result()) + results.clear() + + # Step 1 should not save + response = checkpointer.save_pytree_async(1, self.pytree) + self.assertIsNone(response) + + def test_save_pytree_async_result(self): + save_policy = save_decision_policies.FixedIntervalPolicy(100) + checkpointer = Checkpointer( + self.directory, save_decision_policy=save_policy + ) + self.enter_context(checkpointer) + + # Step 0 should save + response = checkpointer.save_pytree_async(0, self.pytree) + self.assertIsNotNone(response) + self.assertTrue(response.result()) + + # Save step 1 (should be skipped by FixedIntervalPolicy(100)) + response2 = checkpointer.save_pytree_async(1, self.pytree) + self.assertIsNone(response2) + + def test_save_pytree_async_raises_on_background_failure(self): + checkpointer = Checkpointer(self.directory) + self.enter_context(checkpointer) + + with mock.patch.object( + checkpointer._manager, + 'wait_until_finished', + side_effect=[None, RuntimeError('Mocked background save error')], + ): + response = checkpointer.save_pytree_async(0, self.pytree) + self.assertIsNotNone(response) + with self.assertRaisesRegex( + RuntimeError, 'Mocked background save error' + ): + response.result() + def test_close(self): checkpointer = Checkpointer(self.directory) step_path = self.directory / '0'