From 30ea18e99e8d50ed2df3adcbcaa5df07751b79d0 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 7 May 2026 11:01:49 -0700 Subject: [PATCH] Revert "[refactor]: use an executor class instead of closure (#375)" This reverts commit 5f4b23e502e93b0681d00810ed5b6cf84e8a4487. --- .../execution.py | 340 ++++++++---------- 1 file changed, 146 insertions(+), 194 deletions(-) diff --git a/src/aws_durable_execution_sdk_python/execution.py b/src/aws_durable_execution_sdk_python/execution.py index 977834d..df535b4 100644 --- a/src/aws_durable_execution_sdk_python/execution.py +++ b/src/aws_durable_execution_sdk_python/execution.py @@ -7,13 +7,14 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from enum import Enum -from typing import TYPE_CHECKING, Any, cast, Callable +from typing import TYPE_CHECKING, Any from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.exceptions import ( BackgroundThreadError, BotoClientError, CheckpointError, + DurableExecutionsError, ExecutionError, InvocationError, SuspendExecution, @@ -23,12 +24,13 @@ ErrorObject, LambdaClient, Operation, + OperationType, OperationUpdate, ) from aws_durable_execution_sdk_python.state import ExecutionState, ReplayStatus if TYPE_CHECKING: - from collections.abc import MutableMapping + from collections.abc import Callable, MutableMapping from mypy_boto3_lambda import LambdaClient as Boto3LambdaClient @@ -49,9 +51,9 @@ class InitialExecutionState: @staticmethod def from_dict(input_dict: MutableMapping[str, Any]) -> InitialExecutionState: - operations = [ - Operation.from_dict(op) for op in input_dict.get("Operations", []) - ] + operations = [] + if input_operations := input_dict.get("Operations"): + operations = [Operation.from_dict(op) for op in input_operations] return InitialExecutionState( operations=operations, next_marker=input_dict.get("NextMarker", ""), @@ -59,9 +61,9 @@ def from_dict(input_dict: MutableMapping[str, Any]) -> InitialExecutionState: @staticmethod def from_json_dict(input_dict: MutableMapping[str, Any]) -> InitialExecutionState: - operations = [ - Operation.from_json_dict(op) for op in input_dict.get("Operations", []) - ] + operations = [] + if input_operations := input_dict.get("Operations"): + operations = [Operation.from_json_dict(op) for op in input_operations] return InitialExecutionState( operations=operations, next_marker=input_dict.get("NextMarker", ""), @@ -197,6 +199,11 @@ def to_dict(self) -> MutableMapping[str, Any]: return result + @classmethod + def create_succeeded(cls, result: str) -> DurableExecutionInvocationOutput: + """Create a succeeded invocation output.""" + return cls(status=InvocationStatus.SUCCEEDED, result=result) + # endregion Invocation models @@ -210,85 +217,51 @@ def durable_execution( if func is None: logger.debug("Decorator called with parameters") return functools.partial(durable_execution, boto3_client=boto3_client) - else: - logger.debug("Starting durable execution handler...") - - def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: - executor = DurableExecutionExecutor( - cast(Callable[[Any, DurableContext], Any], func), - boto3_client, - event, - context, - ) - return executor.execute() - return wrapper + logger.debug("Starting durable execution handler...") + def wrapper(event: Any, context: LambdaContext) -> MutableMapping[str, Any]: + invocation_input: DurableExecutionInvocationInput + service_client: DurableServiceClient -class DurableExecutionExecutor: - def __init__( - self, - func: Callable[[Any, DurableContext], Any], - boto3_client: Boto3LambdaClient | None, - event: Any, - context: LambdaContext, - ): - self.func = func - self.boto3_client = boto3_client - self.event = event - self.context = context - self.invocation_input = self._parse_invocation_input(event) - self.service_client = self._parse_service_client(event, boto3_client) - - def _parse_invocation_input(self, event: Any) -> DurableExecutionInvocationInput: # event likely only to be DurableExecutionInvocationInputWithClient when directly injected by test framework - invocation_input: ( - DurableExecutionInvocationInputWithClient | DurableExecutionInvocationInput - ) if isinstance(event, DurableExecutionInvocationInputWithClient): + logger.debug("durableExecutionArn: %s", event.durable_execution_arn) invocation_input = event + service_client = invocation_input.service_client else: try: + logger.debug( + "durableExecutionArn: %s", event.get("DurableExecutionArn") + ) invocation_input = DurableExecutionInvocationInput.from_json_dict(event) - except (KeyError, TypeError, AttributeError): + except (KeyError, TypeError, AttributeError) as e: msg = ( "Unexpected payload provided to start the durable execution. " "Check your resource configurations to confirm the durability is set." ) - # throws ExecutionError to terminate the invocation - self._handle_execution_output( - exception=ExecutionError(msg), retryable=True - ) - # add a redundant raise to make type checker happy - raise ExecutionError(msg) - - logger.debug("durableExecutionArn: %s", invocation_input.durable_execution_arn) - return invocation_input + raise ExecutionError(msg) from e - @staticmethod - def _parse_service_client(event, boto3_client): - if isinstance(event, DurableExecutionInvocationInputWithClient): - return event.service_client - elif boto3_client: - return LambdaClient(boto3_client) - else: # Use custom client if provided, otherwise initialize from environment - return LambdaClient.initialize_client() + service_client = ( + LambdaClient(client=boto3_client) + if boto3_client is not None + else LambdaClient.initialize_client() + ) - def execute(self): execution_state: ExecutionState = ExecutionState( - durable_execution_arn=self.invocation_input.durable_execution_arn, - initial_checkpoint_token=self.invocation_input.checkpoint_token, + durable_execution_arn=invocation_input.durable_execution_arn, + initial_checkpoint_token=invocation_input.checkpoint_token, operations={}, - service_client=self.service_client, + service_client=service_client, replay_status=ReplayStatus.NEW, ) try: execution_state.fetch_paginated_operations( - self.invocation_input.initial_execution_state.operations, - self.invocation_input.checkpoint_token, - self.invocation_input.initial_execution_state.next_marker, + invocation_input.initial_execution_state.operations, + invocation_input.checkpoint_token, + invocation_input.initial_execution_state.next_marker, ) except BotoClientError as e: # Non-retryable Durable API errors (e.g., customer configuration issues, @@ -299,9 +272,11 @@ def execute(self): "without retry.", extra=e.build_logger_extras(), ) - return self._handle_execution_output( - exception=e, retryable=e.is_retryable() - ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), + ).to_dict() + raise execution_state.mark_replaying_if_prior_operations_exist() @@ -313,15 +288,15 @@ def execute(self): if raw_input_payload and raw_input_payload.strip(): try: input_event = json.loads(raw_input_payload) - except json.JSONDecodeError as e: + except json.JSONDecodeError: logger.exception( "Failed to parse input payload as JSON: payload: %r", raw_input_payload, ) - self._handle_execution_output(exception=e, retryable=True) + raise durable_context: DurableContext = DurableContext.from_lambda_context( - state=execution_state, lambda_context=self.context + state=execution_state, lambda_context=context ) # Use ThreadPoolExecutor for concurrent execution of user code and background checkpoint processing @@ -336,13 +311,13 @@ def execute(self): # Thread 2: Execute user function logger.debug( - "%s entering user-space...", self.invocation_input.durable_execution_arn + "%s entering user-space...", invocation_input.durable_execution_arn ) - user_future = executor.submit(self.func, input_event, durable_context) + user_future = executor.submit(func, input_event, durable_context) logger.debug( "%s waiting for user code completion...", - self.invocation_input.durable_execution_arn, + invocation_input.durable_execution_arn, ) try: @@ -352,44 +327,71 @@ def execute(self): # done with userland logger.debug( "%s exiting user-space...", - self.invocation_input.durable_execution_arn, + invocation_input.durable_execution_arn, ) - serialized_result = self._handle_large_result(execution_state, result) + serialized_result = json.dumps(result) + # large response handling here. Remember if checkpointing to complete, NOT to include + # payload in response + if ( + serialized_result + and len(serialized_result) > LAMBDA_RESPONSE_SIZE_LIMIT + ): + logger.debug( + "Response size (%s bytes) exceeds Lambda limit (%s) bytes). Checkpointing result.", + len(serialized_result), + LAMBDA_RESPONSE_SIZE_LIMIT, + ) + success_operation = OperationUpdate.create_execution_succeed( + payload=serialized_result + ) + # Checkpoint large result with blocking (is_sync=True, default). + # Must ensure the result is persisted before returning to Lambda. + # Large results exceed Lambda response limits and must be stored durably + # before the execution completes. + try: + execution_state.create_checkpoint( + success_operation, is_sync=True + ) + except CheckpointError as e: + return handle_checkpoint_error(e).to_dict() + return DurableExecutionInvocationOutput.create_succeeded( + result="" + ).to_dict() - return self._handle_execution_output(result=serialized_result) + return DurableExecutionInvocationOutput.create_succeeded( + result=serialized_result + ).to_dict() except BackgroundThreadError as bg_error: # Background checkpoint system failed - propagated through CompletionEvent # Do not attempt to checkpoint anything, just terminate immediately - cause = bg_error.source_exception - - if isinstance(cause, BotoClientError): + if isinstance(bg_error.source_exception, BotoClientError): logger.exception( "Checkpoint processing failed", - extra=cause.build_logger_extras(), + extra=bg_error.source_exception.build_logger_extras(), ) # Non-retryable Durable API errors (e.g., customer configuration issues, # 4xx client errors) will never succeed on retry — fail the execution immediately. - if not cause.is_retryable(): + if not bg_error.source_exception.is_retryable(): logger.exception( "Non-retryable Durable API error from background thread. Must fail execution " "without retry.", - extra=cause.build_logger_extras(), + extra=bg_error.source_exception.build_logger_extras(), ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(bg_error.source_exception), + ).to_dict() else: logger.exception("Checkpoint processing failed") - - retryable = ( - not isinstance(cause, BotoClientError) or cause.is_retryable() - ) - return self._handle_execution_output( - exception=cause, retryable=retryable - ) + raise bg_error.source_exception from bg_error except SuspendExecution: # User code suspended - stop background checkpointing thread logger.debug("Suspending execution...") - return self._handle_execution_output(status=InvocationStatus.PENDING) + return DurableExecutionInvocationOutput( + status=InvocationStatus.PENDING + ).to_dict() except CheckpointError as e: # Checkpoint system is broken - stop background thread and exit immediately @@ -397,121 +399,71 @@ def execute(self): "Checkpoint system failed", extra=e.build_logger_extras(), ) - # Terminate Lambda invocation immediately and have it be retried if retryable - return self._handle_execution_output( - exception=e, retryable=e.is_retryable() - ) + return handle_checkpoint_error(e).to_dict() except InvocationError as e: - if e.is_retryable(): - logger.exception("Invocation error. Must terminate.") - else: - # Non-retryable Durable API errors (e.g., customer configuration issues, - # 4xx client errors) will never succeed on retry — fail the execution immediately. + # Non-retryable Durable API errors (e.g., customer configuration issues, + # 4xx client errors) will never succeed on retry — fail the execution immediately. + if not e.is_retryable(): logger.exception( "Non-retryable Durable API error. Must fail execution without retry.", extra=e.build_logger_extras(), # type: ignore[attr-defined] ) - return self._handle_execution_output( - exception=e, retryable=e.is_retryable() - ) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), + ).to_dict() + logger.exception("Invocation error. Must terminate.") + # Throw the error to trigger Lambda retry + raise except ExecutionError as e: logger.exception("Execution error. Must fail execution without retry.") - return self._handle_execution_output(exception=e) + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, + error=ErrorObject.from_exception(e), + ).to_dict() except Exception as e: # all user-space errors go here logger.exception("Execution failed") - try: - error = self._handle_large_error(execution_state, exception=e) - except CheckpointError as e: - # Terminate Lambda invocation immediately and have it be retried if retryable - return self._handle_execution_output( - exception=e, retryable=e.is_retryable() - ) - - # fail without an ErrorObject - return self._handle_execution_output( - status=InvocationStatus.FAILED, error=error - ) + result = DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=ErrorObject.from_exception(e) + ).to_dict() - @staticmethod - def _handle_large_result(execution_state: ExecutionState, result: Any) -> str: - # large response handling here. Remember if checkpointing to complete, NOT to include - # payload in response - serialized_result = json.dumps(result) - if serialized_result and len(serialized_result) > LAMBDA_RESPONSE_SIZE_LIMIT: - logger.debug( - "Response size (%s bytes) exceeds Lambda limit (%s) bytes). Checkpointing result.", - len(serialized_result), - LAMBDA_RESPONSE_SIZE_LIMIT, - ) - success_operation = OperationUpdate.create_execution_succeed( - payload=serialized_result - ) - # Checkpoint large result with blocking (is_sync=True, default). - # Must ensure the result is persisted before returning to Lambda. - # Large results exceed Lambda response limits and must be stored durably - # before the execution completes. - execution_state.create_checkpoint(success_operation, is_sync=True) - return "" + serialized_result = json.dumps(result) - return serialized_result - - @staticmethod - def _handle_large_error( - execution_state: ExecutionState, exception: Exception - ) -> ErrorObject | None: - # large response handling here. Remember if checkpointing to complete, NOT to include - # payload in response - error = ErrorObject.from_exception(exception) - serialized_error = json.dumps(error.to_dict()) - if serialized_error and len(serialized_error) > LAMBDA_RESPONSE_SIZE_LIMIT: - logger.debug( - "Response size (%s bytes) exceeds Lambda limit (%s) bytes). Checkpointing result.", - len(serialized_error), - LAMBDA_RESPONSE_SIZE_LIMIT, - ) - failed_operation = OperationUpdate.create_execution_fail(error=error) - # Checkpoint large result with blocking (is_sync=True, default). - # Must ensure the result is persisted before returning to Lambda. - # Large results exceed Lambda response limits and must be stored durably - # before the execution completes. - execution_state.create_checkpoint_sync(failed_operation) - - # return fail without an ErrorObject - return None - - return error - - def _handle_execution_output( - self, - result: str | None = None, - error: ErrorObject | None = None, - exception: Exception | None = None, - retryable: bool = False, - status: InvocationStatus | None = None, - ) -> MutableMapping[str, Any]: - if exception: - if retryable: - # Throw the error to trigger Lambda retry - raise exception - else: - return self._handle_execution_output( - result=result, - error=ErrorObject.from_exception(exception), - status=status, - ) + if ( + serialized_result + and len(serialized_result) > LAMBDA_RESPONSE_SIZE_LIMIT + ): + logger.debug( + "Response size (%s bytes) exceeds Lambda limit (%s) bytes). Checkpointing result.", + len(serialized_result), + LAMBDA_RESPONSE_SIZE_LIMIT, + ) + failed_operation = OperationUpdate.create_execution_fail( + error=ErrorObject.from_exception(e) + ) - if error: - output = DurableExecutionInvocationOutput( - status=InvocationStatus.FAILED, result=result, error=error - ) - elif result is not None: - output = DurableExecutionInvocationOutput( - status=InvocationStatus.SUCCEEDED, result=result - ) - elif status: - output = DurableExecutionInvocationOutput(status=status) - else: - raise ValueError("Unexpected durable execution output") - return output.to_dict() + # Checkpoint large result with blocking (is_sync=True, default). + # Must ensure the result is persisted before returning to Lambda. + # Large results exceed Lambda response limits and must be stored durably + # before the execution completes. + try: + execution_state.create_checkpoint_sync(failed_operation) + except CheckpointError as e: + return handle_checkpoint_error(e).to_dict() + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED + ).to_dict() + + return result + + return wrapper + + +def handle_checkpoint_error(error: CheckpointError) -> DurableExecutionInvocationOutput: + if error.is_retryable(): + raise error from None # Terminate Lambda immediately and have it be retried + return DurableExecutionInvocationOutput( + status=InvocationStatus.FAILED, error=ErrorObject.from_exception(error) + )