From 2da8645bf2abc16937ada81538b9d73df2ad1867 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 1 May 2026 15:05:16 -0700 Subject: [PATCH 1/5] add hook interface --- src/aws_durable_execution_sdk_python/hook.py | 71 ++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 src/aws_durable_execution_sdk_python/hook.py diff --git a/src/aws_durable_execution_sdk_python/hook.py b/src/aws_durable_execution_sdk_python/hook.py new file mode 100644 index 0000000..2029cbe --- /dev/null +++ b/src/aws_durable_execution_sdk_python/hook.py @@ -0,0 +1,71 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Optional, Dict, Any + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import OperationType, OperationStatus, OperationAction + + +@dataclass +class OperationStartInfo: + operation_id: str + operation_name: Optional[str] + operation_type: OperationType + parent_operation_id: Optional[str] = None + attempt: Optional[int] = None + attributes: Optional[Dict[str, Any]] = None + + +@dataclass +class OperationEndInfo(OperationStartInfo): + outcome: Optional[OperationAction] = None # None | "SUCCEED" | "FAIL" | "RETRY" | "CANCEL") + status: Optional[OperationStatus] = None + error: Optional[Exception] = None + + +@dataclass +class AttemptStartInfo(OperationStartInfo): + attempt: int = 1 + + +@dataclass +class AttemptEndInfo(AttemptStartInfo): + outcome: Optional[OperationAction] = None # None | "SUCCEED" | "FAIL" | "RETRY" + error: Optional[Exception] = None + next_attempt_delay_seconds: Optional[float] = None + + +@dataclass +class InvocationStartInfo: + request_id: str + execution_arn: str + + +@dataclass +class InvocationEndInfo(InvocationStartInfo): + status: Optional[InvocationStatus] = None + error: Optional[Exception] = None + + +@dataclass +class ExecutionStartInfo(InvocationStartInfo): + pass + + +@dataclass +class ExecutionEndInfo(InvocationEndInfo): + pass + + +class DurableExecutionPlugin(ABC): + """Base class for plugins. Override only the methods you need.""" + + def on_execution_start(self, info: ExecutionStartInfo) -> None: pass + def on_execution_end(self, info: ExecutionEndInfo) -> None: pass + def on_invocation_start(self, info: InvocationStartInfo) -> None: pass + def on_invocation_end(self, info: InvocationEndInfo) -> None: pass + def on_operation_start(self, info: OperationStartInfo) -> None: pass + def on_operation_end(self, info: OperationEndInfo) -> None: pass + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: pass + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: pass + def enrich_log_context(self, info: Optional[OperationStartInfo]) -> Optional[Dict[str, Any]]: pass From 3dad5f7cc27749282e462132726bf47047739ead Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 1 May 2026 15:38:31 -0700 Subject: [PATCH 2/5] update hook interface --- src/aws_durable_execution_sdk_python/hook.py | 36 ++++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/hook.py b/src/aws_durable_execution_sdk_python/hook.py index 2029cbe..6c3ac76 100644 --- a/src/aws_durable_execution_sdk_python/hook.py +++ b/src/aws_durable_execution_sdk_python/hook.py @@ -1,26 +1,29 @@ +import datetime from abc import ABC from dataclasses import dataclass -from typing import Optional, Dict, Any +from typing import Dict, Any from aws_durable_execution_sdk_python.execution import InvocationStatus -from aws_durable_execution_sdk_python.lambda_service import OperationType, OperationStatus, OperationAction +from aws_durable_execution_sdk_python.lambda_service import OperationType, OperationStatus, OperationAction, \ + OperationSubType, ErrorObject @dataclass class OperationStartInfo: operation_id: str - operation_name: Optional[str] operation_type: OperationType - parent_operation_id: Optional[str] = None - attempt: Optional[int] = None - attributes: Optional[Dict[str, Any]] = None + sub_type: OperationSubType | None = None + name: str | None = None + parent_id: str | None = None + start_timestamp: datetime.datetime | None = None @dataclass class OperationEndInfo(OperationStartInfo): - outcome: Optional[OperationAction] = None # None | "SUCCEED" | "FAIL" | "RETRY" | "CANCEL") - status: Optional[OperationStatus] = None - error: Optional[Exception] = None + status: OperationStatus = OperationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + attempt: int = 1 + error: ErrorObject | None = None @dataclass @@ -30,21 +33,22 @@ class AttemptStartInfo(OperationStartInfo): @dataclass class AttemptEndInfo(AttemptStartInfo): - outcome: Optional[OperationAction] = None # None | "SUCCEED" | "FAIL" | "RETRY" - error: Optional[Exception] = None - next_attempt_delay_seconds: Optional[float] = None + outcome: OperationAction = OperationAction.SUCCEED + error: ErrorObject | None = None + next_attempt_delay_seconds: int | None = None @dataclass class InvocationStartInfo: request_id: str execution_arn: str + start_time: datetime.datetime @dataclass class InvocationEndInfo(InvocationStartInfo): - status: Optional[InvocationStatus] = None - error: Optional[Exception] = None + status: InvocationStatus = InvocationStatus.SUCCEEDED + error: ErrorObject | None = None @dataclass @@ -68,4 +72,6 @@ def on_operation_start(self, info: OperationStartInfo) -> None: pass def on_operation_end(self, info: OperationEndInfo) -> None: pass def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: pass def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: pass - def enrich_log_context(self, info: Optional[OperationStartInfo]) -> Optional[Dict[str, Any]]: pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass From d6e3190e817ff45146ebab52b6835660e2d01f29 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 1 May 2026 15:40:18 -0700 Subject: [PATCH 3/5] hatch fmt --- src/aws_durable_execution_sdk_python/hook.py | 40 +++++++++++++++----- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/hook.py b/src/aws_durable_execution_sdk_python/hook.py index 6c3ac76..20b7f47 100644 --- a/src/aws_durable_execution_sdk_python/hook.py +++ b/src/aws_durable_execution_sdk_python/hook.py @@ -4,8 +4,13 @@ from typing import Dict, Any from aws_durable_execution_sdk_python.execution import InvocationStatus -from aws_durable_execution_sdk_python.lambda_service import OperationType, OperationStatus, OperationAction, \ - OperationSubType, ErrorObject +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, +) @dataclass @@ -64,14 +69,29 @@ class ExecutionEndInfo(InvocationEndInfo): class DurableExecutionPlugin(ABC): """Base class for plugins. Override only the methods you need.""" - def on_execution_start(self, info: ExecutionStartInfo) -> None: pass - def on_execution_end(self, info: ExecutionEndInfo) -> None: pass - def on_invocation_start(self, info: InvocationStartInfo) -> None: pass - def on_invocation_end(self, info: InvocationEndInfo) -> None: pass - def on_operation_start(self, info: OperationStartInfo) -> None: pass - def on_operation_end(self, info: OperationEndInfo) -> None: pass - def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: pass - def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: pass + def on_execution_start(self, info: ExecutionStartInfo) -> None: + pass + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + pass + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + pass + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + pass + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + pass # Todo: further discussions required to finalize the following interface # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass From 0e7c39909eeb083efe64501bb6e486a0bce42baf Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Fri, 1 May 2026 15:46:55 -0700 Subject: [PATCH 4/5] add test cases for hook --- src/aws_durable_execution_sdk_python/hook.py | 1 - tests/hook_test.py | 243 +++++++++++++++++++ 2 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 tests/hook_test.py diff --git a/src/aws_durable_execution_sdk_python/hook.py b/src/aws_durable_execution_sdk_python/hook.py index 20b7f47..9fde017 100644 --- a/src/aws_durable_execution_sdk_python/hook.py +++ b/src/aws_durable_execution_sdk_python/hook.py @@ -1,7 +1,6 @@ import datetime from abc import ABC from dataclasses import dataclass -from typing import Dict, Any from aws_durable_execution_sdk_python.execution import InvocationStatus from aws_durable_execution_sdk_python.lambda_service import ( diff --git a/tests/hook_test.py b/tests/hook_test.py new file mode 100644 index 0000000..edacff4 --- /dev/null +++ b/tests/hook_test.py @@ -0,0 +1,243 @@ +import datetime +import unittest + +from aws_durable_execution_sdk_python.hook import ( + AttemptEndInfo, + AttemptStartInfo, + DurableExecutionPlugin, + ExecutionEndInfo, + ExecutionStartInfo, + InvocationEndInfo, + InvocationStartInfo, + OperationEndInfo, + OperationStartInfo, +) +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + ErrorObject, + OperationAction, + OperationStatus, + OperationSubType, + OperationType, +) + + +class TestOperationStartInfo(unittest.TestCase): + def test_required_fields(self): + info = OperationStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.operation_id, "op-1") + self.assertEqual(info.operation_type, OperationType.STEP) + self.assertIsNone(info.sub_type) + self.assertIsNone(info.name) + self.assertIsNone(info.parent_id) + self.assertIsNone(info.start_timestamp) + + def test_all_fields(self): + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + info = OperationStartInfo( + operation_id="op-2", + operation_type=OperationType.CALLBACK, + sub_type=OperationSubType.CALLBACK, + name="my-op", + parent_id="parent-1", + start_timestamp=ts, + ) + self.assertEqual(info.sub_type, OperationSubType.CALLBACK) + self.assertEqual(info.name, "my-op") + self.assertEqual(info.parent_id, "parent-1") + self.assertEqual(info.start_timestamp, ts) + + +class TestOperationEndInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(OperationEndInfo, OperationStartInfo)) + + def test_defaults(self): + info = OperationEndInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.status, OperationStatus.SUCCEEDED) + self.assertIsNone(info.end_timestamp) + self.assertEqual(info.attempt, 1) + self.assertIsNone(info.error) + + def test_with_error(self): + err = ErrorObject(message="fail", type="RuntimeError", data=None, stack_trace=None) + info = OperationEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + status=OperationStatus.FAILED, + error=err, + attempt=3, + ) + self.assertEqual(info.status, OperationStatus.FAILED) + self.assertEqual(info.attempt, 3) + self.assertEqual(info.error.message, "fail") + + +class TestAttemptStartInfo(unittest.TestCase): + def test_inherits_operation_start_info(self): + self.assertTrue(issubclass(AttemptStartInfo, OperationStartInfo)) + + def test_default_attempt(self): + info = AttemptStartInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.attempt, 1) + + def test_custom_attempt(self): + info = AttemptStartInfo( + operation_id="op-1", operation_type=OperationType.STEP, attempt=5 + ) + self.assertEqual(info.attempt, 5) + + +class TestAttemptEndInfo(unittest.TestCase): + def test_inherits_attempt_start_info(self): + self.assertTrue(issubclass(AttemptEndInfo, AttemptStartInfo)) + + def test_defaults(self): + info = AttemptEndInfo( + operation_id="op-1", operation_type=OperationType.STEP + ) + self.assertEqual(info.outcome, OperationAction.SUCCEED) + self.assertIsNone(info.error) + self.assertIsNone(info.next_attempt_delay_seconds) + + def test_retry_with_delay(self): + err = ErrorObject(message="timeout", type="TimeoutError", data=None, stack_trace=None) + info = AttemptEndInfo( + operation_id="op-1", + operation_type=OperationType.STEP, + outcome=OperationAction.RETRY, + error=err, + next_attempt_delay_seconds=30, + ) + self.assertEqual(info.outcome, OperationAction.RETRY) + self.assertEqual(info.next_attempt_delay_seconds, 30) + self.assertEqual(info.error.type, "TimeoutError") + + +class TestInvocationStartInfo(unittest.TestCase): + def test_fields(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationStartInfo( + request_id="req-1", execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", start_time=ts + ) + self.assertEqual(info.request_id, "req-1") + self.assertEqual(info.execution_arn, "arn:aws:lambda:us-east-1:123:durable:abc") + self.assertEqual(info.start_time, ts) + + +class TestInvocationEndInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(InvocationEndInfo, InvocationStartInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = InvocationEndInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + def test_failed(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + err = ErrorObject(message="boom", type="Error", data=None, stack_trace=None) + info = InvocationEndInfo( + request_id="req-1", + execution_arn="arn:test", + start_time=ts, + status=InvocationStatus.FAILED, + error=err, + ) + self.assertEqual(info.status, InvocationStatus.FAILED) + self.assertEqual(info.error.message, "boom") + + +class TestExecutionStartInfo(unittest.TestCase): + def test_inherits_invocation_start_info(self): + self.assertTrue(issubclass(ExecutionStartInfo, InvocationStartInfo)) + + def test_construction(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionStartInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.request_id, "req-1") + + +class TestExecutionEndInfo(unittest.TestCase): + def test_inherits_invocation_end_info(self): + self.assertTrue(issubclass(ExecutionEndInfo, InvocationEndInfo)) + + def test_defaults(self): + ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) + info = ExecutionEndInfo( + request_id="req-1", execution_arn="arn:test", start_time=ts + ) + self.assertEqual(info.status, InvocationStatus.SUCCEEDED) + self.assertIsNone(info.error) + + +class TestDurableExecutionPlugin(unittest.TestCase): + def test_default_methods_are_noop(self): + """All default hook methods should be callable and return None.""" + plugin = _NoOpPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + exec_start = ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts) + exec_end = ExecutionEndInfo(request_id="r", execution_arn="a", start_time=ts) + inv_start = InvocationStartInfo(request_id="r", execution_arn="a", start_time=ts) + inv_end = InvocationEndInfo(request_id="r", execution_arn="a", start_time=ts) + op_start = OperationStartInfo(operation_id="o", operation_type=OperationType.STEP) + op_end = OperationEndInfo(operation_id="o", operation_type=OperationType.STEP) + att_start = AttemptStartInfo(operation_id="o", operation_type=OperationType.STEP) + att_end = AttemptEndInfo(operation_id="o", operation_type=OperationType.STEP) + + self.assertIsNone(plugin.on_execution_start(exec_start)) + self.assertIsNone(plugin.on_execution_end(exec_end)) + self.assertIsNone(plugin.on_invocation_start(inv_start)) + self.assertIsNone(plugin.on_invocation_end(inv_end)) + self.assertIsNone(plugin.on_operation_start(op_start)) + self.assertIsNone(plugin.on_operation_end(op_end)) + self.assertIsNone(plugin.on_operation_attempt_start(att_start)) + self.assertIsNone(plugin.on_operation_attempt_end(att_end)) + + def test_subclass_override(self): + """A subclass can override specific hooks.""" + plugin = _TrackingPlugin() + ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) + + plugin.on_execution_start(ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts)) + plugin.on_operation_start(OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT)) + + self.assertEqual(plugin.calls, ["execution_start:r", "operation_start:o"]) + + def test_cannot_instantiate_abc_directly(self): + """DurableExecutionPlugin is abstract but has no abstract methods, so it can be instantiated via a subclass.""" + self.assertTrue(issubclass(DurableExecutionPlugin, object)) + + +class _NoOpPlugin(DurableExecutionPlugin): + """Concrete subclass that inherits all default no-op methods.""" + pass + + +class _TrackingPlugin(DurableExecutionPlugin): + """Concrete subclass that tracks calls to specific hooks.""" + + def __init__(self): + self.calls: list[str] = [] + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + self.calls.append(f"execution_start:{info.request_id}") + + def on_operation_start(self, info: OperationStartInfo) -> None: + self.calls.append(f"operation_start:{info.operation_id}") + + +if __name__ == "__main__": + unittest.main() From 370ead54a42874c8fb05af20a4934c32fe3b3b16 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Mon, 4 May 2026 10:00:13 -0700 Subject: [PATCH 5/5] format code --- .../context.py | 1 + .../execution.py | 15 +- src/aws_durable_execution_sdk_python/hook.py | 96 --------- .../lambda_service.py | 6 + .../plugin.py | 203 ++++++++++++++++++ src/aws_durable_execution_sdk_python/state.py | 33 ++- tests/{hook_test.py => plugin_test.py} | 51 +++-- 7 files changed, 280 insertions(+), 125 deletions(-) delete mode 100644 src/aws_durable_execution_sdk_python/hook.py create mode 100644 src/aws_durable_execution_sdk_python/plugin.py rename tests/{hook_test.py => plugin_test.py} (85%) diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index bfded98..27cbb90 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -37,6 +37,7 @@ from aws_durable_execution_sdk_python.operation.wait_for_condition import ( WaitForConditionOperationExecutor, ) +from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin from aws_durable_execution_sdk_python.serdes import ( PassThroughSerDes, SerDes, diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 8172263..cf375f6 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -6,7 +6,6 @@ import logging from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from enum import Enum from typing import TYPE_CHECKING, Any from aws_durable_execution_sdk_python.context import DurableContext @@ -19,9 +18,11 @@ InvocationError, SuspendExecution, ) +from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin from aws_durable_execution_sdk_python.lambda_service import ( DurableServiceClient, ErrorObject, + InvocationStatus, LambdaClient, Operation, OperationType, @@ -149,12 +150,6 @@ def from_durable_execution_invocation_input( ) -class InvocationStatus(Enum): - SUCCEEDED = "SUCCEEDED" - FAILED = "FAILED" - PENDING = "PENDING" - - @dataclass(frozen=True) class DurableExecutionInvocationOutput: """Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns. @@ -212,11 +207,14 @@ def durable_execution( func: Callable[[Any, DurableContext], Any] | None = None, *, boto3_client: Boto3LambdaClient | None = None, + plugins: list[DurableExecutionPlugin] | None = None, ) -> Callable[[Any, LambdaContext], Any]: # Decorator called with parameters if func is None: logger.debug("Decorator called with parameters") - return functools.partial(durable_execution, boto3_client=boto3_client) + return functools.partial( + durable_execution, boto3_client=boto3_client, plugins=plugins + ) logger.debug("Starting durable execution handler...") @@ -254,6 +252,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: initial_checkpoint_token=invocation_input.checkpoint_token, operations={}, service_client=service_client, + plugins=plugins or [], # If there are operations other than the initial EXECUTION one, current state is in replay mode replay_status=ReplayStatus.REPLAY if len(invocation_input.initial_execution_state.operations) > 1 diff --git a/src/aws_durable_execution_sdk_python/hook.py b/src/aws_durable_execution_sdk_python/hook.py deleted file mode 100644 index 9fde017..0000000 --- a/src/aws_durable_execution_sdk_python/hook.py +++ /dev/null @@ -1,96 +0,0 @@ -import datetime -from abc import ABC -from dataclasses import dataclass - -from aws_durable_execution_sdk_python.execution import InvocationStatus -from aws_durable_execution_sdk_python.lambda_service import ( - OperationType, - OperationStatus, - OperationAction, - OperationSubType, - ErrorObject, -) - - -@dataclass -class OperationStartInfo: - operation_id: str - operation_type: OperationType - sub_type: OperationSubType | None = None - name: str | None = None - parent_id: str | None = None - start_timestamp: datetime.datetime | None = None - - -@dataclass -class OperationEndInfo(OperationStartInfo): - status: OperationStatus = OperationStatus.SUCCEEDED - end_timestamp: datetime.datetime | None = None - attempt: int = 1 - error: ErrorObject | None = None - - -@dataclass -class AttemptStartInfo(OperationStartInfo): - attempt: int = 1 - - -@dataclass -class AttemptEndInfo(AttemptStartInfo): - outcome: OperationAction = OperationAction.SUCCEED - error: ErrorObject | None = None - next_attempt_delay_seconds: int | None = None - - -@dataclass -class InvocationStartInfo: - request_id: str - execution_arn: str - start_time: datetime.datetime - - -@dataclass -class InvocationEndInfo(InvocationStartInfo): - status: InvocationStatus = InvocationStatus.SUCCEEDED - error: ErrorObject | None = None - - -@dataclass -class ExecutionStartInfo(InvocationStartInfo): - pass - - -@dataclass -class ExecutionEndInfo(InvocationEndInfo): - pass - - -class DurableExecutionPlugin(ABC): - """Base class for plugins. Override only the methods you need.""" - - def on_execution_start(self, info: ExecutionStartInfo) -> None: - pass - - def on_execution_end(self, info: ExecutionEndInfo) -> None: - pass - - def on_invocation_start(self, info: InvocationStartInfo) -> None: - pass - - def on_invocation_end(self, info: InvocationEndInfo) -> None: - pass - - def on_operation_start(self, info: OperationStartInfo) -> None: - pass - - def on_operation_end(self, info: OperationEndInfo) -> None: - pass - - def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: - pass - - def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: - pass - - # Todo: further discussions required to finalize the following interface - # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass diff --git a/src/aws_durable_execution_sdk_python/lambda_service.py b/src/aws_durable_execution_sdk_python/lambda_service.py index 51a0049..d1b2381 100644 --- a/src/aws_durable_execution_sdk_python/lambda_service.py +++ b/src/aws_durable_execution_sdk_python/lambda_service.py @@ -96,6 +96,12 @@ class OperationSubType(Enum): CHAINED_INVOKE = "ChainedInvoke" +class InvocationStatus(Enum): + SUCCEEDED = "SUCCEEDED" + FAILED = "FAILED" + PENDING = "PENDING" + + @dataclass(frozen=True) class ExecutionDetails: input_payload: str | None = None diff --git a/src/aws_durable_execution_sdk_python/plugin.py b/src/aws_durable_execution_sdk_python/plugin.py new file mode 100644 index 0000000..6f63246 --- /dev/null +++ b/src/aws_durable_execution_sdk_python/plugin.py @@ -0,0 +1,203 @@ +import datetime +import logging +from abc import ABC +from dataclasses import dataclass + +from aws_durable_execution_sdk_python.lambda_service import ( + OperationType, + OperationStatus, + OperationAction, + OperationSubType, + ErrorObject, + InvocationStatus, + Operation, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class OperationStartInfo: + operation_id: str + operation_type: OperationType + sub_type: OperationSubType | None = None + name: str | None = None + parent_id: str | None = None + start_timestamp: datetime.datetime | None = None + + +@dataclass +class OperationEndInfo(OperationStartInfo): + status: OperationStatus = OperationStatus.SUCCEEDED + end_timestamp: datetime.datetime | None = None + attempt: int = 1 + error: ErrorObject | None = None + + +@dataclass +class AttemptStartInfo(OperationStartInfo): + attempt: int = 1 + + +@dataclass +class AttemptEndInfo(AttemptStartInfo): + outcome: OperationAction = OperationAction.SUCCEED + error: ErrorObject | None = None + next_attempt_delay_seconds: int | None = None + + +@dataclass +class InvocationStartInfo: + request_id: str + execution_arn: str + start_time: datetime.datetime + + +@dataclass +class InvocationEndInfo(InvocationStartInfo): + status: InvocationStatus = InvocationStatus.SUCCEEDED + error: ErrorObject | None = None + + +@dataclass +class ExecutionStartInfo(InvocationStartInfo): + pass + + +@dataclass +class ExecutionEndInfo(InvocationEndInfo): + pass + + +class DurableExecutionPlugin(ABC): + """Base class for plugins. Override only the methods you need.""" + + def on_execution_start(self, info: ExecutionStartInfo) -> None: + pass + + def on_execution_end(self, info: ExecutionEndInfo) -> None: + pass + + def on_invocation_start(self, info: InvocationStartInfo) -> None: + pass + + def on_invocation_end(self, info: InvocationEndInfo) -> None: + pass + + def on_operation_start(self, info: OperationStartInfo) -> None: + pass + + def on_operation_end(self, info: OperationEndInfo) -> None: + pass + + def on_operation_attempt_start(self, info: AttemptStartInfo) -> None: + pass + + def on_operation_attempt_end(self, info: AttemptEndInfo) -> None: + pass + + # Todo: further discussions required to finalize the following interface + # def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass + + +def execute_plugins(plugins: list[DurableExecutionPlugin], info): + for plugin in plugins: + try: + match info: + case ExecutionStartInfo(): + plugin.on_execution_start(info) + case ExecutionEndInfo(): + plugin.on_execution_end(info) + case InvocationStartInfo(): + plugin.on_invocation_start(info) + case InvocationEndInfo(): + plugin.on_invocation_end(info) + case OperationStartInfo(): + plugin.on_operation_start(info) + case OperationEndInfo(): + plugin.on_operation_end(info) + case AttemptStartInfo(): + plugin.on_operation_attempt_start(info) + case AttemptEndInfo(): + plugin.on_operation_attempt_end(info) + case _: + raise ValueError(f"Unknown info type: {type(info)}") + except Exception: + logger.exception("Plugin %s failed", plugin.__class__.__name__) + + +def get_operation_info_from_operation( + operation: Operation, +) -> OperationStartInfo | OperationEndInfo | None: + if operation is None: + raise ValueError("Operation is None") + if operation.status in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ]: + return OperationEndInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + status=operation.status, + ) + if operation.status is OperationStatus.STARTED: + return OperationStartInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + ) + return None + + +def get_attempt_info_from_operation( + operation: Operation, +) -> AttemptStartInfo | AttemptEndInfo | None: + if operation is None: + raise ValueError("Operation is None") + if operation.status is OperationStatus.STARTED and operation.step_details: + return AttemptStartInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + attempt=operation.step_details.attempt, + ) + if ( + operation.status + in [ + OperationStatus.SUCCEEDED, + OperationStatus.FAILED, + OperationStatus.TIMED_OUT, + OperationStatus.CANCELLED, + OperationStatus.STOPPED, + ] + and operation.step_details + ): + return AttemptEndInfo( + operation_id=operation.operation_id, + operation_type=operation.operation_type, + sub_type=operation.sub_type, + name=operation.name, + parent_id=operation.parent_id, + start_timestamp=operation.start_timestamp, + end_timestamp=operation.end_timestamp, + attempt=operation.step_details.attempt, + outcome=OperationAction.SUCCEED + if operation.status is OperationStatus.SUCCEEDED + else OperationAction.FAIL, + error=operation.step_details.error, + ) + return None diff --git a/src/aws_durable_execution_sdk_python/state.py b/src/aws_durable_execution_sdk_python/state.py index 0d9cb0f..91bf005 100644 --- a/src/aws_durable_execution_sdk_python/state.py +++ b/src/aws_durable_execution_sdk_python/state.py @@ -30,6 +30,11 @@ OperationUpdate, StateOutput, ) +from aws_durable_execution_sdk_python.plugin import ( + DurableExecutionPlugin, + get_operation_info_from_operation, + execute_plugins, +) from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock if TYPE_CHECKING: @@ -229,6 +234,7 @@ def __init__( initial_checkpoint_token: str, operations: MutableMapping[str, Operation], service_client: DurableServiceClient, + plugins: list[DurableExecutionPlugin], batcher_config: CheckpointBatcherConfig | None = None, replay_status: ReplayStatus = ReplayStatus.NEW, ): @@ -236,6 +242,7 @@ def __init__( self._current_checkpoint_token: str = initial_checkpoint_token self.operations: MutableMapping[str, Operation] = operations self._service_client: DurableServiceClient = service_client + self._plugins: list[DurableExecutionPlugin] = plugins self._ordered_checkpoint_lock: OrderedLock = OrderedLock() self._operations_lock: Lock = Lock() @@ -267,7 +274,7 @@ def fetch_paginated_operations( initial_operations: list[Operation], checkpoint_token: str, next_marker: str | None, - ) -> None: + ) -> list[Operation]: """Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe. The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety. @@ -277,6 +284,9 @@ def fetch_paginated_operations( checkpoint_token: checkpoint token used to call Durable Functions API. next_marker: a marker indicates that there are paginated operations. + Returns: + List of all operations fetched from the Durable Functions API + Raises: GetExecutionStateError: If the API call fails. The error is logged with structured extras before re-raising. Callers are responsible @@ -308,6 +318,7 @@ def fetch_paginated_operations( self.operations.update( {op.operation_id: op for op in all_operations} ) + return all_operations def get_input_payload(self) -> str | None: # It is possible that backend will not provide an execution operation @@ -670,12 +681,15 @@ def checkpoint_batches_forever(self) -> None: current_checkpoint_token = output.checkpoint_token # Fetch new operations from the API before unblocking sync waiters - self.fetch_paginated_operations( + updated_operations = self.fetch_paginated_operations( output.new_execution_state.operations, output.checkpoint_token, output.new_execution_state.next_marker, ) + for operation in updated_operations: + self._execute_plugin(operation) + # Signal completion for any synchronous operations for queued_op in batch: if queued_op.completion_event is not None: @@ -861,6 +875,21 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]: ) return batch + def _execute_plugin(self, operation: Operation) -> None: + """Execute any registered plugins for the given queued operation. + + This method iterates through all registered plugins and calls their + `on_checkpoint` method with the queued operation as an argument. Plugins + are executed in the order they were registered. + + Args: + operation: the operation to execute plugins for + """ + if info := get_operation_info_from_operation(operation): + execute_plugins(self._plugins, info) + if info := get_attempt_info_from_operation(operation): + execute_plugins(self._plugins, info) + @staticmethod def _calculate_operation_size(queued_op: QueuedOperation) -> int: """Calculate the serialized size of a queued operation for batching limits. diff --git a/tests/hook_test.py b/tests/plugin_test.py similarity index 85% rename from tests/hook_test.py rename to tests/plugin_test.py index edacff4..b61f56a 100644 --- a/tests/hook_test.py +++ b/tests/plugin_test.py @@ -1,7 +1,7 @@ import datetime import unittest -from aws_durable_execution_sdk_python.hook import ( +from aws_durable_execution_sdk_python.plugin import ( AttemptEndInfo, AttemptStartInfo, DurableExecutionPlugin, @@ -55,16 +55,16 @@ def test_inherits_operation_start_info(self): self.assertTrue(issubclass(OperationEndInfo, OperationStartInfo)) def test_defaults(self): - info = OperationEndInfo( - operation_id="op-1", operation_type=OperationType.STEP - ) + info = OperationEndInfo(operation_id="op-1", operation_type=OperationType.STEP) self.assertEqual(info.status, OperationStatus.SUCCEEDED) self.assertIsNone(info.end_timestamp) self.assertEqual(info.attempt, 1) self.assertIsNone(info.error) def test_with_error(self): - err = ErrorObject(message="fail", type="RuntimeError", data=None, stack_trace=None) + err = ErrorObject( + message="fail", type="RuntimeError", data=None, stack_trace=None + ) info = OperationEndInfo( operation_id="op-1", operation_type=OperationType.STEP, @@ -82,9 +82,7 @@ def test_inherits_operation_start_info(self): self.assertTrue(issubclass(AttemptStartInfo, OperationStartInfo)) def test_default_attempt(self): - info = AttemptStartInfo( - operation_id="op-1", operation_type=OperationType.STEP - ) + info = AttemptStartInfo(operation_id="op-1", operation_type=OperationType.STEP) self.assertEqual(info.attempt, 1) def test_custom_attempt(self): @@ -99,15 +97,15 @@ def test_inherits_attempt_start_info(self): self.assertTrue(issubclass(AttemptEndInfo, AttemptStartInfo)) def test_defaults(self): - info = AttemptEndInfo( - operation_id="op-1", operation_type=OperationType.STEP - ) + info = AttemptEndInfo(operation_id="op-1", operation_type=OperationType.STEP) self.assertEqual(info.outcome, OperationAction.SUCCEED) self.assertIsNone(info.error) self.assertIsNone(info.next_attempt_delay_seconds) def test_retry_with_delay(self): - err = ErrorObject(message="timeout", type="TimeoutError", data=None, stack_trace=None) + err = ErrorObject( + message="timeout", type="TimeoutError", data=None, stack_trace=None + ) info = AttemptEndInfo( operation_id="op-1", operation_type=OperationType.STEP, @@ -124,7 +122,9 @@ class TestInvocationStartInfo(unittest.TestCase): def test_fields(self): ts = datetime.datetime(2025, 6, 1, tzinfo=datetime.UTC) info = InvocationStartInfo( - request_id="req-1", execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", start_time=ts + request_id="req-1", + execution_arn="arn:aws:lambda:us-east-1:123:durable:abc", + start_time=ts, ) self.assertEqual(info.request_id, "req-1") self.assertEqual(info.execution_arn, "arn:aws:lambda:us-east-1:123:durable:abc") @@ -188,13 +188,21 @@ def test_default_methods_are_noop(self): plugin = _NoOpPlugin() ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) - exec_start = ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts) + exec_start = ExecutionStartInfo( + request_id="r", execution_arn="a", start_time=ts + ) exec_end = ExecutionEndInfo(request_id="r", execution_arn="a", start_time=ts) - inv_start = InvocationStartInfo(request_id="r", execution_arn="a", start_time=ts) + inv_start = InvocationStartInfo( + request_id="r", execution_arn="a", start_time=ts + ) inv_end = InvocationEndInfo(request_id="r", execution_arn="a", start_time=ts) - op_start = OperationStartInfo(operation_id="o", operation_type=OperationType.STEP) + op_start = OperationStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) op_end = OperationEndInfo(operation_id="o", operation_type=OperationType.STEP) - att_start = AttemptStartInfo(operation_id="o", operation_type=OperationType.STEP) + att_start = AttemptStartInfo( + operation_id="o", operation_type=OperationType.STEP + ) att_end = AttemptEndInfo(operation_id="o", operation_type=OperationType.STEP) self.assertIsNone(plugin.on_execution_start(exec_start)) @@ -211,8 +219,12 @@ def test_subclass_override(self): plugin = _TrackingPlugin() ts = datetime.datetime(2025, 1, 1, tzinfo=datetime.UTC) - plugin.on_execution_start(ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts)) - plugin.on_operation_start(OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT)) + plugin.on_execution_start( + ExecutionStartInfo(request_id="r", execution_arn="a", start_time=ts) + ) + plugin.on_operation_start( + OperationStartInfo(operation_id="o", operation_type=OperationType.WAIT) + ) self.assertEqual(plugin.calls, ["execution_start:r", "operation_start:o"]) @@ -223,6 +235,7 @@ def test_cannot_instantiate_abc_directly(self): class _NoOpPlugin(DurableExecutionPlugin): """Concrete subclass that inherits all default no-op methods.""" + pass