diff --git a/.gitignore b/.gitignore index 7d2b20a..bfc52e9 100644 --- a/.gitignore +++ b/.gitignore @@ -31,4 +31,6 @@ dist/ .kiro/ /examples/build/* -/examples/*.zip \ No newline at end of file +/examples/*.zip + +.env \ No newline at end of file diff --git a/examples/examples-catalog.json b/examples/examples-catalog.json index e80e3ba..fb9ab78 100644 --- a/examples/examples-catalog.json +++ b/examples/examples-catalog.json @@ -580,6 +580,28 @@ "ApplicationLogLevel": "DEBUG", "LogFormat": "JSON" } - } + }, + { + "name": "Map with Item Namer", + "description": "Map operation with custom item_namer for iteration naming", + "handler": "map_with_item_namer.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/map/map_with_item_namer.py" + }, + { + "name": "Parallel with Named Branches", + "description": "Parallel operation with named branches using ParallelBranch", + "handler": "parallel_with_named_branches.handler", + "integration": true, + "durableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + }, + "path": "./src/parallel/parallel_with_named_branches.py" + } ] } diff --git a/examples/src/map/map_with_item_namer.py b/examples/src/map/map_with_item_namer.py new file mode 100644 index 0000000..331faab --- /dev/null +++ b/examples/src/map/map_with_item_namer.py @@ -0,0 +1,30 @@ +"""Example demonstrating map operations with custom iteration naming.""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import MapConfig +from aws_durable_execution_sdk_python.context import DurableContext +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> list[str]: + """Process orders using context.map() with custom iteration names.""" + orders = [ + {"id": "order-101", "amount": 25}, + {"id": "order-102", "amount": 50}, + {"id": "order-103", "amount": 75}, + ] + + return context.map( + inputs=orders, + func=lambda ctx, order, index, _: ctx.step( + lambda _: f"processed-{order['id']}-${order['amount']}", + name=f"process_{order['id']}", + ), + name="process_orders", + config=MapConfig( + max_concurrency=2, + item_namer=lambda order, index: f"order-{order['id']}", + ), + ).get_results() diff --git a/examples/src/parallel/parallel_with_named_branches.py b/examples/src/parallel/parallel_with_named_branches.py new file mode 100644 index 0000000..ce43358 --- /dev/null +++ b/examples/src/parallel/parallel_with_named_branches.py @@ -0,0 +1,51 @@ +"""Example demonstrating all parallel branch patterns.""" + +from typing import Any + +from aws_durable_execution_sdk_python.config import ParallelBranch, ParallelConfig +from aws_durable_execution_sdk_python.context import ( + DurableContext, + durable_parallel_branch, +) +from aws_durable_execution_sdk_python.execution import durable_execution + + +@durable_parallel_branch(name="fetch-orders") +def fetch_orders(ctx: DurableContext) -> str: + return ctx.step(lambda _: "orders-loaded", name="load_orders") + + +@durable_parallel_branch() +def fetch_preferences(ctx: DurableContext) -> str: + return ctx.step(lambda _: "prefs-loaded", name="load_prefs") + + +@durable_execution +def handler(_event: Any, context: DurableContext) -> list[str]: + """Execute parallel branches using all supported patterns.""" + + return context.parallel( + functions=[ + # 1. Named parallel branch with ParallelBranch + ParallelBranch( + func=lambda ctx: ctx.step( + lambda _: "user-data-loaded", name="load_user" + ), + name="fetch-user-data", + ), + # 2. Named parallel branch with decorator + fetch_orders(), + # 3. Unnamed parallel branch with decorator + fetch_preferences(), + # 4. Unnamed parallel branch with ParallelBranch + ParallelBranch( + func=lambda ctx: ctx.step( + lambda _: "metrics-loaded", name="load_metrics" + ), + ), + # 5. No wrapper, just a raw callable + lambda ctx: ctx.step(lambda _: "config-loaded", name="load_config"), + ], + name="load_all_data", + config=ParallelConfig(max_concurrency=3), + ).get_results() diff --git a/examples/template.yaml b/examples/template.yaml index 0a9dcb9..2854e72 100644 --- a/examples/template.yaml +++ b/examples/template.yaml @@ -941,6 +941,42 @@ "ExecutionTimeout": 300 } } + }, + "MapWithItemNamer": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "map_with_item_namer.handler", + "Description": "Map operation with custom item_namer for iteration naming", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } + }, + "ParallelWithNamedBranches": { + "Type": "AWS::Serverless::Function", + "Properties": { + "CodeUri": "build/", + "Handler": "parallel_with_named_branches.handler", + "Description": "Parallel operation with named branches using ParallelBranch", + "Role": { + "Fn::GetAtt": [ + "DurableFunctionRole", + "Arn" + ] + }, + "DurableConfig": { + "RetentionPeriodInDays": 7, + "ExecutionTimeout": 300 + } + } } } } \ No newline at end of file diff --git a/examples/test/map/test_map_with_item_namer.py b/examples/test/map/test_map_with_item_namer.py new file mode 100644 index 0000000..11997d8 --- /dev/null +++ b/examples/test/map/test_map_with_item_namer.py @@ -0,0 +1,39 @@ +"""Tests for map_with_item_namer example.""" + +import pytest +from src.map import map_with_item_namer +from test.conftest import deserialize_operation_payload + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, +) + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=map_with_item_namer.handler, + lambda_function_name="map with item namer", +) +def test_map_with_item_namer(durable_runner): + """Test map example with custom item_namer for iteration naming.""" + with durable_runner: + result = durable_runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == [ + "processed-order-101-$25", + "processed-order-102-$50", + "processed-order-103-$75", + ] + + # Get the map operation + map_op = result.get_context("process_orders") + assert map_op is not None + assert map_op.status is OperationStatus.SUCCEEDED + + # Verify custom iteration names from item_namer + assert len(map_op.child_operations) == 3 + child_names = {op.name for op in map_op.child_operations} + expected_names = {"order-order-101", "order-order-102", "order-order-103"} + assert child_names == expected_names diff --git a/examples/test/parallel/test_parallel_with_named_branches.py b/examples/test/parallel/test_parallel_with_named_branches.py new file mode 100644 index 0000000..6b801b8 --- /dev/null +++ b/examples/test/parallel/test_parallel_with_named_branches.py @@ -0,0 +1,57 @@ +"""Tests for parallel_with_named_branches example.""" + +import pytest +from src.parallel import parallel_with_named_branches +from test.conftest import deserialize_operation_payload + +from aws_durable_execution_sdk_python.execution import InvocationStatus +from aws_durable_execution_sdk_python.lambda_service import ( + OperationStatus, + OperationType, +) + + +@pytest.mark.example +@pytest.mark.durable_execution( + handler=parallel_with_named_branches.handler, + lambda_function_name="parallel with named branches", +) +def test_parallel_with_named_branches(durable_runner): + """Test parallel example with all branch patterns.""" + with durable_runner: + result = durable_runner.run(input="test", timeout=10) + + assert result.status is InvocationStatus.SUCCEEDED + assert deserialize_operation_payload(result.result) == [ + "user-data-loaded", + "orders-loaded", + "prefs-loaded", + "metrics-loaded", + "config-loaded", + ] + + # Get the parallel operation + parallel_op = result.get_context("load_all_data") + assert parallel_op is not None + assert parallel_op.status is OperationStatus.SUCCEEDED + + # Verify branch names: named branches have custom names, unnamed use defaults + assert len(parallel_op.child_operations) == 5 + + child_names = [op.name for op in parallel_op.child_operations] + + # 1. Named ParallelBranch + assert child_names[0] == "fetch-user-data" + # 2. Named decorator + assert child_names[1] == "fetch-orders" + # 3. Unnamed decorator (None name falls back to index-based default) + assert child_names[2] == "parallel-branch-2" + # 4. Unnamed ParallelBranch (None name falls back to index-based default) + assert child_names[3] == "parallel-branch-3" + # 5. Raw callable (no ParallelBranch wrapper, index-based default) + assert child_names[4] == "parallel-branch-4" + + # Verify all children succeeded + for child in parallel_op.child_operations: + assert child.operation_type == OperationType.CONTEXT + assert child.status is OperationStatus.SUCCEEDED diff --git a/src/aws_durable_execution_sdk_python/__init__.py b/src/aws_durable_execution_sdk_python/__init__.py index 23a85cd..d82abc9 100644 --- a/src/aws_durable_execution_sdk_python/__init__.py +++ b/src/aws_durable_execution_sdk_python/__init__.py @@ -7,8 +7,10 @@ # Helper decorators - commonly used for step functions # Concurrency from aws_durable_execution_sdk_python.concurrency.models import BatchResult +from aws_durable_execution_sdk_python.config import ParallelBranch from aws_durable_execution_sdk_python.context import ( DurableContext, + durable_parallel_branch, durable_step, durable_wait_for_callback, durable_with_child_context, @@ -27,15 +29,18 @@ # Essential context types - passed to user functions from aws_durable_execution_sdk_python.types import StepContext + __all__ = [ "BatchResult", "DurableContext", "DurableExecutionsError", "InvocationError", + "ParallelBranch", "StepContext", "ValidationError", "__version__", "durable_execution", + "durable_parallel_branch", "durable_step", "durable_wait_for_callback", "durable_with_child_context", diff --git a/src/aws_durable_execution_sdk_python/concurrency/executor.py b/src/aws_durable_execution_sdk_python/concurrency/executor.py index 24c7657..3a7ab13 100644 --- a/src/aws_durable_execution_sdk_python/concurrency/executor.py +++ b/src/aws_durable_execution_sdk_python/concurrency/executor.py @@ -194,6 +194,14 @@ def execute_item( """Execute a single executable in a child context and return the result.""" raise NotImplementedError + def get_iteration_name(self, index: int) -> str: + """Get the display name for an iteration/branch at the given index. + + Subclasses can override this to provide custom naming (e.g., from item_namer + or branch names). The default returns "{name_prefix}{index}". + """ + return f"{self.name_prefix}{index}" + def execute( self, execution_state: ExecutionState, executor_context: DurableContext ) -> BatchResult[ResultType]: @@ -410,7 +418,7 @@ def _execute_item_in_child_context( operation_id: str = executor_context._create_step_id_for_logical_step( # noqa: SLF001 executable.index ) - name: str = f"{self.name_prefix}{executable.index}" + name: str = self.get_iteration_name(executable.index) is_virtual: bool = self.nesting_type is NestingType.FLAT child_context: DurableContext = executor_context.create_child_context( diff --git a/src/aws_durable_execution_sdk_python/config.py b/src/aws_durable_execution_sdk_python/config.py index 980786d..e8c0eb4 100644 --- a/src/aws_durable_execution_sdk_python/config.py +++ b/src/aws_durable_execution_sdk_python/config.py @@ -9,6 +9,7 @@ from aws_durable_execution_sdk_python.exceptions import ValidationError + P = TypeVar("P") # Payload type R = TypeVar("R") # Result type T = TypeVar("T") @@ -245,6 +246,41 @@ class ParallelConfig: nesting_type: NestingType = NestingType.NESTED +@dataclass(frozen=True) +class ParallelBranch(Generic[T]): + """A named branch for parallel execution. + + Use this to provide custom names for parallel branches, improving + observability in execution history. + + Type Parameters: + T: The return type of the branch function. + + Args: + func: The callable to execute in this branch. Receives a DurableContext. + name: Optional custom name for this branch. When provided, replaces + the default "parallel-branch-{index}" naming in execution history. + This affects observability but not replay determinism. + + Example: + context.parallel( + functions=[ + ParallelBranch(func=lambda ctx: fetch_user(ctx), name="fetch-user-data"), + ParallelBranch(func=lambda ctx: fetch_orders(ctx), name="fetch-order-history"), + ], + name="load-data", + config=ParallelConfig(max_concurrency=2), + ) + """ + + func: Callable + name: str | None = None + + def __call__(self, *args, **kwargs): + """Delegate to the wrapped function, making ParallelBranch itself callable.""" + return self.func(*args, **kwargs) + + class StepSemantics(Enum): AT_MOST_ONCE_PER_RETRY = "AT_MOST_ONCE_PER_RETRY" AT_LEAST_ONCE_PER_RETRY = "AT_LEAST_ONCE_PER_RETRY" @@ -354,12 +390,15 @@ class ItemBatcher(Generic[T]): @dataclass(frozen=True) -class MapConfig: +class MapConfig(Generic[T]): """Configuration options for map operations over collections. This class configures how map operations process collections of items, including concurrency, batching, completion criteria, and serialization. + Type Parameters: + T: The type of items being processed in the map operation. + Args: max_concurrency: Maximum number of items to process concurrently. If None, no limit is imposed and all items are processed concurrently. @@ -402,6 +441,12 @@ class MapConfig: - NESTED: Each item runs in its own isolated context (default) - FLAT: All items share the same parent context + item_namer: Optional callable to generate custom names for each map iteration. + When provided, replaces the default "map-item-{index}" naming scheme. + Receives the item and its index, and returns a string name for that iteration. + This affects observability (execution history names) but not replay determinism. + If None, uses the default naming: "map-item-{index}". + Example: # Process 5 items at a time, batch by count, require all to succeed config = MapConfig( @@ -409,6 +454,12 @@ class MapConfig: item_batcher=ItemBatcher(max_items_per_batch=10), completion_config=CompletionConfig.all_successful() ) + + # With custom iteration names + config = MapConfig( + max_concurrency=5, + item_namer=lambda item, index: f"process-order-{item.id}" + ) """ max_concurrency: int | None = None @@ -418,6 +469,7 @@ class MapConfig: item_serdes: SerDes | None = None summary_generator: SummaryGenerator | None = None nesting_type: NestingType = NestingType.NESTED + item_namer: Callable[[T, int], str] | None = None @dataclass(frozen=True) diff --git a/src/aws_durable_execution_sdk_python/context.py b/src/aws_durable_execution_sdk_python/context.py index bfded98..6691f2a 100644 --- a/src/aws_durable_execution_sdk_python/context.py +++ b/src/aws_durable_execution_sdk_python/context.py @@ -12,6 +12,7 @@ Duration, InvokeConfig, MapConfig, + ParallelBranch, ParallelConfig, StepConfig, WaitForCallbackConfig, @@ -55,6 +56,7 @@ WaitForConditionCheckContext, ) + if TYPE_CHECKING: from collections.abc import Callable, Sequence @@ -119,6 +121,52 @@ def function_with_arguments(child_context: DurableContext): return wrapper +def durable_parallel_branch( + name: str | None = None, +) -> Callable[ + [Callable[Concatenate[DurableContext, Params], T]], + Callable[Params, ParallelBranch[T]], +]: + """Wrap your callable into a named ParallelBranch for use with context.parallel(). + + This is a decorator factory β€” call it with an optional name to produce + the actual decorator. + + Args: + name: Optional custom name for this branch. When provided, replaces + the default "parallel-branch-{index}" naming in execution history. + If None, the function's __name__ is used. + + Example: + @durable_parallel_branch(name="fetch-user-data") + def fetch_user(ctx: DurableContext, user_id: str) -> dict: + return ctx.step(lambda _: {"id": user_id, "name": "Jane"}, name="load_user") + + @durable_parallel_branch(name="fetch-orders") + def fetch_orders(ctx: DurableContext, user_id: str) -> list: + return ctx.step(lambda _: ["order1", "order2"], name="load_orders") + + # Usage in a durable handler: + results = context.parallel( + functions=[fetch_user(user_id), fetch_orders(user_id)], + name="load-data", + ) + """ + + def decorator( + func: Callable[Concatenate[DurableContext, Params], T], + ) -> Callable[Params, ParallelBranch[T]]: + def wrapper(*args, **kwargs) -> ParallelBranch[T]: + def function_with_arguments(ctx: DurableContext) -> T: + return func(ctx, *args, **kwargs) + + return ParallelBranch(func=function_with_arguments, name=name) + + return wrapper + + return decorator + + def durable_wait_for_callback( func: Callable[Concatenate[str, WaitForCallbackContext, Params], T], ) -> Callable[Params, Callable[[str, WaitForCallbackContext], T]]: @@ -496,7 +544,7 @@ def map_in_child_context() -> BatchResult[R]: def parallel( self, - functions: Sequence[Callable[[DurableContext], T]], + functions: Sequence[Callable[[DurableContext], T] | ParallelBranch[T]], name: str | None = None, config: ParallelConfig | None = None, ) -> BatchResult[T]: diff --git a/src/aws_durable_execution_sdk_python/operation/map.py b/src/aws_durable_execution_sdk_python/operation/map.py index 8e9fb6a..f201efc 100644 --- a/src/aws_durable_execution_sdk_python/operation/map.py +++ b/src/aws_durable_execution_sdk_python/operation/map.py @@ -15,6 +15,7 @@ from aws_durable_execution_sdk_python.config import MapConfig, NestingType from aws_durable_execution_sdk_python.lambda_service import OperationSubType + if TYPE_CHECKING: from aws_durable_execution_sdk_python.context import DurableContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -47,6 +48,7 @@ def __init__( summary_generator: SummaryGenerator | None = None, item_serdes: SerDes | None = None, nesting_type: NestingType = NestingType.NESTED, + item_namer: Callable[[T, int], str] | None = None, ): super().__init__( executables=executables, @@ -61,13 +63,14 @@ def __init__( nesting_type=nesting_type, ) self.items = items + self._item_namer = item_namer @classmethod def from_items( cls, items: Sequence[T], func: Callable, - config: MapConfig, + config: MapConfig[T], ) -> MapExecutor[T, R]: """Create MapExecutor from items and a callable.""" executables: list[Executable[Callable]] = [ @@ -86,8 +89,15 @@ def from_items( summary_generator=config.summary_generator, item_serdes=config.item_serdes, nesting_type=config.nesting_type, + item_namer=config.item_namer, ) + def get_iteration_name(self, index: int) -> str: + """Return custom item name if item_namer is provided, otherwise default.""" + if self._item_namer is not None: + return self._item_namer(self.items[index], index) + return super().get_iteration_name(index) + def execute_item(self, child_context, executable: Executable[Callable]) -> R: logger.debug("πŸ—ΊοΈ Processing map item: %s", executable.index) item = self.items[executable.index] diff --git a/src/aws_durable_execution_sdk_python/operation/parallel.py b/src/aws_durable_execution_sdk_python/operation/parallel.py index 4d7094a..76fc16f 100644 --- a/src/aws_durable_execution_sdk_python/operation/parallel.py +++ b/src/aws_durable_execution_sdk_python/operation/parallel.py @@ -9,9 +9,14 @@ from aws_durable_execution_sdk_python.concurrency.executor import ConcurrentExecutor from aws_durable_execution_sdk_python.concurrency.models import Executable -from aws_durable_execution_sdk_python.config import ParallelConfig, NestingType +from aws_durable_execution_sdk_python.config import ( + NestingType, + ParallelBranch, + ParallelConfig, +) from aws_durable_execution_sdk_python.lambda_service import OperationSubType + if TYPE_CHECKING: from aws_durable_execution_sdk_python.concurrency.models import BatchResult from aws_durable_execution_sdk_python.context import DurableContext @@ -56,13 +61,19 @@ def __init__( @classmethod def from_callables( cls, - callables: Sequence[Callable], + callables: Sequence[Callable | ParallelBranch], config: ParallelConfig, ) -> ParallelExecutor: - """Create ParallelExecutor from a sequence of callables.""" + """Create ParallelExecutor from a sequence of callables or ParallelBranch instances. + + Since ParallelBranch is callable, it is stored directly as the func in + each Executable. The get_iteration_name method inspects the func to + extract the branch name when available. + """ executables: list[Executable[Callable]] = [ Executable(index=i, func=func) for i, func in enumerate(callables) ] + return cls( executables=executables, max_concurrency=config.max_concurrency, @@ -76,6 +87,13 @@ def from_callables( nesting_type=config.nesting_type, ) + def get_iteration_name(self, index: int) -> str: + """Return custom branch name if the callable is a ParallelBranch with a name.""" + func = self.executables[index].func + if isinstance(func, ParallelBranch) and func.name is not None: + return func.name + return super().get_iteration_name(index) + def execute_item(self, child_context, executable: Executable[Callable]) -> R: # noqa: PLR6301 logger.debug("πŸ”€ Processing parallel branch: %s", executable.index) result: R = executable.func(child_context) @@ -84,7 +102,7 @@ def execute_item(self, child_context, executable: Executable[Callable]) -> R: # def parallel_handler( - callables: Sequence[Callable], + callables: Sequence[Callable | ParallelBranch], config: ParallelConfig | None, execution_state: ExecutionState, parallel_context: DurableContext, diff --git a/src/aws_durable_execution_sdk_python/types.py b/src/aws_durable_execution_sdk_python/types.py index 9181be9..90080b0 100644 --- a/src/aws_durable_execution_sdk_python/types.py +++ b/src/aws_durable_execution_sdk_python/types.py @@ -6,6 +6,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Generic, Protocol, TypeVar + if TYPE_CHECKING: from collections.abc import Callable, Mapping, Sequence @@ -15,6 +16,7 @@ ChildConfig, Duration, MapConfig, + ParallelBranch, ParallelConfig, StepConfig, ) @@ -124,7 +126,7 @@ def map( @abstractmethod def parallel( self, - functions: Sequence[Callable[[DurableContext], T]], + functions: Sequence[Callable[[DurableContext], T] | ParallelBranch[T]], name: str | None = None, config: ParallelConfig | None = None, ) -> BatchResult[T]: diff --git a/tests/context_test.py b/tests/context_test.py index af32a3e..0e2cf0e 100644 --- a/tests/context_test.py +++ b/tests/context_test.py @@ -14,6 +14,7 @@ Duration, InvokeConfig, MapConfig, + ParallelBranch, ParallelConfig, StepConfig, ) @@ -21,6 +22,7 @@ Callback, DurableContext, ExecutionContext, + durable_parallel_branch, ) from aws_durable_execution_sdk_python.exceptions import ( CallbackError, @@ -2160,3 +2162,116 @@ def test_should_propagate_outer_parent_id_when_virtual_is_nested_in_virtual(): # endregion Virtual-context identity tests + + +# region durable_parallel_branch + + +def test_durable_parallel_branch_returns_parallel_branch_with_name(): + """Test that the decorator produces a ParallelBranch with the given name.""" + + @durable_parallel_branch(name="fetch-user-data") + def fetch_user(ctx: DurableContext, user_id: str) -> dict: + return {"id": user_id} + + result = fetch_user("user-123") + + assert isinstance(result, ParallelBranch) + assert result.name == "fetch-user-data" + + +def test_durable_parallel_branch_with_no_name(): + """Test that when name is None, ParallelBranch.name is None.""" + + @durable_parallel_branch() + def fetch_orders(ctx: DurableContext) -> list: + return ["order1"] + + result = fetch_orders() + + assert isinstance(result, ParallelBranch) + assert result.name is None + + +def test_durable_parallel_branch_callable_delegates_to_func(): + """Test that calling the ParallelBranch delegates to the wrapped function.""" + + @durable_parallel_branch(name="my-branch") + def my_branch(ctx: DurableContext, value: int) -> int: + return value * 2 + + branch = my_branch(21) + mock_ctx = Mock(spec=DurableContext) + + result = branch(mock_ctx) + + assert result == 42 + + +def test_durable_parallel_branch_with_multiple_args_and_kwargs(): + """Test that positional and keyword arguments are correctly bound.""" + + @durable_parallel_branch(name="compute") + def compute(ctx: DurableContext, a: int, b: int, op: str = "add") -> str: + if op == "add": + return f"{a + b}" + return f"{a * b}" + + branch = compute(3, 4, op="mul") + mock_ctx = Mock(spec=DurableContext) + + result = branch(mock_ctx) + + assert result == "12" + + +def test_durable_parallel_branch_passes_context_as_first_arg(): + """Test that the DurableContext is passed as the first argument to the function.""" + received_ctx = None + + @durable_parallel_branch(name="capture-ctx") + def capture(ctx: DurableContext) -> str: + nonlocal received_ctx + received_ctx = ctx + return "done" + + branch = capture() + mock_ctx = Mock(spec=DurableContext) + branch(mock_ctx) + + assert received_ctx is mock_ctx + + +def test_durable_parallel_branch_multiple_invocations_are_independent(): + """Test that calling the wrapper multiple times produces independent branches.""" + + @durable_parallel_branch(name="greet") + def greet(ctx: DurableContext, name: str) -> str: + return f"hello {name}" + + branch_a = greet("Alice") + branch_b = greet("Bob") + + mock_ctx = Mock(spec=DurableContext) + + assert branch_a(mock_ctx) == "hello Alice" + assert branch_b(mock_ctx) == "hello Bob" + + +def test_durable_parallel_branch_is_compatible_with_parallel_functions_arg(): + """Test that the result can be used in a functions list alongside plain callables.""" + + @durable_parallel_branch(name="named-branch") + def named(ctx: DurableContext) -> str: + return "named" + + plain = lambda ctx: "plain" # noqa: E731 + + functions = [named(), plain] + + assert isinstance(functions[0], ParallelBranch) + assert callable(functions[0]) + assert callable(functions[1]) + + +# endregion durable_parallel_branch diff --git a/tests/operation/map_test.py b/tests/operation/map_test.py index b3c979d..c7a653f 100644 --- a/tests/operation/map_test.py +++ b/tests/operation/map_test.py @@ -1178,3 +1178,122 @@ def map_func(ctx, item, idx, items): assert result.total_count == 0 assert result.success_count == 0 assert result.failure_count == 0 + + +# region item_namer tests + + +def test_map_executor_get_iteration_name_default(): + """Without item_namer, iterations use default 'map-item-{index}' naming.""" + items = ["a", "b", "c"] + config = MapConfig(max_concurrency=2) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "map-item-0" + assert executor.get_iteration_name(1) == "map-item-1" + assert executor.get_iteration_name(2) == "map-item-2" + + +def test_map_executor_get_iteration_name_with_item_namer(): + """With item_namer, iterations use custom names.""" + items = [{"id": "order-1"}, {"id": "order-2"}, {"id": "order-3"}] + config = MapConfig( + max_concurrency=2, + item_namer=lambda item, index: f"process-{item['id']}", + ) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "process-order-1" + assert executor.get_iteration_name(1) == "process-order-2" + assert executor.get_iteration_name(2) == "process-order-3" + + +def test_map_executor_item_namer_receives_item_and_index(): + """item_namer receives both the item and its index.""" + items = ["alpha", "beta", "gamma"] + received_args: list[tuple] = [] + + def namer(item, index): + received_args.append((item, index)) + return f"item-{index}-{item}" + + config = MapConfig(item_namer=namer) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + executor.get_iteration_name(0) + executor.get_iteration_name(2) + + assert received_args == [("alpha", 0), ("gamma", 2)] + + +def test_map_executor_item_namer_uses_index(): + """item_namer can use the index to generate names.""" + items = [10, 20, 30] + config = MapConfig(item_namer=lambda item, index: f"step-{index + 1}") + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "step-1" + assert executor.get_iteration_name(1) == "step-2" + assert executor.get_iteration_name(2) == "step-3" + + +def test_map_executor_item_namer_none_falls_back_to_default(): + """Explicitly passing item_namer=None uses default naming.""" + items = ["x", "y"] + config = MapConfig(item_namer=None) + + executor = MapExecutor.from_items( + items=items, + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor.get_iteration_name(0) == "map-item-0" + assert executor.get_iteration_name(1) == "map-item-1" + + +def test_map_executor_from_items_passes_item_namer(): + """MapExecutor.from_items correctly passes item_namer from config.""" + namer = lambda item, index: f"custom-{index}" # noqa: E731 + config = MapConfig(item_namer=namer) + + executor = MapExecutor.from_items( + items=["a"], + func=lambda ctx, item, idx, items: item, + config=config, + ) + + assert executor._item_namer is namer + + +def test_map_config_generic_with_item_namer(): + """MapConfig can be parameterized with a type and use item_namer.""" + config: MapConfig[dict] = MapConfig( + item_namer=lambda item, index: f"item-{item['name']}", + ) + + assert config.item_namer is not None + assert config.item_namer({"name": "test"}, 0) == "item-test" + + +# endregion diff --git a/tests/operation/parallel_test.py b/tests/operation/parallel_test.py index 1922207..cf7c736 100644 --- a/tests/operation/parallel_test.py +++ b/tests/operation/parallel_test.py @@ -20,8 +20,8 @@ ) from aws_durable_execution_sdk_python.config import ( CompletionConfig, - ParallelConfig, NestingType, + ParallelConfig, ) from aws_durable_execution_sdk_python.context import DurableContext, ExecutionContext from aws_durable_execution_sdk_python.identifier import OperationIdentifier @@ -1118,3 +1118,123 @@ def create_id(self, i): assert parent_call[1]["serdes"] is custom_serdes assert isinstance(parent_call[1]["value"], BatchResult) assert parent_call[1]["value"] is result + + +# region ParallelBranch and branch naming tests + + +def test_parallel_branch_is_callable(): + """ParallelBranch instances are callable.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda x: x * 2, name="double") + assert callable(branch) + + +def test_parallel_branch_delegates_to_func(): + """Calling ParallelBranch delegates to the wrapped func.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda x, y: x + y, name="add") + assert branch(3, 4) == 7 + + +def test_parallel_branch_passes_kwargs(): + """ParallelBranch passes keyword arguments to func.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda ctx, flag=False: flag, name="test") + assert branch("ctx", flag=True) is True + + +def test_parallel_branch_frozen(): + """ParallelBranch is immutable (frozen dataclass).""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda: None, name="test") + with pytest.raises(AttributeError): + branch.name = "changed" # type: ignore[misc] + + +def test_parallel_executor_get_iteration_name_default(): + """Plain callables use default 'parallel-branch-{index}' naming.""" + callables = [lambda ctx: "a", lambda ctx: "b", lambda ctx: "c"] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(callables, config) + + assert executor.get_iteration_name(0) == "parallel-branch-0" + assert executor.get_iteration_name(1) == "parallel-branch-1" + assert executor.get_iteration_name(2) == "parallel-branch-2" + + +def test_parallel_executor_get_iteration_name_with_named_branches(): + """ParallelBranch with name uses the custom name.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branches = [ + ParallelBranch(func=lambda ctx: "user", name="fetch-user-data"), + ParallelBranch(func=lambda ctx: "orders", name="fetch-order-history"), + ] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(branches, config) + + assert executor.get_iteration_name(0) == "fetch-user-data" + assert executor.get_iteration_name(1) == "fetch-order-history" + + +def test_parallel_executor_get_iteration_name_mixed(): + """Mix of ParallelBranch (with/without name) and plain callables.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branches = [ + ParallelBranch(func=lambda ctx: "a", name="named-branch"), + lambda ctx: "b", + ParallelBranch(func=lambda ctx: "c"), + ] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(branches, config) + + assert executor.get_iteration_name(0) == "named-branch" + assert executor.get_iteration_name(1) == "parallel-branch-1" + assert executor.get_iteration_name(2) == "parallel-branch-2" + + +def test_parallel_executor_get_iteration_name_none_name(): + """ParallelBranch with name=None falls back to default naming.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branches = [ + ParallelBranch(func=lambda ctx: "x", name=None), + ] + config = ParallelConfig() + + executor = ParallelExecutor.from_callables(branches, config) + + assert executor.get_iteration_name(0) == "parallel-branch-0" + + +def test_parallel_branch_execute_item(): + """ParallelBranch works correctly in execute_item.""" + from aws_durable_execution_sdk_python.config import ParallelBranch + + branch = ParallelBranch(func=lambda ctx: f"result-{ctx}", name="my-branch") + executable = Executable(index=0, func=branch) + + executor = ParallelExecutor( + executables=[executable], + max_concurrency=None, + completion_config=CompletionConfig.all_successful(), + top_level_sub_type=OperationSubType.PARALLEL, + iteration_sub_type=OperationSubType.PARALLEL_BRANCH, + name_prefix="parallel-branch-", + serdes=None, + ) + + result = executor.execute_item("test-ctx", executable) + assert result == "result-test-ctx" + + +# endregion