Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ class _AsyncSaveResponse(async_types.AsyncResponse[bool]):
"""Response for asynchronous saving."""

def __init__(
self, manager: checkpoint_manager.CheckpointManager, saved: bool
self, manager: checkpoint_manager.CheckpointManager
):

async def _wait() -> bool:
# If a background operation fails wait_until_finished() will re-raise the
# exception back to caller.
manager.wait_until_finished()
return saved
return True

self._thread_runner = thread_utils.BackgroundThreadRunner[bool](_wait())

Expand Down Expand Up @@ -366,15 +368,18 @@ def save_pytree(
Returns:
Whether a checkpoint was saved or not.
"""
return self.save_pytree_async(
response = self.save_pytree_async(
step,
pytree,
checkpointable_name=checkpointable_name,
force=force,
overwrite=overwrite,
metrics=metrics,
custom_metadata=custom_metadata,
).result()
)
if response is None:
return False
return response.result()

def save_checkpointables(
self,
Expand Down Expand Up @@ -453,14 +458,17 @@ def save_checkpointables(
Returns:
bool: True if the checkpoint was successfully saved, False otherwise.
"""
return self.save_checkpointables_async(
response = self.save_checkpointables_async(
step,
checkpointables,
force=force,
overwrite=overwrite,
metrics=metrics,
custom_metadata=custom_metadata,
).result()
)
if response is None:
return False
return response.result()

def save_pytree_async(
self,
Expand All @@ -472,7 +480,7 @@ def save_pytree_async(
overwrite: bool = False,
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> async_types.AsyncResponse[bool]:
) -> async_types.AsyncResponse[bool] | None:
"""Saves a checkpoint asynchronously.

This function is the asynchronous equivalent of
Expand All @@ -486,7 +494,8 @@ def save_pytree_async(
::

async_response = ckptr.save_pytree_async(step=0, pytree=tree)
saved = async_response.result()
if async_response is not None:
saved = async_response.result()

Args:
step: The step number to save.
Expand All @@ -500,7 +509,8 @@ def save_pytree_async(

Returns:
An `AsyncResponse`, which can be awaited via `result()`, which returns a
bool indicating whether a checkpoint was saved or not.
bool indicating whether a checkpoint was saved or not, or None if the save
was skipped by policy.
"""
return self.save_checkpointables_async(
step,
Expand All @@ -520,7 +530,7 @@ def save_checkpointables_async(
overwrite: bool = False,
metrics: tree_types.JsonType | None = None,
custom_metadata: tree_types.JsonType | None = None,
) -> async_types.AsyncResponse[bool]:
) -> async_types.AsyncResponse[bool] | None:
"""Saves checkpointable objects asynchronously.

This function is the asynchronous equivalent of
Expand All @@ -534,7 +544,8 @@ def save_checkpointables_async(
step=0,
checkpointables=items_to_save
)
saved = async_response.result()
if async_response is not None:
saved = async_response.result()

Args:
step: The step number to save.
Expand All @@ -545,9 +556,9 @@ def save_checkpointables_async(
custom_metadata: See `save_checkpointables`.

Returns:
An object representing the background operation. Call `.result()` on it
to block and return a boolean indicating whether the checkpoint was
successfully saved.
An object representing the background operation, or None if the save was
skipped by policy. Call `.result()` on it to block and return a boolean
indicating whether the checkpoint was successfully saved.

Raises:
StepAlreadyExistsError: If `overwrite` is False and a checkpoint at the
Expand All @@ -573,14 +584,16 @@ def save_checkpointables_async(
checkpointables, metrics=metrics
)
self._manager._checkpointer = checkpointer # pylint: disable=protected-access
saved = self._manager.save(
save_initiated = self._manager.save(
step,
args=args,
metrics=metrics,
force=force,
custom_metadata=custom_metadata,
)
return _AsyncSaveResponse(self._manager, saved)
if not save_initiated:
return None
return _AsyncSaveResponse(self._manager)

def load_pytree(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def test_skips_when_ongoing_save(self):
checkpointer = Checkpointer(self.directory)
self.enter_context(checkpointer)
saved_0 = checkpointer.save_pytree_async(0, self.pytree)
self.assertIsNotNone(saved_0)
saved_1 = checkpointer.save_pytree(1, self.pytree)
self.assertTrue(saved_0.result())
self.assertFalse(saved_1)
Expand Down Expand Up @@ -291,6 +292,7 @@ def mock_serialize(*args, **kwargs):

step = 0
response = checkpointer.save_pytree_async(step, self.pytree)
self.assertIsNotNone(response)
initial_d_files_mtimes = tree_test_utils.get_d_files_mtimes(
self.directory / str(step)
)
Expand All @@ -313,6 +315,71 @@ def mock_serialize(*args, **kwargs):
)
)

def test_save_pytree_async_on_complete(self):
save_policy = save_decision_policies.FixedIntervalPolicy(100)
checkpointer = Checkpointer(
self.directory, save_decision_policy=save_policy
)
self.enter_context(checkpointer)

results = []
condition = threading.Condition()

def callback(saved):
with condition:
results.append(saved)
condition.notify_all()

# Step 0 should save
response = checkpointer.save_pytree_async(0, self.pytree)
self.assertIsNotNone(response)
response.on_complete(callback)

with condition:
while not results:
if not condition.wait(timeout=10):
self.fail('Timed out waiting for callback.')

self.assertEqual(results, [True])
self.assertTrue(response.result())
results.clear()

# Step 1 should not save
response = checkpointer.save_pytree_async(1, self.pytree)
self.assertIsNone(response)

def test_save_pytree_async_result(self):
save_policy = save_decision_policies.FixedIntervalPolicy(100)
checkpointer = Checkpointer(
self.directory, save_decision_policy=save_policy
)
self.enter_context(checkpointer)

# Step 0 should save
response = checkpointer.save_pytree_async(0, self.pytree)
self.assertIsNotNone(response)
self.assertTrue(response.result())

# Save step 1 (should be skipped by FixedIntervalPolicy(100))
response2 = checkpointer.save_pytree_async(1, self.pytree)
self.assertIsNone(response2)

def test_save_pytree_async_raises_on_background_failure(self):
checkpointer = Checkpointer(self.directory)
self.enter_context(checkpointer)

with mock.patch.object(
checkpointer._manager,
'wait_until_finished',
side_effect=[None, RuntimeError('Mocked background save error')],
):
response = checkpointer.save_pytree_async(0, self.pytree)
self.assertIsNotNone(response)
with self.assertRaisesRegex(
RuntimeError, 'Mocked background save error'
):
response.result()

def test_close(self):
checkpointer = Checkpointer(self.directory)
step_path = self.directory / '0'
Expand Down
Loading