Skip to content
Draft

Hooks #371

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/aws_durable_execution_sdk_python/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from aws_durable_execution_sdk_python.operation.wait_for_condition import (
WaitForConditionOperationExecutor,
)
from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin
from aws_durable_execution_sdk_python.serdes import (
PassThroughSerDes,
SerDes,
Expand Down
15 changes: 7 additions & 8 deletions src/aws_durable_execution_sdk_python/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import logging
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any

from aws_durable_execution_sdk_python.context import DurableContext
Expand All @@ -19,9 +18,11 @@
InvocationError,
SuspendExecution,
)
from aws_durable_execution_sdk_python.plugin import DurableExecutionPlugin
from aws_durable_execution_sdk_python.lambda_service import (
DurableServiceClient,
ErrorObject,
InvocationStatus,
LambdaClient,
Operation,
OperationType,
Expand Down Expand Up @@ -149,12 +150,6 @@ def from_durable_execution_invocation_input(
)


class InvocationStatus(Enum):
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
PENDING = "PENDING"


@dataclass(frozen=True)
class DurableExecutionInvocationOutput:
"""Representation the DurableExecutionInvocationOutput. This is what the Durable lambda handler returns.
Expand Down Expand Up @@ -212,11 +207,14 @@ def durable_execution(
func: Callable[[Any, DurableContext], Any] | None = None,
*,
boto3_client: Boto3LambdaClient | None = None,
plugins: list[DurableExecutionPlugin] | None = None,
) -> Callable[[Any, LambdaContext], Any]:
# Decorator called with parameters
if func is None:
logger.debug("Decorator called with parameters")
return functools.partial(durable_execution, boto3_client=boto3_client)
return functools.partial(
durable_execution, boto3_client=boto3_client, plugins=plugins
)

logger.debug("Starting durable execution handler...")

Expand Down Expand Up @@ -254,6 +252,7 @@ def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]:
initial_checkpoint_token=invocation_input.checkpoint_token,
operations={},
service_client=service_client,
plugins=plugins or [],
# If there are operations other than the initial EXECUTION one, current state is in replay mode
replay_status=ReplayStatus.REPLAY
if len(invocation_input.initial_execution_state.operations) > 1
Expand Down
6 changes: 6 additions & 0 deletions src/aws_durable_execution_sdk_python/lambda_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ class OperationSubType(Enum):
CHAINED_INVOKE = "ChainedInvoke"


class InvocationStatus(Enum):
SUCCEEDED = "SUCCEEDED"
FAILED = "FAILED"
PENDING = "PENDING"


@dataclass(frozen=True)
class ExecutionDetails:
input_payload: str | None = None
Expand Down
203 changes: 203 additions & 0 deletions src/aws_durable_execution_sdk_python/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
import datetime
import logging
from abc import ABC
from dataclasses import dataclass

from aws_durable_execution_sdk_python.lambda_service import (
OperationType,
OperationStatus,
OperationAction,
OperationSubType,
ErrorObject,
InvocationStatus,
Operation,
)

logger = logging.getLogger(__name__)


@dataclass
class OperationStartInfo:
operation_id: str
operation_type: OperationType
sub_type: OperationSubType | None = None
name: str | None = None
parent_id: str | None = None
start_timestamp: datetime.datetime | None = None


@dataclass
class OperationEndInfo(OperationStartInfo):
status: OperationStatus = OperationStatus.SUCCEEDED
end_timestamp: datetime.datetime | None = None
attempt: int = 1
error: ErrorObject | None = None


@dataclass
class AttemptStartInfo(OperationStartInfo):
attempt: int = 1


@dataclass
class AttemptEndInfo(AttemptStartInfo):
outcome: OperationAction = OperationAction.SUCCEED
error: ErrorObject | None = None
next_attempt_delay_seconds: int | None = None


@dataclass
class InvocationStartInfo:
request_id: str
execution_arn: str
start_time: datetime.datetime


@dataclass
class InvocationEndInfo(InvocationStartInfo):
status: InvocationStatus = InvocationStatus.SUCCEEDED
error: ErrorObject | None = None


@dataclass
class ExecutionStartInfo(InvocationStartInfo):
pass


@dataclass
class ExecutionEndInfo(InvocationEndInfo):
pass


class DurableExecutionPlugin(ABC):
"""Base class for plugins. Override only the methods you need."""

def on_execution_start(self, info: ExecutionStartInfo) -> None:
pass

def on_execution_end(self, info: ExecutionEndInfo) -> None:
pass

def on_invocation_start(self, info: InvocationStartInfo) -> None:
pass

def on_invocation_end(self, info: InvocationEndInfo) -> None:
pass

def on_operation_start(self, info: OperationStartInfo) -> None:
pass

def on_operation_end(self, info: OperationEndInfo) -> None:
pass

def on_operation_attempt_start(self, info: AttemptStartInfo) -> None:
pass

def on_operation_attempt_end(self, info: AttemptEndInfo) -> None:
pass

# Todo: further discussions required to finalize the following interface
# def enrich_log_context(self, info: OperationStartInfo | None) -> Dict[str, Any] | None: pass


def execute_plugins(plugins: list[DurableExecutionPlugin], info):
for plugin in plugins:
try:
match info:
case ExecutionStartInfo():
plugin.on_execution_start(info)
case ExecutionEndInfo():
plugin.on_execution_end(info)
case InvocationStartInfo():
plugin.on_invocation_start(info)
case InvocationEndInfo():
plugin.on_invocation_end(info)
case OperationStartInfo():
plugin.on_operation_start(info)
case OperationEndInfo():
plugin.on_operation_end(info)
case AttemptStartInfo():
plugin.on_operation_attempt_start(info)
case AttemptEndInfo():
plugin.on_operation_attempt_end(info)
case _:
raise ValueError(f"Unknown info type: {type(info)}")
except Exception:
logger.exception("Plugin %s failed", plugin.__class__.__name__)


def get_operation_info_from_operation(
operation: Operation,
) -> OperationStartInfo | OperationEndInfo | None:
if operation is None:
raise ValueError("Operation is None")
if operation.status in [
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.TIMED_OUT,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
]:
return OperationEndInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_timestamp=operation.start_timestamp,
end_timestamp=operation.end_timestamp,
status=operation.status,
)
if operation.status is OperationStatus.STARTED:
return OperationStartInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_timestamp=operation.start_timestamp,
)
return None


def get_attempt_info_from_operation(
operation: Operation,
) -> AttemptStartInfo | AttemptEndInfo | None:
if operation is None:
raise ValueError("Operation is None")
if operation.status is OperationStatus.STARTED and operation.step_details:
return AttemptStartInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_timestamp=operation.start_timestamp,
attempt=operation.step_details.attempt,
)
if (
operation.status
in [
OperationStatus.SUCCEEDED,
OperationStatus.FAILED,
OperationStatus.TIMED_OUT,
OperationStatus.CANCELLED,
OperationStatus.STOPPED,
]
and operation.step_details
):
return AttemptEndInfo(
operation_id=operation.operation_id,
operation_type=operation.operation_type,
sub_type=operation.sub_type,
name=operation.name,
parent_id=operation.parent_id,
start_timestamp=operation.start_timestamp,
end_timestamp=operation.end_timestamp,
attempt=operation.step_details.attempt,
outcome=OperationAction.SUCCEED
if operation.status is OperationStatus.SUCCEEDED
else OperationAction.FAIL,
error=operation.step_details.error,
)
return None
33 changes: 31 additions & 2 deletions src/aws_durable_execution_sdk_python/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@
OperationUpdate,
StateOutput,
)
from aws_durable_execution_sdk_python.plugin import (
DurableExecutionPlugin,
get_operation_info_from_operation,
execute_plugins,
)
from aws_durable_execution_sdk_python.threading import CompletionEvent, OrderedLock

if TYPE_CHECKING:
Expand Down Expand Up @@ -229,13 +234,15 @@ def __init__(
initial_checkpoint_token: str,
operations: MutableMapping[str, Operation],
service_client: DurableServiceClient,
plugins: list[DurableExecutionPlugin],
batcher_config: CheckpointBatcherConfig | None = None,
replay_status: ReplayStatus = ReplayStatus.NEW,
):
self.durable_execution_arn: str = durable_execution_arn
self._current_checkpoint_token: str = initial_checkpoint_token
self.operations: MutableMapping[str, Operation] = operations
self._service_client: DurableServiceClient = service_client
self._plugins: list[DurableExecutionPlugin] = plugins
self._ordered_checkpoint_lock: OrderedLock = OrderedLock()
self._operations_lock: Lock = Lock()

Expand Down Expand Up @@ -267,7 +274,7 @@ def fetch_paginated_operations(
initial_operations: list[Operation],
checkpoint_token: str,
next_marker: str | None,
) -> None:
) -> list[Operation]:
"""Add initial operations and fetch all paginated operations from the Durable Functions API. This method is thread_safe.

The checkpoint_token is passed explicitly as a parameter rather than using the instance variable to ensure thread safety.
Expand All @@ -277,6 +284,9 @@ def fetch_paginated_operations(
checkpoint_token: checkpoint token used to call Durable Functions API.
next_marker: a marker indicates that there are paginated operations.

Returns:
List of all operations fetched from the Durable Functions API

Raises:
GetExecutionStateError: If the API call fails. The error is logged
with structured extras before re-raising. Callers are responsible
Expand Down Expand Up @@ -308,6 +318,7 @@ def fetch_paginated_operations(
self.operations.update(
{op.operation_id: op for op in all_operations}
)
return all_operations

def get_input_payload(self) -> str | None:
# It is possible that backend will not provide an execution operation
Expand Down Expand Up @@ -670,12 +681,15 @@ def checkpoint_batches_forever(self) -> None:
current_checkpoint_token = output.checkpoint_token

# Fetch new operations from the API before unblocking sync waiters
self.fetch_paginated_operations(
updated_operations = self.fetch_paginated_operations(
output.new_execution_state.operations,
output.checkpoint_token,
output.new_execution_state.next_marker,
)

for operation in updated_operations:
self._execute_plugin(operation)

# Signal completion for any synchronous operations
for queued_op in batch:
if queued_op.completion_event is not None:
Expand Down Expand Up @@ -861,6 +875,21 @@ def _collect_checkpoint_batch(self) -> list[QueuedOperation]:
)
return batch

def _execute_plugin(self, operation: Operation) -> None:
"""Execute any registered plugins for the given queued operation.

This method iterates through all registered plugins and calls their
`on_checkpoint` method with the queued operation as an argument. Plugins
are executed in the order they were registered.

Args:
operation: the operation to execute plugins for
"""
if info := get_operation_info_from_operation(operation):
execute_plugins(self._plugins, info)
if info := get_attempt_info_from_operation(operation):
execute_plugins(self._plugins, info)

@staticmethod
def _calculate_operation_size(queued_op: QueuedOperation) -> int:
"""Calculate the serialized size of a queued operation for batching limits.
Expand Down
Loading
Loading