diff --git a/.github/hooks/pre-commit b/.github/hooks/pre-commit new file mode 100755 index 0000000..a9a1af0 --- /dev/null +++ b/.github/hooks/pre-commit @@ -0,0 +1,9 @@ +#!/bin/sh + +if hatch fmt --check; then + echo "Hatch fmt check passed!" +else + hatch fmt + echo "Error: hatch fmt modified your files. Please re-stage and commit again." + exit 1 +fi \ No newline at end of file diff --git a/examples/examples-catalog.json b/examples/examples-catalog.json index e80e3ba..0387014 100644 --- a/examples/examples-catalog.json +++ b/examples/examples-catalog.json @@ -580,6 +580,17 @@ "ApplicationLogLevel": "DEBUG", "LogFormat": "JSON" } - } + }, + { + "name": "Plugin", + "description": "Test plugin", + "handler": "execution_with_plugin.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/plugin/execution_with_plugin.py" + } ] } diff --git a/examples/src/plugin/execution_with_plugin.py b/examples/src/plugin/execution_with_plugin.py new file mode 100644 index 0000000..f8a7fa1 --- /dev/null +++ b/examples/src/plugin/execution_with_plugin.py @@ -0,0 +1,70 @@ +"""Demonstrates handler execution without any durable operations.""" + +import logging +from typing import Any + +from aws_durable_execution_sdk_python import StepContext +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_step, + durable_with_child_context, +) +from aws_durable_execution_sdk_python.execution import durable_execution +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + AttemptStartInfo, +) + + +class MyPlugin(DurableExecutionPlugin): + logger = logging.getLogger("MyPlugin") + + def on_execution_start(self, info): + self.logger.info(f"Execution started: {info}") + + def on_execution_end(self, info): + self.logger.info(f"Execution ended: {info}") + + def on_operation_start(self, info): + self.logger.info(f"Operation started: {info}") + + def on_operation_end(self, info): + self.logger.info(f"Operation ended: {info}") + + def on_invocation_start(self, info): + self.logger.info(f"Invocation started: {info}") + + def on_invocation_end(self, info): + self.logger.info(f"Invocation ended: {info}") + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + self.logger.info(f"Attempt started: {info}") + + def on_operation_attempt_end(self, info) -> None: + self.logger.info(f"Attempt ended: {info}") + + +@durable_step +def add_numbers(_step_context: StepContext, a: int, b: int) -> int: + return a + b + + +@durable_with_child_context +def add_numbers_in_child(child_context: DurableContext, a: int, b: int): + result: int = child_context.step( + add_numbers(a, b), + name="add-a-and-b", + ) + return result + + +@durable_execution(plugins=[MyPlugin()]) +def handler(_event: Any, context: DurableContext) -> int: + result: int = context.run_in_child_context( + add_numbers_in_child(6, 4), + name="add-6-and-4", + ) + return context.step( + add_numbers(result, 2), + name="add-result-to-2", + ) diff --git a/examples/template.yaml b/examples/template.yaml index 0a9dcb9..e3dfbfd 100644 --- a/examples/template.yaml +++ b/examples/template.yaml @@ -941,6 +941,24 @@ "ExecutionTimeout": 300 } } + }, + "ExecutionWithPlugin": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "execution_with_plugin.handler", + "Description": "Test plugin", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } } } } \ No newline at end of file diff --git a/examples/test/plugin/test_plugin.py b/examples/test/plugin/test_plugin.py new file mode 100644 index 0000000..5e21ba6 --- /dev/null +++ b/examples/test/plugin/test_plugin.py @@ -0,0 +1,24 @@ +"""Tests for step example.""" + +import pytest +from aws_durable_execution_sdk_python.execution import InvocationStatus + +from src.plugin import execution_with_plugin +from test.conftest import deserialize_operation_payload + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=execution_with_plugin.handler, + lambda_function_name="Plugin", +) +def test_plugin(durable_runner): + """Test basic step example.""" + with durable_runner: + result = durable_runner.run(input="{}", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == 12 + + step_result = result.get_step("add-result-to-2") + assert deserialize_operation_payload(step_result.result) == 12 diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index df535b4..ae81bfc 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -6,7 +6,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Any from aws_durable_execution_sdk_python.context import DurableContext @@ -26,6 +25,13 @@ Operation, OperationType, OperationUpdate, + InvocationStatus, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + PluginExecutor, + handle_plugins, ) from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus @@ -149,62 +155,6 @@ def from_durable_execution_invocation_input( ) -class InvocationStatus(Enum): - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - PENDING = "PENDING" - - -@dataclass(frozen=True) -class DurableExecutionInvocationOutput: - """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. - - If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution, - payload must be empty for SUCCEEDED/FAILED status. - """ - - status: InvocationStatus - result: str | None = None - error: ErrorObject | None = None - - @classmethod - def from_dict( - cls, data: MutableMapping[str, Any] - ) -> DurableExecutionInvocationOutput: - """Create an instance from a dictionary. - - Args: - data: Dictionary with camelCase keys matching the original structure - - Returns: - A DurableExecutionInvocationOutput instance - """ - status = InvocationStatus(data.get("Status")) - error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None - return cls(status=status, result=data.get("Result"), error=error) - - def to_dict(self) -> MutableMapping[str, Any]: - """Convert to a dictionary with the original field names. - - Returns: - Dictionary with the original camelCase keys - """ - result: MutableMapping[str, Any] = {"Status": self.status.value} - - if self.result is not None: - # large payloads return "", because checkpointed already - result["Result"] = self.result - if self.error: - result["Error"] = self.error.to_dict() - - return result - - @classmethod - def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: - """Create a succeeded invocation output.""" - return cls(status=InvocationStatus.SUCCEEDED, result=result) - - # endregion Invocation models @@ -212,14 +162,29 @@ def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, boto3_client: Boto3LambdaClient | None = None, + plugins: list[DurableExecutionPlugin] | None = None, ) -> Callable[[Any, LambdaContext], Any]: + """ + Decorator to create a durable execution handler. + + Args: + func: The user function to decorate + boto3_client: Optional boto3 Lambda client to use + plugins: Optional list of plugins to use (EXPERIMENTAL: This + parameter may change or be removed.) + """ # Decorator called with parameters if func is None: logger.debug("Decorator called with parameters") - return functools.partial(durable_execution, boto3_client=boto3_client) + return functools.partial( + durable_execution, boto3_client=boto3_client, plugins=plugins + ) logger.debug("Starting durable execution handler...") + plugin_executor = PluginExecutor(plugins) + + @handle_plugins(plugin_executor) def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: invocation_input: DurableExecutionInvocationInput service_client: DurableServiceClient @@ -255,6 +220,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: operations={}, service_client=service_client, replay_status=ReplayStatus.NEW, + plugin_executor=plugin_executor, ) try: @@ -306,6 +272,13 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: ) as executor, contextlib.closing(execution_state) as execution_state, ): + # execute the plugins + plugin_executor.on_invocation_start( + durable_execution_arn=invocation_input.durable_execution_arn, + context=context, + execution_operation=execution_state.get_execution_operation(), + is_replaying=execution_state.is_replaying(), + ) # Thread 1: Run background checkpoint processing executor.submit(execution_state.checkpoint_batches_forever) diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index aa78e4e..38c4455 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -105,6 +105,70 @@ class OperationSubType(Enum): CHAINED_INVOKE = "ChainedInvoke" +class InvocationStatus(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + PENDING = "PENDING" + + # Used internally only: the invocation failed and the backend will retry + RETRY = "RETRY" + + +@dataclass(frozen=True) +class DurableExecutionInvocationOutput: + """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. + + If the execution has been already completed via an update to the EXECUTION operation via CheckpointDurableExecution, + payload must be empty for SUCCEEDED/FAILED status. + """ + + status: InvocationStatus + result: str | None = None + error: ErrorObject | None = None + + @classmethod + def from_dict( + cls, data: MutableMapping[str, Any] + ) -> DurableExecutionInvocationOutput: + """Create an instance from a dictionary. + + Args: + data: Dictionary with camelCase keys matching the original structure + + Returns: + A DurableExecutionInvocationOutput instance + """ + status = InvocationStatus(data.get("Status")) + error = ErrorObject.from_dict(data["Error"]) if data.get("Error") else None + return cls(status=status, result=data.get("Result"), error=error) + + def to_dict(self) -> MutableMapping[str, Any]: + """Convert to a dictionary with the original field names. + + Returns: + Dictionary with the original camelCase keys + """ + result: MutableMapping[str, Any] = {"Status": self.status.value} + + if self.result is not None: + # large payloads return "", because checkpointed already + result["Result"] = self.result + if self.error: + result["Error"] = self.error.to_dict() + + return result + + @classmethod + def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: + """Create a succeeded invocation output.""" + return cls(status=InvocationStatus.SUCCEEDED, result=result) + + @classmethod + def create_retry(cls, error: ErrorObject) -> DurableExecutionInvocationOutput: + """Create a failed invocation output.""" + return cls(status=InvocationStatus.RETRY, error=error) + + @dataclass(frozen=True) class ExecutionDetails: input_payload: str | None = None diff --git a/src/aws_durable_execution_sdk_python/plugin.py b/src/aws_durable_execution_sdk_python/plugin.py new file mode 100644 index 0000000..9a4f294 --- /dev/null +++ b/src/aws_durable_execution_sdk_python/plugin.py @@ -0,0 +1,368 @@ +import contextlib +import datetime +import functools +import logging +from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from typing import Any, Callable, MutableMapping + +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, + InvocationStatus, + Operation, + OperationUpdate, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.types import LambdaContext + + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class OperationStartInfo: + operation_id: str + operation_type: OperationType + sub_type: OperationSubType | None = None + name: str | None = None + parent_id: str | None = None + start_timestamp: datetime.datetime | None = None + + +@dataclass(frozen=True) +class OperationEndInfo(OperationStartInfo): + status: OperationStatus = OperationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + attempt: int | None = None + error: ErrorObject | None = None + + +@dataclass(frozen=True) +class AttemptStartInfo(OperationStartInfo): + attempt: int = 1 + + +@dataclass(frozen=True) +class AttemptEndInfo(AttemptStartInfo): + succeeded: bool | None = None + end_timestamp: datetime.datetime | None = None + error: ErrorObject | None = None + next_attempt_delay_seconds: int | None = None + + +@dataclass(frozen=True) +class InvocationStartInfo: + request_id: str | None + execution_arn: str | None + start_timestamp: datetime.datetime | None + + +@dataclass(frozen=True) +class InvocationEndInfo(InvocationStartInfo): + status: InvocationStatus = InvocationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + error: ErrorObject | None = None + + +@dataclass(frozen=True) +class ExecutionStartInfo(InvocationStartInfo): + pass + + +@dataclass(frozen=True) +class ExecutionEndInfo(ExecutionStartInfo): + status: InvocationStatus = InvocationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + error: ErrorObject | None = None + + +class DurableExecutionPlugin: + """Base class for plugins. Override only the methods you need.""" + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + pass + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + pass + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + pass + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + pass + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass + + +class PluginExecutor: + def __init__(self, plugins: list[DurableExecutionPlugin] | None): + self._plugins = plugins or [] + self._execution_operation: Operation | None = None + self._durable_execution_arn: str | None = None + self._aws_request_id: str | None = None + self._executor: ThreadPoolExecutor | None = None + + @contextlib.contextmanager + def run(self): + if self._plugins: + self._executor = ThreadPoolExecutor( + max_workers=1, + thread_name_prefix="plugin-executor", + ) + try: + yield + finally: + # Shut down the thread pool, waiting for pending tasks to complete. + if self._executor: + self._executor.shutdown(wait=True) + + @staticmethod + def _dispatch_plugin(plugin: DurableExecutionPlugin, info) -> None: + """Invoke the appropriate plugin callback. Runs inside the thread pool.""" + try: + match info: + case ExecutionEndInfo(): + plugin.on_execution_end(info) + case InvocationEndInfo(): + plugin.on_invocation_end(info) + case ExecutionStartInfo(): + plugin.on_execution_start(info) + case InvocationStartInfo(): + plugin.on_invocation_start(info) + case AttemptEndInfo(): + plugin.on_operation_attempt_end(info) + case OperationEndInfo(): + plugin.on_operation_end(info) + case AttemptStartInfo(): + plugin.on_operation_attempt_start(info) + case OperationStartInfo(): + plugin.on_operation_start(info) + case _: + raise ValueError(f"Unknown info type: {type(info)}") + except Exception: + # log and ignore the exception + logger.exception("Plugin %s exception ignored", plugin.__class__.__name__) + + def execute_plugins(self, info): + if not self._executor: + return + for plugin in self._plugins: + self._executor.submit(self._dispatch_plugin, plugin, info) + + def on_invocation_start( + self, + durable_execution_arn: str, + context: LambdaContext | None, + execution_operation: Operation | None, + is_replaying: bool, + ) -> None: + self._durable_execution_arn = durable_execution_arn + self._execution_operation = execution_operation + self._aws_request_id = context.aws_request_id if context else None + start_timestamp = ( + execution_operation.start_timestamp if execution_operation else None + ) + + if not is_replaying: + self.execute_plugins( + ExecutionStartInfo( + request_id=self._aws_request_id, + execution_arn=durable_execution_arn, + start_timestamp=start_timestamp, + ) + ) + + self.execute_plugins( + InvocationStartInfo( + request_id=self._aws_request_id, + execution_arn=durable_execution_arn, + start_timestamp=start_timestamp, + ) + ) + + def on_invocation_end( + self, + output: "DurableExecutionInvocationOutput", + ) -> None: + start_timestamp = ( + self._execution_operation.start_timestamp + if self._execution_operation + else None + ) + # the actual end timestamp may be unknown because it's not checkpointed yet + end_timestamp: datetime.datetime = ( + self._execution_operation.end_timestamp + if self._execution_operation + else None + ) or datetime.datetime.now(datetime.UTC) + + self.execute_plugins( + InvocationEndInfo( + request_id=self._aws_request_id, + execution_arn=self._durable_execution_arn, + start_timestamp=start_timestamp, + status=output.status, + end_timestamp=end_timestamp, + error=output.error, + ) + ) + + if output.status in [InvocationStatus.SUCCEEDED, InvocationStatus.FAILED]: + self.execute_plugins( + ExecutionEndInfo( + request_id=self._aws_request_id, + execution_arn=self._durable_execution_arn, + start_timestamp=start_timestamp, + status=output.status, + end_timestamp=end_timestamp, + error=output.error, + ) + ) + + def on_operation_action(self, operation: Operation | None, update: OperationUpdate): + """Execute any registered plugins for a given operation before it is updated. + + Args: + operation: the operation after update + update: the operation update that is checkpointed + """ + if update.action is not OperationAction.START: + return + + self.execute_plugins( + OperationStartInfo( + operation_id=update.operation_id, + operation_type=update.operation_type, + sub_type=update.sub_type, + name=update.name, + parent_id=update.parent_id, + start_timestamp=datetime.datetime.now(datetime.UTC), + ) + ) + + if update.operation_type is OperationType.STEP: + attempt = ( + operation.step_details.attempt + if operation and operation.step_details + else 1 + ) + self.execute_plugins( + AttemptStartInfo( + operation_id=update.operation_id, + operation_type=update.operation_type, + sub_type=update.sub_type, + name=update.name, + parent_id=update.parent_id, + start_timestamp=datetime.datetime.now(datetime.UTC), + attempt=attempt, + ) + ) + + def on_operation_update(self, operation): + """Execute any registered plugins for a given operation after it is updated. + + Updates such as STARTED might be omitted because START and completion action (e.g. SUCCEED/FAIL) may be + checkpointed in batch and the backend returns only the terminal status (e.g. SUCCEEDED/PENDING/FAILED). + + Args: + operation: the operation is just checkpointed + """ + params = dict( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + ) + if operation.step_details and ( + self._is_terminal_status(operation.status) + # PENDING in addition to terminal status + or operation.status is OperationStatus.PENDING + ): + self.execute_plugins( + AttemptEndInfo( + **params, + end_timestamp=operation.end_timestamp, + attempt=operation.step_details.attempt, + succeeded=operation.status is OperationStatus.SUCCEEDED, + error=operation.step_details.error, + next_attempt_delay_seconds=operation.step_details.next_attempt_timestamp, + ) + ) + + if self._is_terminal_status(operation.status): + attempt = operation.step_details.attempt if operation.step_details else None + self.execute_plugins( + OperationEndInfo( + **params, + end_timestamp=operation.end_timestamp, + status=operation.status, + error=self._extract_error(operation), + attempt=attempt, + ) + ) + + @staticmethod + def _extract_error(operation: Operation): + if operation.step_details and operation.step_details.error: + return operation.step_details.error + if operation.callback_details and operation.callback_details.error: + return operation.callback_details.error + if operation.chained_invoke_details and operation.chained_invoke_details.error: + return operation.chained_invoke_details.error + if operation.context_details and operation.context_details.error: + return operation.context_details.error + return None + + @staticmethod + def _is_terminal_status(status): + return status in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ] + + +def handle_plugins(plugin_executor: PluginExecutor): + def decorator(func: Callable[[Any, LambdaContext], MutableMapping[str, Any]]): + @functools.wraps(func) + def wrapper(event: Any, context: LambdaContext): + with plugin_executor.run(): + try: + output = func(event, context) + + plugin_executor.on_invocation_end( + output=DurableExecutionInvocationOutput.from_dict(output), + ) + return output + except Exception as e: + plugin_executor.on_invocation_end( + output=DurableExecutionInvocationOutput.create_retry( + ErrorObject.from_exception(e) + ), + ) + raise + + return wrapper + + return decorator diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 8317550..fbfad03 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -30,6 +30,7 @@ OperationUpdate, StateOutput, ) +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock if TYPE_CHECKING: @@ -236,6 +237,7 @@ def __init__( initial_checkpoint_token: str, operations: MutableMapping[str, Operation], service_client: DurableServiceClient, + plugin_executor: PluginExecutor, batcher_config: CheckpointBatcherConfig | None = None, replay_status: ReplayStatus = ReplayStatus.NEW, ): @@ -243,6 +245,7 @@ def __init__( self._current_checkpoint_token: str = initial_checkpoint_token self.operations: MutableMapping[str, Operation] = operations self._service_client: DurableServiceClient = service_client + self._plugin_executor: PluginExecutor = plugin_executor self._ordered_checkpoint_lock: OrderedLock = OrderedLock() self._operations_lock: Lock = Lock() @@ -274,7 +277,7 @@ def fetch_paginated_operations( initial_operations: list[Operation], checkpoint_token: str, next_marker: str | None, - ) -> None: + ) -> list[Operation]: """Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe. The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety. @@ -283,6 +286,8 @@ def fetch_paginated_operations( initial_operations: initial operations to be added to ExecutionState checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + Returns: + List of all operations fetched from the Durable Functions API Raises: GetExecutionStateError: If the API call fails. The error is logged @@ -315,6 +320,7 @@ def fetch_paginated_operations( self.operations.update( {op.operation_id: op for op in all_operations} ) + return all_operations def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation @@ -689,12 +695,20 @@ def checkpoint_batches_forever(self) -> None: current_checkpoint_token = output.checkpoint_token # Fetch new operations from the API before unblocking sync waiters - self.fetch_paginated_operations( + updated_operations = self.fetch_paginated_operations( output.new_execution_state.operations, output.checkpoint_token, output.new_execution_state.next_marker, ) + for update in updates: + with self._operations_lock: + op = self.operations.get(update.operation_id) + self._plugin_executor.on_operation_action(op, update) + + for operation in updated_operations: + self._plugin_executor.on_operation_update(operation) + # Signal completion for any synchronous operations for queued_op in batch: if queued_op.completion_event is not None: diff --git a/tests/e2e/map_with_concurrent_waits_int_test.py b/tests/e2e/map_with_concurrent_waits_int_test.py index 8ad812e..62ad7c2 100644 --- a/tests/e2e/map_with_concurrent_waits_int_test.py +++ b/tests/e2e/map_with_concurrent_waits_int_test.py @@ -42,6 +42,7 @@ OperationUpdate, OperationType, ) +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, ExecutionState, @@ -68,6 +69,7 @@ def _make_state( operations={}, service_client=mock_client, batcher_config=config, + plugin_executor=PluginExecutor([]), ) diff --git a/tests/execution_test.py b/tests/execution_test.py index db13b5a..31ec7d4 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -23,7 +23,6 @@ from aws_durable_execution_sdk_python.execution import ( DurableExecutionInvocationInput, DurableExecutionInvocationInputWithClient, - DurableExecutionInvocationOutput, InitialExecutionState, InvocationStatus, durable_execution, @@ -46,7 +45,9 @@ StateOutput, StepDetails, WaitDetails, + DurableExecutionInvocationOutput, ) +from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin LARGE_RESULT = "large_success" * 1024 * 1024 @@ -2827,3 +2828,295 @@ def test_handler(event: Any, context: DurableContext) -> dict: _make_invocation_input(mock_client, next_marker="next-page-marker"), _make_lambda_context(), ) + + +# region Plugin Integration Tests + + +class _RecordingPlugin(DurableExecutionPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append(f"execution_end:{info.status.value}") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append(f"invocation_end:{info.status.value}") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info): + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info): + self.calls.append(f"attempt_end:{info.operation_id}") + + +class _FailingPlugin(DurableExecutionPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("plugin boom") + + def on_execution_end(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_start(self, info): + raise RuntimeError("plugin boom") + + def on_invocation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_end(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("plugin boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("plugin boom") + + +def test_durable_execution_with_plugins_success(): + """Test that plugins receive invocation start/end and execution end on success.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # ExecutionStartInfo dispatches to on_invocation_start in the match block + assert "invocation_start" in plugin.calls + assert "invocation_end:SUCCEEDED" in plugin.calls + assert "execution_end:SUCCEEDED" in plugin.calls + + +def test_durable_execution_with_plugins_failure(): + """Test that plugins receive invocation end and execution end on user error.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "user error" + raise ValueError(msg) + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.FAILED.value + assert "invocation_start" in plugin.calls + assert "invocation_end:FAILED" in plugin.calls + assert "execution_end:FAILED" in plugin.calls + + +def test_durable_execution_with_plugins_pending(): + """Test that plugins receive invocation end with PENDING status on suspend.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + raise SuspendExecution("test") + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.PENDING.value + assert "invocation_start" in plugin.calls + assert "invocation_end:PENDING" in plugin.calls + # Execution end should NOT be fired for PENDING + execution_end_calls = [c for c in plugin.calls if c.startswith("execution_end")] + assert len(execution_end_calls) == 0 + + +def test_durable_execution_with_plugins_retryable_error(): + """Test that plugins receive invocation end with RETRY status on retryable error.""" + mock_client = Mock(spec=DurableServiceClient) + + plugin = _RecordingPlugin() + + @durable_execution(plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + msg = "Retriable error" + raise InvocationError(msg) + + with pytest.raises(InvocationError): + test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert "invocation_start" in plugin.calls + assert "invocation_end:RETRY" in plugin.calls + + +def test_durable_execution_with_multiple_plugins(): + """Test that multiple plugins all receive callbacks.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin1 = _RecordingPlugin() + plugin2 = _RecordingPlugin() + + @durable_execution(plugins=[plugin1, plugin2]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin1.calls + assert "invocation_start" in plugin2.calls + assert "invocation_end:SUCCEEDED" in plugin1.calls + assert "invocation_end:SUCCEEDED" in plugin2.calls + + +def test_durable_execution_with_failing_plugin_does_not_break_execution(): + """Test that a failing plugin does not prevent the handler from completing.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + failing_plugin = _FailingPlugin() + recording_plugin = _RecordingPlugin() + + @durable_execution(plugins=[failing_plugin, recording_plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + # Execution should still succeed despite the failing plugin + assert result["Status"] == InvocationStatus.SUCCEEDED.value + # The recording plugin should still have been called + assert "invocation_start" in recording_plugin.calls + assert "invocation_end:SUCCEEDED" in recording_plugin.calls + + +def test_durable_execution_with_no_plugins(): + """Test that passing no plugins (None) works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=None) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_with_empty_plugins_list(): + """Test that passing an empty plugins list works correctly.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + @durable_execution(plugins=[]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + + +def test_durable_execution_decorator_with_plugins_and_boto3_client(): + """Test that plugins parameter works alongside boto3_client parameter.""" + mock_client = Mock(spec=DurableServiceClient) + mock_output = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState(), + ) + mock_client.checkpoint.return_value = mock_output + + plugin = _RecordingPlugin() + + # When using DurableExecutionInvocationInputWithClient, boto3_client is ignored + # but we verify the decorator accepts both parameters + @durable_execution(boto3_client=None, plugins=[plugin]) + def test_handler(event: Any, context: DurableContext) -> dict: + return {"result": "success"} + + result = test_handler( + _make_invocation_input(mock_client), + _make_lambda_context(), + ) + + assert result["Status"] == InvocationStatus.SUCCEEDED.value + assert "invocation_start" in plugin.calls + + +# endregion Plugin Integration Tests diff --git a/tests/logger_test.py b/tests/logger_test.py index b6017fa..1966e27 100644 --- a/tests/logger_test.py +++ b/tests/logger_test.py @@ -11,6 +11,7 @@ OperationType, ) from aws_durable_execution_sdk_python.logger import Logger, LoggerInterface, LogInfo +from aws_durable_execution_sdk_python.plugin import PluginExecutor from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus @@ -83,6 +84,7 @@ def exception( initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), ) @@ -227,6 +229,7 @@ def test_logger_with_log_info(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=Mock(), + plugin_executor=PluginExecutor([]), ) new_info = LogInfo(execution_state_new, "parent2", "op123", "new_name") new_logger = logger.with_log_info(new_info) @@ -377,6 +380,7 @@ def test_logger_replay_no_logging(): operations={"op1": operation}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(replay_execution_state, "parent123", "test_name", 5) mock_logger = Mock() @@ -404,6 +408,7 @@ def test_logger_replay_then_new_logging(): operations={"op1": operation1, "op2": operation2}, service_client=Mock(), replay_status=ReplayStatus.REPLAY, + plugin_executor=PluginExecutor([]), ) log_info = LogInfo(execution_state, "parent123", "test_name", 5) mock_logger = Mock() diff --git a/tests/plugin_test.py b/tests/plugin_test.py new file mode 100644 index 0000000..4a85126 --- /dev/null +++ b/tests/plugin_test.py @@ -0,0 +1,1098 @@ +import datetime +import logging +import unittest +from unittest.mock import MagicMock, patch + +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + InvocationStatus, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, + DurableExecutionInvocationOutput, +) +from aws_durable_execution_sdk_python.plugin import ( + AttemptEndInfo, + AttemptStartInfo, + DurableExecutionPlugin, + ExecutionEndInfo, + ExecutionStartInfo, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, + PluginExecutor, +) + + +# region Dataclass Tests + + +class TestOperationStartInfo(unittest.TestCase): + def test_required_fields(self): + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.operation_id, "op-1") + self.assertEqual(info.operation_type, OperationType.STEP) + self.assertIsNone(info.sub_type) + self.assertIsNone(info.name) + self.assertIsNone(info.parent_id) + self.assertIsNone(info.start_timestamp) + + def test_all_fields(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = OperationStartInfo( + operation_id="op-2", + operation_type=OperationType.CALLBACK, + sub_type=OperationSubType.CALLBACK, + name="my-op", + parent_id="parent-1", + start_timestamp=ts, + ) + self.assertEqual(info.sub_type, OperationSubType.CALLBACK) + self.assertEqual(info.name, "my-op") + self.assertEqual(info.parent_id, "parent-1") + self.assertEqual(info.start_timestamp, ts) + + +class TestOperationEndInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(OperationEndInfo, OperationStartInfo)) + + def test_defaults(self): + info = OperationEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertEqual(info.status, OperationStatus.SUCCEEDED) + self.assertIsNone(info.end_timestamp) + self.assertIsNone(info.attempt) + self.assertIsNone(info.error) + + def test_with_error(self): + err = ErrorObject( + message="fail", type="RuntimeError", data=None, stack_trace=None + ) + info = OperationEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.FAILED, + error=err, + attempt=3, + ) + self.assertEqual(info.status, OperationStatus.FAILED) + self.assertEqual(info.attempt, 3) + self.assertEqual(info.error.message, "fail") + + +class TestAttemptStartInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(AttemptStartInfo, OperationStartInfo)) + + def test_default_attempt(self): + info = AttemptStartInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertEqual(info.attempt, 1) + + def test_custom_attempt(self): + info = AttemptStartInfo( + operation_id="op-1", operation_type=OperationType.STEP, attempt=5 + ) + self.assertEqual(info.attempt, 5) + + +class TestAttemptEndInfo(unittest.TestCase): + def test_inherits_attempt_start_info(self): + self.assertTrue(issubclass(AttemptEndInfo, AttemptStartInfo)) + + def test_defaults(self): + info = AttemptEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertIsNone(info.succeeded) + self.assertIsNone(info.error) + self.assertIsNone(info.next_attempt_delay_seconds) + + def test_retry_with_delay(self): + err = ErrorObject( + message="timeout", type="TimeoutError", data=None, stack_trace=None + ) + info = AttemptEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + succeeded=False, + error=err, + next_attempt_delay_seconds=30, + ) + self.assertFalse(info.succeeded) + self.assertEqual(info.next_attempt_delay_seconds, 30) + self.assertEqual(info.error.type, "TimeoutError") + + +class TestInvocationStartInfo(unittest.TestCase): + def test_fields(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationStartInfo( + request_id="req-1", + execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", + start_timestamp=ts, + ) + self.assertEqual(info.request_id, "req-1") + self.assertEqual(info.execution_arn, "arn:aws:lambda:us-east-1:123:durable:abc") + self.assertEqual(info.start_timestamp, ts) + + +class TestInvocationEndInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(InvocationEndInfo, InvocationStartInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + def test_failed(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + err = ErrorObject(message="boom", type="Error", data=None, stack_trace=None) + info = InvocationEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_timestamp=ts, + status=InvocationStatus.FAILED, + error=err, + ) + self.assertEqual(info.status, InvocationStatus.FAILED) + self.assertEqual(info.error.message, "boom") + + +class TestExecutionStartInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(ExecutionStartInfo, InvocationStartInfo)) + + def test_construction(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionStartInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.assertEqual(info.request_id, "req-1") + + +class TestExecutionEndInfo(unittest.TestCase): + def test_inherits_execution_start_info(self): + self.assertTrue(issubclass(ExecutionEndInfo, ExecutionStartInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + def test_with_error(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + err = ErrorObject(message="crash", type="Error", data=None, stack_trace=None) + info = ExecutionEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_timestamp=ts, + status=InvocationStatus.FAILED, + end_timestamp=ts, + error=err, + ) + self.assertEqual(info.status, InvocationStatus.FAILED) + self.assertEqual(info.end_timestamp, ts) + self.assertEqual(info.error.message, "crash") + + +# endregion Dataclass Tests + + +# region DurableExecutionPlugin Tests + + +class TestDurableExecutionPlugin(unittest.TestCase): + def test_default_methods_are_noop(self): + """All default hook methods should be callable and return None.""" + plugin = _NoOpPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + exec_start = ExecutionStartInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + exec_end = ExecutionEndInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + inv_start = InvocationStartInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + inv_end = InvocationEndInfo( + request_id="r", execution_arn="a", start_timestamp=ts + ) + op_start = OperationStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) + op_end = OperationEndInfo(operation_id="o", operation_type=OperationType.STEP) + att_start = AttemptStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) + att_end = AttemptEndInfo(operation_id="o", operation_type=OperationType.STEP) + + self.assertIsNone(plugin.on_execution_start(exec_start)) + self.assertIsNone(plugin.on_execution_end(exec_end)) + self.assertIsNone(plugin.on_invocation_start(inv_start)) + self.assertIsNone(plugin.on_invocation_end(inv_end)) + self.assertIsNone(plugin.on_operation_start(op_start)) + self.assertIsNone(plugin.on_operation_end(op_end)) + self.assertIsNone(plugin.on_operation_attempt_start(att_start)) + self.assertIsNone(plugin.on_operation_attempt_end(att_end)) + + def test_subclass_override(self): + """A subclass can override specific hooks.""" + plugin = _TrackingPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + plugin.on_execution_start( + ExecutionStartInfo(request_id="r", execution_arn="a", start_timestamp=ts) + ) + plugin.on_operation_start( + OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT) + ) + + self.assertEqual(plugin.calls, ["execution_start:r", "operation_start:o"]) + + +# endregion DurableExecutionPlugin Tests + + +# region PluginExecutor Tests + + +class TestPluginExecutorInit(unittest.TestCase): + def test_init_with_none(self): + executor = PluginExecutor(plugins=None) + self.assertEqual(executor._plugins, []) + + def test_init_with_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertEqual(executor._plugins, []) + + def test_init_with_plugins(self): + p1 = _NoOpPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + self.assertEqual(len(executor._plugins), 2) + + +class TestPluginExecutor(unittest.TestCase): + def test_no_thread_pool_when_plugins_is_none(self): + """Tests that PluginExecutor does not create a thread pool when plugins is empty.""" + executor = PluginExecutor(plugins=None) + self.assertIsNone(executor._executor) + + def test_no_thread_pool_when_plugins_is_empty_list(self): + executor = PluginExecutor(plugins=[]) + self.assertIsNone(executor._executor) + + def test_thread_pool_created_when_plugins_provided(self): + executor = PluginExecutor(plugins=[_NoOpPlugin()]) + with executor.run(): + self.assertIsNotNone(executor._executor) + + def test_start_is_noop_when_empty(self): + executor = PluginExecutor(plugins=[]) + # Should not raise + with executor.run(): + pass + + def test_on_invocation_start_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + ctx = MagicMock() + ctx.aws_request_id = "req-1" + op = MagicMock() + op.start_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + # Should not raise + executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + + def test_on_invocation_end_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + ctx = MagicMock() + ctx.aws_request_id = "req-1" + op = MagicMock() + op.start_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + op.end_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + # Should not raise + executor.on_invocation_end( + output=output, + ) + + def test_on_operation_action_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = None + + # Should not raise + executor.on_operation_action(None, update) + + def test_on_operation_update_is_safe_when_empty(self): + executor = PluginExecutor(plugins=[]) + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = None + op.start_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + op.end_timestamp = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + op.status = OperationStatus.SUCCEEDED + op.step_details = MagicMock() + op.step_details.attempt = 1 + op.step_details.error = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + # Should not raise + executor.on_operation_update(op) + + +class TestPluginExecutorExecutePlugins(unittest.TestCase): + """Tests for the execute_plugins dispatch method.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def test_dispatch_execution_start_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = ExecutionStartInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("execution_start:req-1", self.plugin.calls) + + def test_dispatch_execution_end_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = ExecutionEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("execution_end:req-1", self.plugin.calls) + + def test_dispatch_invocation_start_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = InvocationStartInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("invocation_start:req-1", self.plugin.calls) + + def test_dispatch_invocation_end_info(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = InvocationEndInfo( + request_id="req-1", execution_arn="arn:test", start_timestamp=ts + ) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("invocation_end:req-1", self.plugin.calls) + + def test_dispatch_operation_end_info(self): + info = OperationEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_dispatch_operation_start_info(self): + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("operation_start:op-1", self.plugin.calls) + + def test_dispatch_attempt_start_info(self): + info = AttemptStartInfo(operation_id="op-1", operation_type=OperationType.STEP) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("attempt_start:op-1", self.plugin.calls) + + def test_dispatch_attempt_end_info(self): + info = AttemptEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + with self.executor.run(): + self.executor.execute_plugins(info) + self.assertIn("attempt_end:op-1", self.plugin.calls) + + def test_dispatch_unknown_type_logs_exception(self): + """Unknown info types should be caught and logged.""" + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + with self.executor.run(): + self.executor.execute_plugins("not a valid info type") + + def test_plugin_exception_is_swallowed(self): + """If a plugin raises, the exception is logged and execution continues.""" + failing_plugin = _FailingPlugin() + tracking_plugin = _TrackingPlugin() + executor = PluginExecutor(plugins=[failing_plugin, tracking_plugin]) + + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + with self.assertLogs( + "aws_durable_execution_sdk_python.plugin", level=logging.ERROR + ): + with executor.run(): + executor.execute_plugins(info) + + # The second plugin should still have been called + self.assertIn("operation_start:op-1", tracking_plugin.calls) + + def test_multiple_plugins_all_called(self): + p1 = _TrackingPlugin() + p2 = _TrackingPlugin() + executor = PluginExecutor(plugins=[p1, p2]) + + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + with executor.run(): + executor.execute_plugins(info) + + self.assertIn("operation_start:op-1", p1.calls) + self.assertIn("operation_start:op-1", p2.calls) + + +class TestPluginExecutorOnInvocationStart(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_start.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def _make_context(self, request_id="req-123"): + ctx = MagicMock() + ctx.aws_request_id = request_id + return ctx + + def _make_operation(self, start_timestamp=None): + op = MagicMock() + op.start_timestamp = start_timestamp or self.ts + return op + + def test_first_invocation_fires_execution_start_and_invocation_start(self): + ctx = self._make_context() + op = self._make_operation() + + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + + self.assertEqual("arn:exec", self.executor._durable_execution_arn) + self.assertEqual(ctx.aws_request_id, self.executor._aws_request_id) + self.assertEqual(op, self.executor._execution_operation) + + # ExecutionStartInfo dispatches to on_invocation_start in match + # InvocationStartInfo dispatches to on_invocation_start in match + # So we expect two invocation_start calls + invocation_calls = [ + c + for c in self.plugin.calls + if c.startswith("invocation_start") or c.startswith("execution_start") + ] + self.assertEqual(len(invocation_calls), 2) + + def test_replay_invocation_skips_execution_start(self): + ctx = self._make_context() + op = self._make_operation() + + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=True, + ) + + # Only InvocationStartInfo should be dispatched (not ExecutionStartInfo) + invocation_calls = [ + c + for c in self.plugin.calls + if c.startswith("invocation_start") or c.startswith("execution_start") + ] + self.assertEqual(len(invocation_calls), 1) + + def test_none_context_uses_none_request_id(self): + op = self._make_operation() + + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=None, + execution_operation=op, + is_replaying=False, + ) + + invocation_calls = [ + c + for c in self.plugin.calls + if c.startswith("invocation_start") or c.startswith("execution_start") + ] + # Both ExecutionStartInfo and InvocationStartInfo dispatched + self.assertEqual(len(invocation_calls), 2) + # request_id should be None + self.assertIn("invocation_start:None", self.plugin.calls) + + +class TestPluginExecutorOnInvocationEnd(unittest.TestCase): + """Tests for PluginExecutor.on_invocation_end.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def _make_context(self, request_id="req-123"): + ctx = MagicMock() + ctx.aws_request_id = request_id + return ctx + + def _make_operation(self, start_ts=None, end_ts=None): + op = MagicMock() + op.start_timestamp = start_ts or self.ts + op.end_timestamp = end_ts + return op + + def test_succeeded_fires_invocation_end_and_execution_end(self): + ctx = self._make_context() + op = self._make_operation(end_ts=self.ts) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-123", self.plugin.calls) + self.assertIn("execution_end:req-123", self.plugin.calls) + + def test_failed_fires_invocation_end_and_execution_end(self): + ctx = self._make_context() + op = self._make_operation(end_ts=self.ts) + err = ErrorObject(message="oops", type="Error", data=None, stack_trace=None) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, result=None, error=err + ) + + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-123", self.plugin.calls) + self.assertIn("execution_end:req-123", self.plugin.calls) + + def test_pending_fires_only_invocation_end(self): + ctx = self._make_context() + op = self._make_operation(end_ts=self.ts) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.PENDING, result=None, error=None + ) + + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-123", self.plugin.calls) + execution_end_calls = [ + c for c in self.plugin.calls if c.startswith("execution_end") + ] + self.assertEqual(len(execution_end_calls), 0) + + def test_none_execution_operation_uses_now_for_end_timestamp(self): + ctx = self._make_context() + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + with patch("aws_durable_execution_sdk_python.plugin.datetime") as mock_dt: + mock_dt.datetime.now.return_value = self.ts + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=None, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-123", self.plugin.calls) + + def test_none_end_timestamp_on_operation_uses_now(self): + ctx = self._make_context() + op = self._make_operation(end_ts=None) + output = DurableExecutionInvocationOutput( + status=InvocationStatus.SUCCEEDED, result=None, error=None + ) + + with self.executor.run(): + self.executor.on_invocation_start( + durable_execution_arn="arn:exec", + context=ctx, + execution_operation=op, + is_replaying=False, + ) + self.executor.on_invocation_end( + output=output, + ) + + self.assertIn("invocation_end:req-123", self.plugin.calls) + + +class TestPluginExecutorOnOperationAction(unittest.TestCase): + """Tests for PluginExecutor.on_operation_action.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + + def test_start_action_fires_operation_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + with self.executor.run(): + self.executor.on_operation_action(None, update) + + self.assertIn("operation_start:op-1", self.plugin.calls) + + def test_start_action_for_step_fires_attempt_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + with self.executor.run(): + self.executor.on_operation_action(None, update) + + self.assertIn("attempt_start:op-1", self.plugin.calls) + + def test_start_action_for_step_with_existing_operation_uses_attempt(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.STEP + update.sub_type = OperationSubType.STEP + update.name = "my-step" + update.parent_id = "parent-1" + + operation = MagicMock() + operation.step_details = MagicMock() + operation.step_details.attempt = 3 + + with self.executor.run(): + self.executor.on_operation_action(operation, update) + + self.assertIn("attempt_start:op-1", self.plugin.calls) + + def test_start_action_for_non_step_does_not_fire_attempt_start(self): + update = MagicMock() + update.action = OperationAction.START + update.operation_id = "op-1" + update.operation_type = OperationType.WAIT + update.sub_type = OperationSubType.WAIT + update.name = "my-wait" + update.parent_id = "parent-1" + + with self.executor.run(): + self.executor.on_operation_action(None, update) + + self.assertIn("operation_start:op-1", self.plugin.calls) + attempt_calls = [c for c in self.plugin.calls if c.startswith("attempt")] + self.assertEqual(len(attempt_calls), 0) + + def test_non_start_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.SUCCEED + update.operation_id = "op-1" + + self.executor.on_operation_action(None, update) + + self.assertEqual(self.plugin.calls, []) + + def test_fail_action_does_not_fire(self): + update = MagicMock() + update.action = OperationAction.FAIL + update.operation_id = "op-1" + + self.executor.on_operation_action(None, update) + + self.assertEqual(self.plugin.calls, []) + + +class TestPluginExecutorOnOperationUpdate(unittest.TestCase): + """Tests for PluginExecutor.on_operation_update.""" + + def setUp(self): + self.plugin = _TrackingPlugin() + self.executor = PluginExecutor(plugins=[self.plugin]) + self.ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + def _make_operation( + self, + status=OperationStatus.SUCCEEDED, + step_details=None, + callback_details=None, + chained_invoke_details=None, + context_details=None, + ): + op = MagicMock() + op.operation_id = "op-1" + op.operation_type = OperationType.STEP + op.sub_type = OperationSubType.STEP + op.name = "my-step" + op.parent_id = "parent-1" + op.start_timestamp = self.ts + op.end_timestamp = self.ts + op.status = status + op.step_details = step_details + op.callback_details = callback_details + op.chained_invoke_details = chained_invoke_details + op.context_details = context_details + return op + + def test_terminal_status_with_step_details_fires_attempt_and_operation(self): + step_details = MagicMock() + step_details.attempt = 2 + step_details.error = None + op = self._make_operation( + status=OperationStatus.SUCCEEDED, step_details=step_details + ) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("attempt_end:op-1", self.plugin.calls) + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_pending_status_with_step_details_fires_attempt_only(self): + step_details = MagicMock() + step_details.attempt = 1 + step_details.error = ErrorObject( + message="retry", type="Error", data=None, stack_trace=None + ) + op = self._make_operation( + status=OperationStatus.PENDING, step_details=step_details + ) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("attempt_end:op-1", self.plugin.calls) + # Should NOT fire operation_end for PENDING + operation_end_calls = [ + c for c in self.plugin.calls if c.startswith("operation_end") + ] + self.assertEqual(len(operation_end_calls), 0) + + def test_terminal_status_without_step_details_fires_operation_only(self): + op = self._make_operation(status=OperationStatus.FAILED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + attempt_calls = [c for c in self.plugin.calls if c.startswith("attempt")] + self.assertEqual(len(attempt_calls), 0) + + def test_non_terminal_status_without_step_details_fires_nothing(self): + op = self._make_operation(status=OperationStatus.STARTED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertEqual(self.plugin.calls, []) + + def test_ready_status_fires_nothing(self): + op = self._make_operation(status=OperationStatus.READY, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertEqual(self.plugin.calls, []) + + def test_timed_out_is_terminal(self): + op = self._make_operation(status=OperationStatus.TIMED_OUT, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_cancelled_is_terminal(self): + op = self._make_operation(status=OperationStatus.CANCELLED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + def test_stopped_is_terminal(self): + op = self._make_operation(status=OperationStatus.STOPPED, step_details=None) + + with self.executor.run(): + self.executor.on_operation_update(op) + + self.assertIn("operation_end:op-1", self.plugin.calls) + + +class TestPluginExecutorExtractError(unittest.TestCase): + """Tests for PluginExecutor._extract_error static method.""" + + def _make_error(self, msg="error"): + return ErrorObject(message=msg, type="Error", data=None, stack_trace=None) + + def test_extract_error_from_step_details(self): + err = self._make_error("step error") + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = err + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "step error") + + def test_extract_error_from_callback_details(self): + err = self._make_error("callback error") + op = MagicMock() + op.step_details = None + op.callback_details = MagicMock() + op.callback_details.error = err + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "callback error") + + def test_extract_error_from_chained_invoke_details(self): + err = self._make_error("invoke error") + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = MagicMock() + op.chained_invoke_details.error = err + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "invoke error") + + def test_extract_error_from_context_details(self): + err = self._make_error("context error") + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = MagicMock() + op.context_details.error = err + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "context error") + + def test_extract_error_returns_none_when_no_error(self): + op = MagicMock() + op.step_details = None + op.callback_details = None + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertIsNone(result) + + def test_extract_error_step_details_no_error(self): + """step_details exists but has no error - falls through to callback.""" + err = self._make_error("callback error") + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = None + op.callback_details = MagicMock() + op.callback_details.error = err + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "callback error") + + def test_extract_error_priority_step_over_callback(self): + """step_details error takes priority over callback error.""" + step_err = self._make_error("step error") + cb_err = self._make_error("callback error") + op = MagicMock() + op.step_details = MagicMock() + op.step_details.error = step_err + op.callback_details = MagicMock() + op.callback_details.error = cb_err + op.chained_invoke_details = None + op.context_details = None + + result = PluginExecutor._extract_error(op) + self.assertEqual(result.message, "step error") + + +class TestPluginExecutorIsTerminalStatus(unittest.TestCase): + """Tests for PluginExecutor._is_terminal_status static method.""" + + def test_succeeded_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.SUCCEEDED)) + + def test_failed_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.FAILED)) + + def test_timed_out_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.TIMED_OUT)) + + def test_cancelled_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.CANCELLED)) + + def test_stopped_is_terminal(self): + self.assertTrue(PluginExecutor._is_terminal_status(OperationStatus.STOPPED)) + + def test_started_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.STARTED)) + + def test_pending_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.PENDING)) + + def test_ready_is_not_terminal(self): + self.assertFalse(PluginExecutor._is_terminal_status(OperationStatus.READY)) + + +# endregion PluginExecutor Tests + + +# region Helper Classes + + +class _NoOpPlugin(DurableExecutionPlugin): + """Concrete subclass that inherits all default no-op methods.""" + + pass + + +class _TrackingPlugin(DurableExecutionPlugin): + """Concrete subclass that tracks calls to all hooks.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + self.calls.append(f"execution_start:{info.request_id}") + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + self.calls.append(f"execution_end:{info.request_id}") + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + self.calls.append(f"invocation_start:{info.request_id}") + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + self.calls.append(f"invocation_end:{info.request_id}") + + def on_operation_start(self, info: OperationStartInfo) -> None: + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info: OperationEndInfo) -> None: + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + self.calls.append(f"attempt_end:{info.operation_id}") + + +class _FailingPlugin(DurableExecutionPlugin): + """Plugin that raises on every hook call.""" + + def on_execution_start(self, info): + raise RuntimeError("boom") + + def on_execution_end(self, info): + raise RuntimeError("boom") + + def on_invocation_start(self, info): + raise RuntimeError("boom") + + def on_invocation_end(self, info): + raise RuntimeError("boom") + + def on_operation_start(self, info): + raise RuntimeError("boom") + + def on_operation_end(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_start(self, info): + raise RuntimeError("boom") + + def on_operation_attempt_end(self, info): + raise RuntimeError("boom") + + +# endregion Helper Classes + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/state_test.py b/tests/state_test.py index 0152ca6..684aa21 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -37,6 +37,10 @@ StateOutput, StepDetails, ) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + PluginExecutor, +) from aws_durable_execution_sdk_python.state import ( CheckpointBatcherConfig, CheckpointedResult, @@ -332,7 +336,7 @@ def test_checkpointerd_result_is_pending(): assert result_no_op.is_pending() is False -def test_checkpointerd_result_is_ready(): +def test_checkpointed_result_is_ready(): """Test CheckpointedResult.is_ready method.""" operation = Operation( operation_id="op1", @@ -405,6 +409,7 @@ def test_execution_state_creation(): initial_checkpoint_token="test_token", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) assert state.durable_execution_arn == "test_arn" assert state.operations == {} @@ -425,6 +430,7 @@ def test_get_checkpoint_result_success_with_result(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -446,6 +452,7 @@ def test_get_checkpoint_result_success_without_step_details(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -467,6 +474,7 @@ def test_get_checkpoint_result_operation_not_succeeded(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -483,6 +491,7 @@ def test_get_checkpoint_result_operation_not_found(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("nonexistent") @@ -500,6 +509,7 @@ def test_create_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -530,6 +540,7 @@ def test_create_checkpoint_with_none(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with None and is_sync=False enqueues an empty checkpoint @@ -554,6 +565,7 @@ def test_create_checkpoint_with_no_args(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # create_checkpoint with no args and is_sync=False enqueues an empty checkpoint @@ -582,6 +594,7 @@ def test_get_checkpoint_result_started(): initial_checkpoint_token="token123", # noqa: S106 operations={"op1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) result = state.get_checkpoint_result("op1") @@ -675,6 +688,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) state.fetch_paginated_operations( @@ -773,6 +787,7 @@ def mock_get_execution_state(durable_execution_arn, checkpoint_token, next_marke initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -811,6 +826,7 @@ def test_fetch_paginated_operations_logs_error(caplog): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) with pytest.raises(GetExecutionStateError): @@ -920,6 +936,7 @@ def test_checkpoint_batch_respects_default_max_items_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -988,6 +1005,7 @@ def test_collect_checkpoint_batch_respects_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1021,6 +1039,7 @@ def test_collect_checkpoint_batch_uses_overflow_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Put operations in overflow queue @@ -1072,6 +1091,7 @@ def test_collect_checkpoint_batch_handles_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue empty checkpoint @@ -1107,6 +1127,7 @@ def test_collect_checkpoint_batch_returns_empty_when_stopped(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal stop before collecting @@ -1128,6 +1149,7 @@ def test_parent_child_relationship_building(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1169,6 +1191,7 @@ def test_descendant_cancellation_when_parent_completes(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1208,6 +1231,7 @@ def test_rejection_of_operations_from_completed_parents(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child hierarchy @@ -1257,6 +1281,7 @@ def test_nested_parallel_operations_deep_hierarchy(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build deep hierarchy: grandparent -> parent -> child @@ -1313,6 +1338,7 @@ def test_synchronous_checkpoint_blocks_until_complete(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -1361,6 +1387,7 @@ def test_concurrent_access_to_operations_dictionary(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add initial operation @@ -1430,6 +1457,7 @@ def test_stop_checkpointing_signals_background_thread(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Verify event is not set initially @@ -1523,6 +1551,7 @@ def test_create_checkpoint_sync_with_parent_id(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create parent operation @@ -1574,6 +1603,7 @@ def test_create_checkpoint_sync_rejects_orphaned_operation(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Build parent-child relationship @@ -1638,6 +1668,7 @@ def test_mark_orphans_handles_cycles(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Manually create a cycle (shouldn't happen in practice, but test defensive code) @@ -1668,6 +1699,7 @@ def test_checkpoint_batches_forever_exception_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation @@ -1715,6 +1747,7 @@ def test_collect_checkpoint_batch_shutdown_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Add operation to queue (would be a non-essential async checkpoint in practice) @@ -1744,6 +1777,7 @@ def test_collect_checkpoint_batch_shutdown_empty_queue(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Signal shutdown with empty queue @@ -1771,6 +1805,7 @@ def test_collect_checkpoint_batch_overflow_put_back(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1816,6 +1851,7 @@ def test_create_checkpoint_sync_with_none_operation_update(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Simulate background processor @@ -1848,6 +1884,7 @@ def test_checkpoint_batches_forever_exception_with_no_sync_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -1887,6 +1924,7 @@ def test_collect_checkpoint_batch_size_limit_during_time_window(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1940,6 +1978,7 @@ def test_collect_checkpoint_batch_respects_max_operations_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -1983,6 +2022,7 @@ def test_collect_checkpoint_batch_time_window_expires(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2030,6 +2070,7 @@ def test_collect_checkpoint_batch_empty_overflow_queue_path(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Ensure overflow queue is empty (it should be by default) @@ -2067,6 +2108,7 @@ def test_collect_checkpoint_batch_overflow_queue_hits_operation_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2106,6 +2148,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2155,6 +2198,7 @@ def test_checkpoint_error_signals_completion_events_with_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create synchronous operation with completion event @@ -2211,6 +2255,7 @@ def test_synchronous_caller_receives_error_on_background_thread_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2288,6 +2333,7 @@ def test_exception_propagates_through_threadpoolexecutor(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Enqueue an operation @@ -2321,6 +2367,7 @@ def test_multiple_sync_operations_all_remain_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create multiple synchronous operations @@ -2372,6 +2419,7 @@ def test_async_operations_not_affected_by_error_handling(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create async operation (no completion event) @@ -2409,6 +2457,7 @@ def test_mixed_sync_async_operations_only_sync_blocked_on_error(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Create sync operation with completion event @@ -2469,6 +2518,7 @@ def test_create_checkpoint_accepts_is_sync_parameter(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2503,6 +2553,7 @@ def test_create_checkpoint_default_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2549,6 +2600,7 @@ def test_create_checkpoint_explicit_is_sync_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2590,6 +2642,7 @@ def test_create_checkpoint_is_sync_false_no_completion_event(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2620,6 +2673,7 @@ def test_create_checkpoint_is_sync_false_returns_immediately(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2658,6 +2712,7 @@ def test_create_checkpoint_with_none_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with None (will block) @@ -2694,6 +2749,7 @@ def test_create_checkpoint_no_args_defaults_to_sync(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Use a thread to call with no arguments (will block) @@ -2733,6 +2789,7 @@ def test_collect_checkpoint_batch_overflow_queue_size_limit_final(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -2788,6 +2845,7 @@ def test_create_checkpoint_blocks_until_completion_default(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2859,6 +2917,7 @@ def test_create_checkpoint_blocks_until_completion_explicit_true(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2930,6 +2989,7 @@ def test_create_checkpoint_completion_event_created_and_signaled(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -2994,6 +3054,7 @@ def test_create_checkpoint_completion_event_not_signaled_on_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3080,6 +3141,7 @@ def test_create_checkpoint_caller_remains_blocked_on_background_failure(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) operation_update = OperationUpdate( @@ -3162,6 +3224,7 @@ def test_create_checkpoint_multiple_sync_calls_all_block(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) num_callers = 3 @@ -3238,6 +3301,7 @@ def test_create_checkpoint_sync_with_empty_checkpoint(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), ) # Track timing and completion @@ -3296,6 +3360,7 @@ def test_create_checkpoint_sync_success(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3330,6 +3395,7 @@ def test_create_checkpoint_sync_unwraps_background_thread_error(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3363,6 +3429,7 @@ def test_create_checkpoint_sync_always_synchronous(): initial_checkpoint_token="initial-token", # noqa: S106 operations={}, service_client=mock_client, + plugin_executor=PluginExecutor(plugins=None), ) # Start background thread @@ -3400,6 +3467,7 @@ def test_state_replay_mode(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3433,6 +3501,7 @@ def test_state_replay_mode_with_timed_out(): initial_checkpoint_token="test_token", # noqa: S106 operations={"op1": operation1, "op2": operation2}, service_client=Mock(), + plugin_executor=PluginExecutor(plugins=None), replay_status=ReplayStatus.REPLAY, ) assert execution_state.is_replaying() is True @@ -3464,6 +3533,7 @@ def test_collect_checkpoint_batch_coalesces_many_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3497,6 +3567,7 @@ def test_collect_checkpoint_batch_empty_checkpoints_with_real_ops_respects_limit initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3536,6 +3607,7 @@ def test_collect_checkpoint_batch_overflow_coalesces_empty_checkpoints(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3576,6 +3648,7 @@ def test_checkpoint_batches_forever_single_api_call_for_many_empty_checkpoints() initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3624,6 +3697,7 @@ def test_collect_checkpoint_batch_first_empty_counts_toward_limit(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3676,6 +3750,7 @@ def test_execution_state_get_execution_operation_no_operations(): initial_checkpoint_token="token123", # noqa: S106 operations={}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3707,6 +3782,7 @@ def test_initial_execution_state_get_execution_operation_wrong_type(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) @@ -3743,8 +3819,447 @@ def test_initial_execution_state_get_input_payload_none(): initial_checkpoint_token="token123", # noqa: S106 operations={"step1": operation}, service_client=mock_lambda_client, + plugin_executor=PluginExecutor(plugins=None), batcher_config=config, ) result = state.get_input_payload() assert result is None + + +# region Plugin Executor Integration Tests + + +class _RecordingPlugin(DurableExecutionPlugin): + """Plugin that records all hook calls for assertion.""" + + def __init__(self) -> None: + self.calls: list[str] = [] + + def on_execution_start(self, info): + self.calls.append("execution_start") + + def on_execution_end(self, info): + self.calls.append("execution_end") + + def on_invocation_start(self, info): + self.calls.append("invocation_start") + + def on_invocation_end(self, info): + self.calls.append("invocation_end") + + def on_operation_start(self, info): + self.calls.append(f"operation_start:{info.operation_id}") + + def on_operation_end(self, info): + self.calls.append(f"operation_end:{info.operation_id}") + + def on_operation_attempt_start(self, info): + self.calls.append(f"attempt_start:{info.operation_id}") + + def on_operation_attempt_end(self, info): + self.calls.append(f"attempt_end:{info.operation_id}") + + +def test_execution_state_accepts_plugin_executor_parameter(): + """Test that ExecutionState can be created with a plugin_executor parameter.""" + mock_client = Mock(spec=LambdaClient) + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + assert state._plugin_executor is plugin_executor + + +def test_plugin_executor_on_operation_action_called_on_checkpoint(): + """Test that plugin_executor.on_operation_action is called for each update after checkpoint.""" + mock_client = Mock(spec=LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + # Start background thread + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # on_operation_action is called for START updates + assert "operation_start:step-1" in plugin.calls + assert "attempt_start:step-1" in plugin.calls + + +def test_plugin_executor_on_operation_update_called_for_terminal_operations(): + """Test that plugin_executor.on_operation_update is called for terminal operations.""" + mock_client = Mock(spec=LambdaClient) + + # Return a succeeded step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.SUCCEED, + name="my-step", + payload='"done"', + ) + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + assert "operation_end:step-1" in plugin.calls + assert "attempt_end:step-1" in plugin.calls + + +def test_plugin_executor_not_called_for_non_terminal_operations(): + """Test that plugin_executor.on_operation_update does not fire for non-terminal operations.""" + mock_client = Mock(spec=LambdaClient) + + # Return a STARTED step operation from checkpoint + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.STARTED, + step_details=None, + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # on_operation_action fires for START + assert "operation_start:step-1" in plugin.calls + # But on_operation_update should NOT fire operation_end for STARTED status + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + + +def test_plugin_executor_called_for_multiple_updates_in_batch(): + """Test that plugin_executor is called for each update in a batch.""" + mock_client = Mock(spec=LambdaClient) + + # Return multiple operations from checkpoint + step_op1 = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result1"'), + ) + step_op2 = Operation( + operation_id="step-2", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"result2"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op1, step_op2], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + config = CheckpointBatcherConfig( + max_batch_time_seconds=0.2, + max_batch_operations=10, + ) + + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + batcher_config=config, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + op1 = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-1", + ) + op2 = OperationUpdate( + operation_id="step-2", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="step-2", + ) + # Enqueue both without blocking so they batch together + state.create_checkpoint(op1, is_sync=False) + state.create_checkpoint(op2, is_sync=True) + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # Both operations should have triggered on_operation_action + assert "operation_start:step-1" in plugin.calls + assert "operation_start:step-2" in plugin.calls + # Both terminal operations should have triggered on_operation_update + assert "operation_end:step-1" in plugin.calls + assert "operation_end:step-2" in plugin.calls + + +def test_plugin_executor_not_called_on_checkpoint_failure(): + """Test that plugin_executor is NOT called when checkpoint API fails.""" + mock_client = Mock(spec=LambdaClient) + mock_client.checkpoint.side_effect = RuntimeError("API error") + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + + with pytest.raises(BackgroundThreadError): + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # Plugin should NOT have been called since checkpoint failed + assert "operation_start:step-1" not in plugin.calls + assert "operation_end:step-1" not in plugin.calls + + +def test_plugin_executor_exception_does_not_break_checkpointing(): + """Test that a plugin exception does not break the checkpoint processing loop.""" + mock_client = Mock(spec=LambdaClient) + + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.SUCCEEDED, + step_details=StepDetails(attempt=1, result='"done"'), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + class _ExplodingPlugin(DurableExecutionPlugin): + def on_operation_start(self, info): + raise RuntimeError("plugin exploded") + + def on_operation_end(self, info): + raise RuntimeError("plugin exploded") + + exploding_plugin = _ExplodingPlugin() + plugin_executor = PluginExecutor(plugins=[exploding_plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + # Should not raise even though plugin explodes + state.create_checkpoint(operation_update, is_sync=True) + + # Checkpoint should still have been called successfully + assert mock_client.checkpoint.call_count == 1 + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + +def test_plugin_executor_called_for_pending_operations(): + """Test that plugin_executor.on_operation_update fires on_attempt_end for PENDING operations.""" + mock_client = Mock(spec=LambdaClient) + + # Return a PENDING step operation from checkpoint (simulates a retry scenario) + step_op = Operation( + operation_id="step-1", + operation_type=OperationType.STEP, + status=OperationStatus.PENDING, + step_details=StepDetails( + attempt=1, + result=None, + error=ErrorObject( + message="transient failure", + type="RetryableError", + data=None, + stack_trace=None, + ), + ), + ) + mock_client.checkpoint.return_value = CheckpointOutput( + checkpoint_token="new_token", # noqa: S106 + new_execution_state=CheckpointUpdatedExecutionState( + operations=[step_op], + next_marker=None, + ), + ) + + plugin = _RecordingPlugin() + plugin_executor = PluginExecutor(plugins=[plugin]) + with plugin_executor.run(): + state = ExecutionState( + durable_execution_arn="test_arn", + initial_checkpoint_token="token123", # noqa: S106 + operations={}, + service_client=mock_client, + plugin_executor=plugin_executor, + ) + + executor = ThreadPoolExecutor(max_workers=1) + executor.submit(state.checkpoint_batches_forever) + + try: + operation_update = OperationUpdate( + operation_id="step-1", + operation_type=OperationType.STEP, + action=OperationAction.START, + name="my-step", + ) + state.create_checkpoint(operation_update, is_sync=True) + + finally: + state.stop_checkpointing() + executor.shutdown(wait=True) + + # on_attempt_end should fire for PENDING operations with step_details + assert "attempt_end:step-1" in plugin.calls + # operation_end should NOT fire for PENDING (only for terminal statuses) + operation_end_calls = [c for c in plugin.calls if c.startswith("operation_end")] + assert len(operation_end_calls) == 0 + + +# endregion Plugin Executor Integration Tests