diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index bfded98..27cbb90 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -37,6 +37,7 @@ from aws_durable_execution_sdk_python.operation.wait_for_condition import ( WaitForConditionOperationExecutor, ) +from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin from aws_durable_execution_sdk_python.serdes import ( PassThroughSerDes, SerDes, diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 8172263..cf375f6 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -6,7 +6,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Any from aws_durable_execution_sdk_python.context import DurableContext @@ -19,9 +18,11 @@ InvocationError, SuspendExecution, ) +from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin from aws_durable_execution_sdk_python.lambda_service import ( DurableServiceClient, ErrorObject, + InvocationStatus, LambdaClient, Operation, OperationType, @@ -149,12 +150,6 @@ def from_durable_execution_invocation_input( ) -class InvocationStatus(Enum): - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - PENDING = "PENDING" - - @dataclass(frozen=True) class DurableExecutionInvocationOutput: """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. @@ -212,11 +207,14 @@ def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, boto3_client: Boto3LambdaClient | None = None, + plugins: list[DurableExecutionPlugin] | None = None, ) -> Callable[[Any, LambdaContext], Any]: # Decorator called with parameters if func is None: logger.debug("Decorator called with parameters") - return functools.partial(durable_execution, boto3_client=boto3_client) + return functools.partial( + durable_execution, boto3_client=boto3_client, plugins=plugins + ) logger.debug("Starting durable execution handler...") @@ -254,6 +252,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: initial_checkpoint_token=invocation_input.checkpoint_token, operations={}, service_client=service_client, + plugins=plugins or [], # If there are operations other than the initial EXECUTION one, current state is in replay mode replay_status=ReplayStatus.REPLAY if len(invocation_input.initial_execution_state.operations) > 1 diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index 51a0049..d1b2381 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -96,6 +96,12 @@ class OperationSubType(Enum): CHAINED_INVOKE = "ChainedInvoke" +class InvocationStatus(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + PENDING = "PENDING" + + @dataclass(frozen=True) class ExecutionDetails: input_payload: str | None = None diff --git a/src/aws_durable_execution_sdk_python/plugin.py b/src/aws_durable_execution_sdk_python/plugin.py new file mode 100644 index 0000000..6f63246 --- /dev/null +++ b/src/aws_durable_execution_sdk_python/plugin.py @@ -0,0 +1,203 @@ +import datetime +import logging +from abc import ABC +from dataclasses import dataclass + +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, + InvocationStatus, + Operation, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class OperationStartInfo: + operation_id: str + operation_type: OperationType + sub_type: OperationSubType | None = None + name: str | None = None + parent_id: str | None = None + start_timestamp: datetime.datetime | None = None + + +@dataclass +class OperationEndInfo(OperationStartInfo): + status: OperationStatus = OperationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + attempt: int = 1 + error: ErrorObject | None = None + + +@dataclass +class AttemptStartInfo(OperationStartInfo): + attempt: int = 1 + + +@dataclass +class AttemptEndInfo(AttemptStartInfo): + outcome: OperationAction = OperationAction.SUCCEED + error: ErrorObject | None = None + next_attempt_delay_seconds: int | None = None + + +@dataclass +class InvocationStartInfo: + request_id: str + execution_arn: str + start_time: datetime.datetime + + +@dataclass +class InvocationEndInfo(InvocationStartInfo): + status: InvocationStatus = InvocationStatus.SUCCEEDED + error: ErrorObject | None = None + + +@dataclass +class ExecutionStartInfo(InvocationStartInfo): + pass + + +@dataclass +class ExecutionEndInfo(InvocationEndInfo): + pass + + +class DurableExecutionPlugin(ABC): + """Base class for plugins. Override only the methods you need.""" + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + pass + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + pass + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + pass + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + pass + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass + + +def execute_plugins(plugins: list[DurableExecutionPlugin], info): + for plugin in plugins: + try: + match info: + case ExecutionStartInfo(): + plugin.on_execution_start(info) + case ExecutionEndInfo(): + plugin.on_execution_end(info) + case InvocationStartInfo(): + plugin.on_invocation_start(info) + case InvocationEndInfo(): + plugin.on_invocation_end(info) + case OperationStartInfo(): + plugin.on_operation_start(info) + case OperationEndInfo(): + plugin.on_operation_end(info) + case AttemptStartInfo(): + plugin.on_operation_attempt_start(info) + case AttemptEndInfo(): + plugin.on_operation_attempt_end(info) + case _: + raise ValueError(f"Unknown info type: {type(info)}") + except Exception: + logger.exception("Plugin %s failed", plugin.__class__.__name__) + + +def get_operation_info_from_operation( + operation: Operation, +) -> OperationStartInfo | OperationEndInfo | None: + if operation is None: + raise ValueError("Operation is None") + if operation.status in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ]: + return OperationEndInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + status=operation.status, + ) + if operation.status is OperationStatus.STARTED: + return OperationStartInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + ) + return None + + +def get_attempt_info_from_operation( + operation: Operation, +) -> AttemptStartInfo | AttemptEndInfo | None: + if operation is None: + raise ValueError("Operation is None") + if operation.status is OperationStatus.STARTED and operation.step_details: + return AttemptStartInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + attempt=operation.step_details.attempt, + ) + if ( + operation.status + in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ] + and operation.step_details + ): + return AttemptEndInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + attempt=operation.step_details.attempt, + outcome=OperationAction.SUCCEED + if operation.status is OperationStatus.SUCCEEDED + else OperationAction.FAIL, + error=operation.step_details.error, + ) + return None diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 0d9cb0f..91bf005 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -30,6 +30,11 @@ OperationUpdate, StateOutput, ) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + get_operation_info_from_operation, + execute_plugins, +) from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock if TYPE_CHECKING: @@ -229,6 +234,7 @@ def __init__( initial_checkpoint_token: str, operations: MutableMapping[str, Operation], service_client: DurableServiceClient, + plugins: list[DurableExecutionPlugin], batcher_config: CheckpointBatcherConfig | None = None, replay_status: ReplayStatus = ReplayStatus.NEW, ): @@ -236,6 +242,7 @@ def __init__( self._current_checkpoint_token: str = initial_checkpoint_token self.operations: MutableMapping[str, Operation] = operations self._service_client: DurableServiceClient = service_client + self._plugins: list[DurableExecutionPlugin] = plugins self._ordered_checkpoint_lock: OrderedLock = OrderedLock() self._operations_lock: Lock = Lock() @@ -267,7 +274,7 @@ def fetch_paginated_operations( initial_operations: list[Operation], checkpoint_token: str, next_marker: str | None, - ) -> None: + ) -> list[Operation]: """Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe. The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety. @@ -277,6 +284,9 @@ def fetch_paginated_operations( checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + Returns: + List of all operations fetched from the Durable Functions API + Raises: GetExecutionStateError: If the API call fails. The error is logged with structured extras before re-raising. Callers are responsible @@ -308,6 +318,7 @@ def fetch_paginated_operations( self.operations.update( {op.operation_id: op for op in all_operations} ) + return all_operations def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation @@ -670,12 +681,15 @@ def checkpoint_batches_forever(self) -> None: current_checkpoint_token = output.checkpoint_token # Fetch new operations from the API before unblocking sync waiters - self.fetch_paginated_operations( + updated_operations = self.fetch_paginated_operations( output.new_execution_state.operations, output.checkpoint_token, output.new_execution_state.next_marker, ) + for operation in updated_operations: + self._execute_plugin(operation) + # Signal completion for any synchronous operations for queued_op in batch: if queued_op.completion_event is not None: @@ -861,6 +875,21 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]: ) return batch + def _execute_plugin(self, operation: Operation) -> None: + """Execute any registered plugins for the given queued operation. + + This method iterates through all registered plugins and calls their + `on_checkpoint` method with the queued operation as an argument. Plugins + are executed in the order they were registered. + + Args: + operation: the operation to execute plugins for + """ + if info := get_operation_info_from_operation(operation): + execute_plugins(self._plugins, info) + if info := get_attempt_info_from_operation(operation): + execute_plugins(self._plugins, info) + @staticmethod def _calculate_operation_size(queued_op: QueuedOperation) -> int: """Calculate the serialized size of a queued operation for batching limits. diff --git a/tests/plugin_test.py b/tests/plugin_test.py new file mode 100644 index 0000000..b61f56a --- /dev/null +++ b/tests/plugin_test.py @@ -0,0 +1,256 @@ +import datetime +import unittest + +from aws_durable_execution_sdk_python.plugin import ( + AttemptEndInfo, + AttemptStartInfo, + DurableExecutionPlugin, + ExecutionEndInfo, + ExecutionStartInfo, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, +) +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, +) + + +class TestOperationStartInfo(unittest.TestCase): + def test_required_fields(self): + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.operation_id, "op-1") + self.assertEqual(info.operation_type, OperationType.STEP) + self.assertIsNone(info.sub_type) + self.assertIsNone(info.name) + self.assertIsNone(info.parent_id) + self.assertIsNone(info.start_timestamp) + + def test_all_fields(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = OperationStartInfo( + operation_id="op-2", + operation_type=OperationType.CALLBACK, + sub_type=OperationSubType.CALLBACK, + name="my-op", + parent_id="parent-1", + start_timestamp=ts, + ) + self.assertEqual(info.sub_type, OperationSubType.CALLBACK) + self.assertEqual(info.name, "my-op") + self.assertEqual(info.parent_id, "parent-1") + self.assertEqual(info.start_timestamp, ts) + + +class TestOperationEndInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(OperationEndInfo, OperationStartInfo)) + + def test_defaults(self): + info = OperationEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertEqual(info.status, OperationStatus.SUCCEEDED) + self.assertIsNone(info.end_timestamp) + self.assertEqual(info.attempt, 1) + self.assertIsNone(info.error) + + def test_with_error(self): + err = ErrorObject( + message="fail", type="RuntimeError", data=None, stack_trace=None + ) + info = OperationEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.FAILED, + error=err, + attempt=3, + ) + self.assertEqual(info.status, OperationStatus.FAILED) + self.assertEqual(info.attempt, 3) + self.assertEqual(info.error.message, "fail") + + +class TestAttemptStartInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(AttemptStartInfo, OperationStartInfo)) + + def test_default_attempt(self): + info = AttemptStartInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertEqual(info.attempt, 1) + + def test_custom_attempt(self): + info = AttemptStartInfo( + operation_id="op-1", operation_type=OperationType.STEP, attempt=5 + ) + self.assertEqual(info.attempt, 5) + + +class TestAttemptEndInfo(unittest.TestCase): + def test_inherits_attempt_start_info(self): + self.assertTrue(issubclass(AttemptEndInfo, AttemptStartInfo)) + + def test_defaults(self): + info = AttemptEndInfo(operation_id="op-1", operation_type=OperationType.STEP) + self.assertEqual(info.outcome, OperationAction.SUCCEED) + self.assertIsNone(info.error) + self.assertIsNone(info.next_attempt_delay_seconds) + + def test_retry_with_delay(self): + err = ErrorObject( + message="timeout", type="TimeoutError", data=None, stack_trace=None + ) + info = AttemptEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + outcome=OperationAction.RETRY, + error=err, + next_attempt_delay_seconds=30, + ) + self.assertEqual(info.outcome, OperationAction.RETRY) + self.assertEqual(info.next_attempt_delay_seconds, 30) + self.assertEqual(info.error.type, "TimeoutError") + + +class TestInvocationStartInfo(unittest.TestCase): + def test_fields(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationStartInfo( + request_id="req-1", + execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", + start_time=ts, + ) + self.assertEqual(info.request_id, "req-1") + self.assertEqual(info.execution_arn, "arn:aws:lambda:us-east-1:123:durable:abc") + self.assertEqual(info.start_time, ts) + + +class TestInvocationEndInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(InvocationEndInfo, InvocationStartInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationEndInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + def test_failed(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + err = ErrorObject(message="boom", type="Error", data=None, stack_trace=None) + info = InvocationEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_time=ts, + status=InvocationStatus.FAILED, + error=err, + ) + self.assertEqual(info.status, InvocationStatus.FAILED) + self.assertEqual(info.error.message, "boom") + + +class TestExecutionStartInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(ExecutionStartInfo, InvocationStartInfo)) + + def test_construction(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionStartInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.request_id, "req-1") + + +class TestExecutionEndInfo(unittest.TestCase): + def test_inherits_invocation_end_info(self): + self.assertTrue(issubclass(ExecutionEndInfo, InvocationEndInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionEndInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + +class TestDurableExecutionPlugin(unittest.TestCase): + def test_default_methods_are_noop(self): + """All default hook methods should be callable and return None.""" + plugin = _NoOpPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + exec_start = ExecutionStartInfo( + request_id="r", execution_arn="a", start_time=ts + ) + exec_end = ExecutionEndInfo(request_id="r", execution_arn="a", start_time=ts) + inv_start = InvocationStartInfo( + request_id="r", execution_arn="a", start_time=ts + ) + inv_end = InvocationEndInfo(request_id="r", execution_arn="a", start_time=ts) + op_start = OperationStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) + op_end = OperationEndInfo(operation_id="o", operation_type=OperationType.STEP) + att_start = AttemptStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) + att_end = AttemptEndInfo(operation_id="o", operation_type=OperationType.STEP) + + self.assertIsNone(plugin.on_execution_start(exec_start)) + self.assertIsNone(plugin.on_execution_end(exec_end)) + self.assertIsNone(plugin.on_invocation_start(inv_start)) + self.assertIsNone(plugin.on_invocation_end(inv_end)) + self.assertIsNone(plugin.on_operation_start(op_start)) + self.assertIsNone(plugin.on_operation_end(op_end)) + self.assertIsNone(plugin.on_operation_attempt_start(att_start)) + self.assertIsNone(plugin.on_operation_attempt_end(att_end)) + + def test_subclass_override(self): + """A subclass can override specific hooks.""" + plugin = _TrackingPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + plugin.on_execution_start( + ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts) + ) + plugin.on_operation_start( + OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT) + ) + + self.assertEqual(plugin.calls, ["execution_start:r", "operation_start:o"]) + + def test_cannot_instantiate_abc_directly(self): + """DurableExecutionPlugin is abstract but has no abstract methods, so it can be instantiated via a subclass.""" + self.assertTrue(issubclass(DurableExecutionPlugin, object)) + + +class _NoOpPlugin(DurableExecutionPlugin): + """Concrete subclass that inherits all default no-op methods.""" + + pass + + +class _TrackingPlugin(DurableExecutionPlugin): + """Concrete subclass that tracks calls to specific hooks.""" + + def __init__(self): + self.calls: list[str] = [] + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + self.calls.append(f"execution_start:{info.request_id}") + + def on_operation_start(self, info: OperationStartInfo) -> None: + self.calls.append(f"operation_start:{info.operation_id}") + + +if __name__ == "__main__": + unittest.main()