diff --git a/python/packages/a2a/tests/test_a2a_agent.py b/python/packages/a2a/tests/test_a2a_agent.py index a0919cbda4..442960a7ee 100644 --- a/python/packages/a2a/tests/test_a2a_agent.py +++ b/python/packages/a2a/tests/test_a2a_agent.py @@ -1284,13 +1284,11 @@ async def test_streaming_artifact_update_event_does_not_duplicate_terminal_task_ final=True, ) - mock_a2a_client.responses.extend( - [ - (working_task, first_chunk), - (working_task, second_chunk), - (terminal_task, terminal_event), - ] - ) + mock_a2a_client.responses.extend([ + (working_task, first_chunk), + (working_task, second_chunk), + (terminal_task, terminal_event), + ]) stream = a2a_agent.run("Hello", stream=True) updates: list[AgentResponseUpdate] = [] @@ -1371,12 +1369,10 @@ async def test_streaming_terminal_task_only_emits_unstreamed_artifacts( final=True, ) - mock_a2a_client.responses.extend( - [ - (working_task, streamed_chunk), - (terminal_task, terminal_event), - ] - ) + mock_a2a_client.responses.extend([ + (working_task, streamed_chunk), + (terminal_task, terminal_event), + ]) stream = a2a_agent.run("Hello", stream=True) updates: list[AgentResponseUpdate] = [] diff --git a/python/packages/core/agent_framework/_workflows/_agent_executor.py b/python/packages/core/agent_framework/_workflows/_agent_executor.py index 986b086fab..2bcc6d355e 100644 --- a/python/packages/core/agent_framework/_workflows/_agent_executor.py +++ b/python/packages/core/agent_framework/_workflows/_agent_executor.py @@ -339,6 +339,13 @@ async def _run_agent(self, ctx: WorkflowContext[Never, AgentResponse]) -> AgentR ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) ) + if not self._cache: + logger.warning( + "AgentExecutor %s: Running agent with empty message cache. " + "This could lead to service error for some LLM providers.", + self.id, + ) + run_agent = cast(Callable[..., Awaitable[AgentResponse[Any]]], self._agent.run) response = await run_agent( self._cache, @@ -371,6 +378,13 @@ async def _run_agent_streaming(self, ctx: WorkflowContext[Never, AgentResponseUp ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}) ) + if not self._cache: + logger.warning( + "AgentExecutor %s: Running agent with empty message cache. " + "This could lead to service error for some LLM providers.", + self.id, + ) + updates: list[AgentResponseUpdate] = [] streamed_user_input_requests: list[Content] = [] run_agent_stream = cast(Callable[..., ResponseStream[AgentResponseUpdate, AgentResponse[Any]]], self._agent.run) diff --git a/python/packages/core/agent_framework/_workflows/_runner.py b/python/packages/core/agent_framework/_workflows/_runner.py index c548e76e53..d58d9b99a9 100644 --- a/python/packages/core/agent_framework/_workflows/_runner.py +++ b/python/packages/core/agent_framework/_workflows/_runner.py @@ -186,6 +186,12 @@ async def _deliver_message_inner(edge_runner: EdgeRunner, message: WorkflowMessa """Inner loop to deliver a single message through an edge runner.""" return await edge_runner.send_message(message, self._state, self._ctx) + async def _deliver_messages_for_edge_runner(edge_runner: EdgeRunner) -> None: + # Preserve message order per edge runner (and therefore per routed target path) + # while still allowing parallelism across different edge runners. + for message in source_messages: + await _deliver_message_inner(edge_runner, message) + # Route all messages through normal workflow edges associated_edge_runners = self._edge_runner_map.get(source_executor_id, []) if not associated_edge_runners: @@ -193,12 +199,6 @@ async def _deliver_message_inner(edge_runner: EdgeRunner, message: WorkflowMessa logger.debug(f"No outgoing edges found for executor {source_executor_id}; dropping messages.") return - async def _deliver_messages_for_edge_runner(edge_runner: EdgeRunner) -> None: - # Preserve message order per edge runner (and therefore per routed target path) - # while still allowing parallelism across different edge runners. - for message in source_messages: - await _deliver_message_inner(edge_runner, message) - tasks = [_deliver_messages_for_edge_runner(edge_runner) for edge_runner in associated_edge_runners] await asyncio.gather(*tasks) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py index 44c651df5a..e71fdd3882 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_handoff.py @@ -38,17 +38,15 @@ from dataclasses import dataclass from typing import Any -from agent_framework import Agent, SupportsAgentRun +from agent_framework import Agent, AgentResponse, Message, SupportsAgentRun from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware, MiddlewareTermination from agent_framework._sessions import AgentSession from agent_framework._tools import FunctionTool, tool -from agent_framework._types import AgentResponse, Content, Message from agent_framework._workflows._agent_executor import AgentExecutor, AgentExecutorRequest from agent_framework._workflows._agent_utils import resolve_agent_id from agent_framework._workflows._checkpoint import CheckpointStorage from agent_framework._workflows._events import WorkflowEvent from agent_framework._workflows._request_info_mixin import response_handler -from agent_framework._workflows._typing_utils import is_chat_agent from agent_framework._workflows._workflow import Workflow from agent_framework._workflows._workflow_builder import WorkflowBuilder from agent_framework._workflows._workflow_context import WorkflowContext @@ -263,88 +261,6 @@ def _prepare_agent_with_handoffs( return cloned_agent - def _persist_pending_approval_function_calls(self) -> None: - """Persist pending approval function calls for stateless provider resumes. - - Handoff workflows force ``store=False`` and replay conversation state from ``_full_conversation``. - When a run pauses on function approval, ``AgentExecutor`` returns ``None`` and the assistant - function-call message is not returned as an ``AgentResponse``. Without persisting that call, the - next turn may submit only a function result, which responses-style APIs reject. - """ - pending_calls: list[Content] = [] - for request in self._pending_agent_requests.values(): - if request.type != "function_approval_request": - continue - function_call = getattr(request, "function_call", None) - if isinstance(function_call, Content) and function_call.type == "function_call": - pending_calls.append(function_call) - - if not pending_calls: - return - - self._full_conversation.append( - Message( - role="assistant", - contents=pending_calls, - author_name=self._agent.name, - ) - ) - - def _persist_missing_approved_function_results( - self, - *, - runtime_tool_messages: list[Message], - response_messages: list[Message], - ) -> None: - """Persist fallback function_result entries for approved calls when missing. - - In approval resumes, function invocation can execute approved tools without - always surfacing those tool outputs in the returned ``AgentResponse.messages``. - For stateless handoff replays, we must keep call/output pairs balanced. - """ - candidate_results: dict[str, Content] = {} - for message in runtime_tool_messages: - for content in message.contents: - if content.type == "function_result": - call_id = getattr(content, "call_id", None) - if isinstance(call_id, str) and call_id: - candidate_results[call_id] = content - continue - - if content.type != "function_approval_response" or not content.approved: - continue - - function_call = getattr(content, "function_call", None) - call_id = getattr(function_call, "call_id", None) or getattr(content, "id", None) - if isinstance(call_id, str) and call_id and call_id not in candidate_results: - # Fallback content for approved calls when runtime messages do not include - # a concrete function_result payload. - candidate_results[call_id] = Content.from_function_result( - call_id=call_id, - result='{"status":"approved"}', - ) - - if not candidate_results: - return - - observed_result_call_ids: set[str] = set() - for message in [*self._full_conversation, *response_messages]: - for content in message.contents: - if content.type == "function_result" and isinstance(content.call_id, str) and content.call_id: - observed_result_call_ids.add(content.call_id) - - missing_call_ids = sorted(set(candidate_results.keys()) - observed_result_call_ids) - if not missing_call_ids: - return - - self._full_conversation.append( - Message( - role="tool", - contents=[candidate_results[call_id] for call_id in missing_call_ids], - author_name=self._agent.name, - ) - ) - def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: """Produce a deep copy of the Agent while preserving runtime configuration.""" options = agent.default_options @@ -360,7 +276,6 @@ def _clone_chat_agent(self, agent: Agent[Any]) -> Agent[Any]: cloned_options = deepcopy(options) # Disable parallel tool calls to prevent the agent from invoking multiple handoff tools at once. cloned_options["allow_multiple_tool_calls"] = False - cloned_options["store"] = False cloned_options["tools"] = new_tools # restore the original tools, in case they are shared between agents @@ -426,45 +341,15 @@ def _handoff_tool() -> None: @override async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None: """Override to support handoff.""" - incoming_messages = list(self._cache) - cleaned_incoming_messages = clean_conversation_for_handoff(incoming_messages) - runtime_tool_messages = [ - message - for message in incoming_messages - if any( - content.type - in { - "function_result", - "function_approval_response", - } - for content in message.contents - ) - or message.role == "tool" - ] - # When the full conversation is empty, it means this is the first run. # Broadcast the initial cache to all other agents. Subsequent runs won't # need this since responses are broadcast after each agent run and user input. if self._is_start_agent and not self._full_conversation: - await self._broadcast_messages(cleaned_incoming_messages, ctx) - - # Persist only cleaned chat history between turns to avoid replaying stale tool calls. - self._full_conversation.extend(cleaned_incoming_messages) - - # Always run with full conversation context for request_info resumes. - # Keep runtime tool-control messages for this run only (e.g., approval responses). - self._cache = list(self._full_conversation) - self._cache.extend(runtime_tool_messages) - - # Handoff workflows are orchestrator-stateful and provider-stateless by design. - # If an existing session still has a service conversation id, clear it to avoid - # replaying stale unresolved tool calls across resumed turns. - if ( - is_chat_agent(self._agent) - and self._agent.default_options.get("store") is False - and self._session.service_session_id is not None - ): - self._session.service_session_id = None + await self._broadcast_messages(self._cache.copy(), ctx) + + # Full conversation maintains the chat history between agents across handoffs, + # excluding internal agent messages such as tool calls and results. + self._full_conversation.extend(self._cache.copy()) # Check termination condition before running the agent if await self._check_terminate_and_yield(ctx): @@ -483,36 +368,35 @@ async def _run_agent_and_emit(self, ctx: WorkflowContext[Any, Any]) -> None: # A function approval request is issued by the base AgentExecutor if response is None: - if is_chat_agent(self._agent) and self._agent.default_options.get("store") is False: - self._persist_pending_approval_function_calls() # Agent did not complete (e.g., waiting for user input); do not emit response logger.debug("AgentExecutor %s: Agent did not complete, awaiting user input", self.id) return - # Remove function call related content from the agent response for broadcast. - # This prevents replaying stale tool artifacts to other agents. + # Remove function call related content from the agent response for full conversation history cleaned_response = clean_conversation_for_handoff(response.messages) - # For internal tracking, preserve the full response (including function_calls) - # in _full_conversation so that Azure OpenAI can match function_calls with - # function_results when the workflow resumes after user approvals. - self._full_conversation.extend(response.messages) - self._persist_missing_approved_function_results( - runtime_tool_messages=runtime_tool_messages, - response_messages=response.messages, - ) + # Append the agent response to the full conversation history. This list removes + # function call related content such that the result stays consistent regardless + # of which agent yields the final output. + self._full_conversation.extend(cleaned_response) # Broadcast only the cleaned response to other agents (without function_calls/results) await self._broadcast_messages(cleaned_response, ctx) # Check if a handoff was requested - if handoff_target := self._is_handoff_requested(response): + if is_handoff_requested := self._is_handoff_requested(response): + handoff_target, handoff_message = is_handoff_requested if handoff_target not in self._handoff_targets: raise ValueError( f"Agent '{resolve_agent_id(self._agent)}' attempted to handoff to unknown " f"target '{handoff_target}'. Valid targets are: {', '.join(self._handoff_targets)}" ) + # Add the handoff message to the cache so that the next invocation of the agent includes + # the tool call result. This is necessary because each tool call must have a corresponding + # tool result. + self._cache.append(handoff_message) + await ctx.send_message( AgentExecutorRequest(messages=[], should_respond=True), target_id=handoff_target, @@ -589,12 +473,25 @@ async def _broadcast_messages( # Since all agents are connected via fan-out, we can directly send the message await ctx.send_message(agent_executor_request) - def _is_handoff_requested(self, response: AgentResponse) -> str | None: + def _is_handoff_requested(self, response: AgentResponse) -> tuple[str, Message] | None: """Determine if the agent response includes a handoff request. If a handoff tool is invoked, the middleware will short-circuit execution and provide a synthetic result that includes the target agent ID. The message that contains the function result will be the last message in the response. + + Args: + response: The AgentResponse to inspect for handoff requests + + Returns: + A tuple of (target_agent_id, message) if a handoff is requested, or None if no handoff is requested + + Note: + The returned message is the full message that contains the handoff function result content. This is + needed to complete the agent's chat history due to the `_AutoHandoffMiddleware` short-circuiting + behavior, which prevents the handoff tool call and result from being included in the agent response + messages. By returning the full message, we can ensure the agent's chat history remains valid with + a function result for the handoff tool call. """ if not response.messages: return None @@ -617,7 +514,7 @@ def _is_handoff_requested(self, response: AgentResponse) -> str | None: if parsed_payload: handoff_target = parsed_payload.get(HANDOFF_FUNCTION_RESULT_KEY) if isinstance(handoff_target, str): - return handoff_target + return handoff_target, last_message else: continue @@ -1034,6 +931,25 @@ def build(self) -> Workflow: # Resolve agents (either from instances or factories) # The returned map keys are either executor IDs or factory names, which is need to resolve handoff configs resolved_agents = self._resolve_agents() + + # Validate that all agents have require_per_service_call_history_persistence enabled. + # Handoff workflows use middleware that short-circuits tool calls (MiddlewareTermination), + # which means the service never sees those tool results. Without per-service-call + # history persistence, local history providers would persist tool results that + # the service has no record of, causing call/result mismatches on subsequent turns. + agents_missing_flag = [ + resolve_agent_id(agent) + for agent in resolved_agents.values() + if not agent.require_per_service_call_history_persistence + ] + if agents_missing_flag: + raise ValueError( + f"Handoff workflows require all participant agents to have " + f"'require_per_service_call_history_persistence=True'. " + f"The following agents are missing this setting: {', '.join(agents_missing_flag)}. " + f"Set this flag when constructing each Agent to ensure local history stays " + f"consistent with the service across handoff tool-call short-circuits." + ) # Resolve handoff configurations to use agent display names # The returned map keys are executor IDs resolved_handoffs = self._resolve_handoffs(resolved_agents) diff --git a/python/packages/orchestrations/agent_framework_orchestrations/_orchestrator_helpers.py b/python/packages/orchestrations/agent_framework_orchestrations/_orchestrator_helpers.py index ca8d5daee3..1ad4d01e82 100644 --- a/python/packages/orchestrations/agent_framework_orchestrations/_orchestrator_helpers.py +++ b/python/packages/orchestrations/agent_framework_orchestrations/_orchestrator_helpers.py @@ -24,6 +24,11 @@ def clean_conversation_for_handoff(conversation: list[Message]) -> list[Message] - Drops all non-text content from every message. - Drops messages with no remaining text content. - Preserves original roles and author names for retained text messages. + + Args: + conversation: Full conversation history, including tool-control content + Returns: + Cleaned conversation history with only text content, suitable for handoff routing """ cleaned: list[Message] = [] for msg in conversation: @@ -31,6 +36,8 @@ def clean_conversation_for_handoff(conversation: list[Message]) -> list[Message] # (function_call/function_result/approval payloads) is runtime-only and # must not be replayed in future model turns. text_parts = [content.text for content in msg.contents if content.type == "text" and content.text] + # TODO(@taochen): This is a simplified check that considers any non-text content as a tool call. + # We need to enhance this logic to specifically identify tool related contents. if not text_parts: continue diff --git a/python/packages/orchestrations/tests/test_handoff.py b/python/packages/orchestrations/tests/test_handoff.py index 0687708b20..225b408422 100644 --- a/python/packages/orchestrations/tests/test_handoff.py +++ b/python/packages/orchestrations/tests/test_handoff.py @@ -1,8 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. +import os import re from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence -from typing import Any, cast +from typing import Annotated, Any, cast from unittest.mock import AsyncMock, MagicMock import pytest @@ -16,13 +17,15 @@ Message, ResponseStream, WorkflowEvent, + WorkflowRunState, resolve_agent_id, tool, ) from agent_framework._clients import BaseChatClient from agent_framework._middleware import ChatMiddlewareLayer, FunctionInvocationContext, MiddlewareTermination from agent_framework._tools import FunctionInvocationLayer, FunctionTool -from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder +from agent_framework.orchestrations import HandoffAgentUserRequest, HandoffBuilder, HandoffSentEvent +from pytest import param from agent_framework_orchestrations._handoff import ( HANDOFF_FUNCTION_RESULT_KEY, @@ -34,6 +37,7 @@ from agent_framework_orchestrations._orchestrator_helpers import clean_conversation_for_handoff +# region unit tests class MockChatClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): """Mock chat client for testing handoff workflows.""" @@ -132,7 +136,12 @@ def __init__( handoff_to: The name of the agent to hand off to, or None for no handoff. This is hardcoded for testing purposes so that the agent always attempts to hand off. """ - super().__init__(client=MockChatClient(name=name, handoff_to=handoff_to), name=name, id=name) + super().__init__( + client=MockChatClient(name=name, handoff_to=handoff_to), + name=name, + id=name, + require_per_service_call_history_persistence=True, + ) class ContextAwareRefundClient(FunctionInvocationLayer[Any], ChatMiddlewareLayer[Any], BaseChatClient[Any]): @@ -255,6 +264,7 @@ async def test_resume_keeps_prior_user_context_for_same_agent() -> None: id="refund_agent", name="refund_agent", client=ContextAwareRefundClient(), + require_per_service_call_history_persistence=True, ) workflow = ( HandoffBuilder(participants=[refund_agent], termination_condition=lambda _: False) @@ -352,6 +362,7 @@ async def _get() -> ChatResponse: name="refund_agent", client=ApprovalReplayClient(), tools=[submit_refund_counted], + require_per_service_call_history_persistence=True, ) workflow = ( HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build() @@ -455,6 +466,7 @@ async def _get() -> ChatResponse: name="refund_agent", client=client, tools=[submit_refund], + require_per_service_call_history_persistence=True, ) workflow = ( HandoffBuilder(participants=[agent], termination_condition=lambda _: False).with_start_agent(agent).build() @@ -524,11 +536,13 @@ async def _get() -> ChatResponse: id="triage", name="triage", client=ReplaySafeHandoffClient(name="triage", handoff_sequence=["specialist", None]), + require_per_service_call_history_persistence=True, ) specialist = Agent( id="specialist", name="specialist", client=ReplaySafeHandoffClient(name="specialist", handoff_sequence=["triage"]), + require_per_service_call_history_persistence=True, ) workflow = ( @@ -652,11 +666,13 @@ async def _get() -> ChatResponse: name="refund_agent", client=refund_client, tools=[submit_refund], + require_per_service_call_history_persistence=True, ) order_agent = Agent( id="order_agent", name="order_agent", client=OrderReplayClient(), + require_per_service_call_history_persistence=True, ) workflow = ( HandoffBuilder(participants=[refund_agent, order_agent], termination_condition=lambda _: False) @@ -686,16 +702,6 @@ async def _get() -> ChatResponse: assert refund_client.resume_validated is True -def test_handoff_clone_disables_provider_side_storage() -> None: - """Handoff executors should force store=False to avoid stale provider call state.""" - triage = MockHandoffAgent(name="triage") - workflow = HandoffBuilder(participants=[triage]).with_start_agent(triage).build() - - executor = workflow.executors[resolve_agent_id(triage)] - assert isinstance(executor, HandoffAgentExecutor) - assert executor._agent.default_options.get("store") is False - - async def test_handoff_clone_preserves_per_service_call_history_persistence() -> None: """Handoff clones should keep per-service-call history persistence active for auto-handoff termination.""" triage_history = InMemoryHistoryProvider() @@ -711,6 +717,7 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() -> name="specialist", client=MockChatClient(name="specialist"), default_options={"tool_choice": "none"}, + require_per_service_call_history_persistence=True, ) workflow = ( @@ -738,21 +745,6 @@ async def test_handoff_clone_preserves_per_service_call_history_persistence() -> assert all(message.role != "tool" for message in stored_messages) -async def test_handoff_clears_stale_service_session_id_before_run() -> None: - """Stale service session IDs must be dropped before each handoff agent turn.""" - triage = MockHandoffAgent(name="triage", handoff_to="specialist") - specialist = MockHandoffAgent(name="specialist") - workflow = HandoffBuilder(participants=[triage, specialist]).with_start_agent(triage).build() - - triage_executor = workflow.executors[resolve_agent_id(triage)] - assert isinstance(triage_executor, HandoffAgentExecutor) - triage_executor._session.service_session_id = "resp_stale_value" - - await _drain(workflow.run("My order is damaged", stream=True)) - - assert triage_executor._session.service_session_id is None - - def test_clean_conversation_for_handoff_keeps_text_only_history() -> None: """Tool-control messages must be excluded from persisted handoff history.""" function_call = Content.from_function_call( @@ -791,52 +783,6 @@ def test_clean_conversation_for_handoff_keeps_text_only_history() -> None: ] -def test_persist_missing_approved_function_results_handles_runtime_and_fallback_outputs() -> None: - """Persisted history should retain approved call outputs across runtime shapes.""" - agent = MockHandoffAgent(name="triage") - executor = HandoffAgentExecutor(agent, handoffs=[]) - - call_with_runtime_result = "call-runtime-result" - call_with_approval_only = "call-approval-only" - - executor._full_conversation = [ - Message( - role="assistant", - contents=[ - Content.from_function_call(call_id=call_with_runtime_result, name="submit_refund", arguments={}), - Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}), - ], - ) - ] - - approval_response = Content.from_function_approval_response( - approved=True, - id=call_with_approval_only, - function_call=Content.from_function_call(call_id=call_with_approval_only, name="submit_refund", arguments={}), - ) - runtime_messages = [ - Message( - role="tool", - contents=[Content.from_function_result(call_id=call_with_runtime_result, result='{"submitted":true}')], - ), - Message(role="user", contents=[approval_response]), - ] - - executor._persist_missing_approved_function_results(runtime_tool_messages=runtime_messages, response_messages=[]) - - persisted_tool_messages = [message for message in executor._full_conversation if message.role == "tool"] - assert persisted_tool_messages - persisted_results = [ - content - for message in persisted_tool_messages - for content in message.contents - if content.type == "function_result" and content.call_id - ] - result_by_call_id = {content.call_id: content.result for content in persisted_results} - assert result_by_call_id[call_with_runtime_result] == '{"submitted":true}' - assert result_by_call_id[call_with_approval_only] == '{"status":"approved"}' - - async def test_autonomous_mode_yields_output_without_user_request(): """Ensure autonomous interaction mode yields output without requesting user input.""" triage = MockHandoffAgent(name="triage", handoff_to="specialist") @@ -979,7 +925,12 @@ async def _get() -> ChatResponse: return _get() - agent = Agent(id="order_agent", name="order_agent", client=FinalizingClient()) + agent = Agent( + id="order_agent", + name="order_agent", + client=FinalizingClient(), + require_per_service_call_history_persistence=True, + ) workflow = ( HandoffBuilder( participants=[agent], @@ -1061,6 +1012,7 @@ async def before_run(self, **kwargs: Any) -> None: name="test_agent", id="test_agent", context_providers=[context_provider], + require_per_service_call_history_persistence=True, ) # Verify the original agent has the context provider @@ -1104,8 +1056,8 @@ async def test_auto_handoff_middleware_intercepts_handoff_tool_call() -> None: middleware = _AutoHandoffMiddleware([HandoffConfiguration(target=target_id)]) @tool(name=get_handoff_tool_name(target_id), approval_mode="never_require") - def handoff_tool() -> str: - return "unreachable" + def handoff_tool() -> None: + pass context = FunctionInvocationContext(function=handoff_tool, arguments={}) call_next = AsyncMock() @@ -1136,6 +1088,20 @@ def regular_tool() -> str: assert context.result is None +def test_handoff_builder_rejects_agents_without_per_service_call_history_persistence() -> None: + """HandoffBuilder.build() should reject agents missing require_per_service_call_history_persistence.""" + agent_without_flag = Agent( + client=MockChatClient(name="no_flag"), + name="no_flag", + id="no_flag", + # require_per_service_call_history_persistence defaults to False + ) + agent_with_flag = MockHandoffAgent(name="has_flag") # MockHandoffAgent sets flag to True + + with pytest.raises(ValueError, match="require_per_service_call_history_persistence"): + HandoffBuilder(participants=[agent_without_flag, agent_with_flag]).with_start_agent(agent_with_flag).build() + + def test_handoff_builder_rejects_non_agent_supports_agent_run(): """Verify that participants() rejects SupportsAgentRun implementations that are not Agent instances.""" from agent_framework import AgentResponse, AgentSession, SupportsAgentRun @@ -1160,3 +1126,246 @@ def get_session(self, *, service_session_id, **kwargs): with pytest.raises(TypeError, match="Participants must be Agent instances"): HandoffBuilder().participants([fake]) + + +# endregion + +# region integration tests + + +try: + from agent_framework.foundry import FoundryChatClient + from azure.identity import AzureCliCredential + + _has_foundry_deps = True +except ImportError: + _has_foundry_deps = False + +skip_if_foundry_integration_tests_disabled = pytest.mark.skipif( + not _has_foundry_deps or os.getenv("FOUNDRY_PROJECT_ENDPOINT", "") == "" or os.getenv("FOUNDRY_MODEL", "") == "", + reason="No real FOUNDRY_PROJECT_ENDPOINT or FOUNDRY_MODEL provided; skipping integration tests.", +) + + +@pytest.mark.integration +@skip_if_foundry_integration_tests_disabled +@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")]) +async def test_simple_handoff_workflow(store: bool) -> None: + """Test a simple handoff workflow with two agents.""" + client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + triage_agent = Agent( + client=client, + instructions=( + "You are frontline support triage. Route customer issues to the appropriate specialist agents " + "based on the problem described." + ), + name="triage_agent", + default_options={"store": store}, + require_per_service_call_history_persistence=True, + ) + + refund_agent = Agent( + client=client, + instructions="You process refund requests. Ask user the ID of the order they want refunded.", + name="refund_agent", + default_options={"store": store}, + require_per_service_call_history_persistence=True, + ) + + workflow = ( + HandoffBuilder( + participants=[triage_agent, refund_agent], + termination_condition=lambda conversation: ( + # We terminate after triage hands off to refund to test handoff works + len(conversation) > 0 and conversation[-1].author_name == refund_agent.name + ), + ) + .with_start_agent(triage_agent) + .build() + ) + + workflow_result = await workflow.run("I want to get a refund") + # The workflow should end in IDLE state rather than IDLE_WITH_PENDING_REQUESTS + # because the termination condition is met right after the refund agent's response. + assert workflow_result.get_final_state() == WorkflowRunState.IDLE + # Output should contain responses from both agents and a final full conversation from between them. + assert len(workflow_result.get_outputs()) == 3 + # There will be exactly one handoff request + handoff_event = [event for event in workflow_result if event.type == "handoff_sent"] + assert len(handoff_event) == 1 + assert isinstance(handoff_event[0].data, HandoffSentEvent) + assert handoff_event[0].data.source == triage_agent.name + assert handoff_event[0].data.target == refund_agent.name + + +@pytest.mark.integration +@skip_if_foundry_integration_tests_disabled +@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")]) +async def test_simple_handoff_workflow_with_request_and_response(store: bool) -> None: + """Test a simple handoff workflow with two agents where the second agent makes a request after handoff.""" + client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + triage_agent = Agent( + client=client, + instructions=( + "You are frontline support triage. Route customer issues to the appropriate specialist agents " + "based on the problem described." + ), + name="triage_agent", + default_options={"store": store}, + require_per_service_call_history_persistence=True, + ) + + refund_agent = Agent( + client=client, + instructions="You process refund requests. Ask user the ID of the order they want refunded.", + name="refund_agent", + default_options={"store": store}, + require_per_service_call_history_persistence=True, + ) + + workflow = ( + HandoffBuilder( + participants=[triage_agent, refund_agent], + termination_condition=lambda conversation: ( + # We terminate after the refund agent request user input and the user provides + # a response. There will be two user messages in the conversation at that point + # - the original user message and the follow-up message in response to the refund + # agent's request. + len([message for message in conversation if message.role == "user"]) == 2 + ), + ) + .with_start_agent(triage_agent) + .build() + ) + + workflow_result = await workflow.run("I want to get a refund") + # The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE + # because the user has not yet responded to the refund agent's request yet. + assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + # There will be exactly one handoff request + handoff_event = [event for event in workflow_result if event.type == "handoff_sent"] + assert len(handoff_event) == 1 + assert isinstance(handoff_event[0].data, HandoffSentEvent) + assert handoff_event[0].data.source == triage_agent.name + assert handoff_event[0].data.target == refund_agent.name + # There should be exactly one request for information from the refund agent after handoff + request_events = [event for event in workflow_result if event.type == "request_info"] + assert len(request_events) == 1 + assert isinstance(request_events[0].data, HandoffAgentUserRequest) + # Provide the user's response to the refund agent's request to allow the workflow to complete. + workflow_result = await workflow.run( + responses={ + request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"), + }, + ) + + # The workflow should now end in IDLE state since the termination condition + # is met after the user's response to the refund agent's request. + assert workflow_result.get_final_state() == WorkflowRunState.IDLE + + +@tool(approval_mode="always_require") +def process_refund(order_number: Annotated[str, "Order number to process refund for"]) -> str: + """Simulated function to process a refund for a given order number.""" + return f"Refund processed successfully for order {order_number}." + + +@pytest.mark.integration +@skip_if_foundry_integration_tests_disabled +@pytest.mark.parametrize("store", [param(False, id="store=False"), param(True, id="store=True")]) +async def test_simple_handoff_workflow_with_approval_request(store: bool) -> None: + """Test a simple handoff workflow with two agents where the second agent makes a request after handoff.""" + client = FoundryChatClient( + project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"], + model=os.environ["FOUNDRY_MODEL"], + credential=AzureCliCredential(), + ) + + triage_agent = Agent( + client=client, + instructions=( + "You are frontline support triage. Route customer issues to the appropriate specialist agents " + "based on the problem described." + ), + name="triage_agent", + default_options={"store": store}, + require_per_service_call_history_persistence=True, + ) + + refund_agent = Agent( + client=client, + instructions="You process refund requests. Ask user the ID of the order they want refunded.", + name="refund_agent", + default_options={"store": store}, + tools=[process_refund], + require_per_service_call_history_persistence=True, + ) + + # This workflow will be terminated manually + workflow = ( + HandoffBuilder( + participants=[triage_agent, refund_agent], + ) + .with_start_agent(triage_agent) + .build() + ) + + workflow_result = await workflow.run("I want to get a refund") + # The workflow should end in IDLE_WITH_PENDING_REQUESTS state rather than IDLE + # because the user has not yet responded to the refund agent's request yet. + assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + # There will be exactly one handoff request + handoff_event = [event for event in workflow_result if event.type == "handoff_sent"] + assert len(handoff_event) == 1 + assert isinstance(handoff_event[0].data, HandoffSentEvent) + assert handoff_event[0].data.source == triage_agent.name + assert handoff_event[0].data.target == refund_agent.name + # There should be exactly one request for information from the refund agent after handoff + request_events = [event for event in workflow_result if event.type == "request_info"] + assert len(request_events) == 1 + assert isinstance(request_events[0].data, HandoffAgentUserRequest) + # Provide the user's response to the refund agent's request to allow the workflow to complete. + workflow_result = await workflow.run( + responses={ + request_events[0].request_id: HandoffAgentUserRequest.create_response("My order number is 12345"), + }, + ) + + # The workflow should now end in IDLE_WITH_PENDING_REQUESTS state since the refund agent will ask for + # approval to process the refund after receiving the user's response. + assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + + # There should be exactly one request for tool approval from the refund agent. + request_events = [event for event in workflow_result if event.type == "request_info"] + assert len(request_events) == 1 + assert isinstance(request_events[0].data, Content) and request_events[0].data.type == "function_approval_request" + + # Provide the user's response to the refund agent's request to allow the workflow to complete. + workflow_result = await workflow.run( + responses={request_events[0].request_id: request_events[0].data.to_function_approval_response(approved=True)} + ) + + # The refund agent will process the refund after receiving approval, but since there is no termination condition, + # the workflow will end in IDLE_WITH_PENDING_REQUESTS state waiting for further user input. + assert workflow_result.get_final_state() == WorkflowRunState.IDLE_WITH_PENDING_REQUESTS + # There should be exactly one request for information from the refund agent after processing the refund, + # which is the follow-up question asking if there is anything else they can help with. + request_events = [event for event in workflow_result if event.type == "request_info"] + assert len(request_events) == 1 + assert isinstance(request_events[0].data, HandoffAgentUserRequest) + workflow_result = await workflow.run(responses={request_events[0].request_id: HandoffAgentUserRequest.terminate()}) + + assert workflow_result.get_final_state() == WorkflowRunState.IDLE + + +# endregion diff --git a/python/samples/03-workflows/agents/handoff_workflow_as_agent.py b/python/samples/03-workflows/agents/handoff_workflow_as_agent.py index c7d06b535a..f392692af9 100644 --- a/python/samples/03-workflows/agents/handoff_workflow_as_agent.py +++ b/python/samples/03-workflows/agents/handoff_workflow_as_agent.py @@ -80,6 +80,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent "based on the problem described." ), name="triage_agent", + require_per_service_call_history_persistence=True, ) # Refund specialist: Handles refund requests @@ -89,6 +90,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent name="refund_agent", # In a real application, an agent can have multiple tools; here we keep it simple tools=[process_refund], + require_per_service_call_history_persistence=True, ) # Order/shipping specialist: Resolves delivery issues @@ -98,6 +100,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent name="order_agent", # In a real application, an agent can have multiple tools; here we keep it simple tools=[check_order_status], + require_per_service_call_history_persistence=True, ) # Return specialist: Handles return requests @@ -107,6 +110,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent name="return_agent", # In a real application, an agent can have multiple tools; here we keep it simple tools=[process_return], + require_per_service_call_history_persistence=True, ) return triage_agent, refund_agent, order_agent, return_agent diff --git a/python/samples/03-workflows/orchestrations/README.md b/python/samples/03-workflows/orchestrations/README.md index 1f5f43c00f..0c4406c247 100644 --- a/python/samples/03-workflows/orchestrations/README.md +++ b/python/samples/03-workflows/orchestrations/README.md @@ -85,6 +85,8 @@ from agent_framework.orchestrations import ( **Handoff workflow tip**: Handoff workflows maintain the full conversation history including any `Message.additional_properties` emitted by your agents. This ensures routing metadata remains intact across all agent transitions. For specialist-to-specialist handoffs, use `.add_handoff(source, targets)` to configure which agents can route to which others with a fluent, type-safe API. +**Handoff `require_per_service_call_history_persistence`**: All agents in a handoff workflow **must** set `require_per_service_call_history_persistence=True`. `HandoffBuilder.build()` will raise a `ValueError` if any participant is missing this flag. This is required because handoff middleware short-circuits tool calls via `MiddlewareTermination`, and without per-service-call history persistence, local history would store tool results the service never received, causing mismatches on subsequent turns. + **Sequential orchestration note**: Sequential orchestration uses a few small adapter nodes for plumbing: - `input-conversation` normalizes input to `list[Message]` - `to-conversation:` converts agent responses into the shared conversation diff --git a/python/samples/03-workflows/orchestrations/handoff_autonomous.py b/python/samples/03-workflows/orchestrations/handoff_autonomous.py index 355a782f9d..7d192f5748 100644 --- a/python/samples/03-workflows/orchestrations/handoff_autonomous.py +++ b/python/samples/03-workflows/orchestrations/handoff_autonomous.py @@ -53,6 +53,7 @@ def create_agents( "Assign the two tasks to the appropriate specialists, one after the other." ), name="coordinator", + require_per_service_call_history_persistence=True, ) research_agent = Agent( @@ -66,6 +67,7 @@ def create_agents( "coordinator. Keep each individual response focused on one aspect." ), name="research_agent", + require_per_service_call_history_persistence=True, ) summary_agent = Agent( @@ -75,6 +77,7 @@ def create_agents( "control to the coordinator." ), name="summary_agent", + require_per_service_call_history_persistence=True, ) return coordinator, research_agent, summary_agent diff --git a/python/samples/03-workflows/orchestrations/handoff_simple.py b/python/samples/03-workflows/orchestrations/handoff_simple.py index b804c5e63a..a95377f92c 100644 --- a/python/samples/03-workflows/orchestrations/handoff_simple.py +++ b/python/samples/03-workflows/orchestrations/handoff_simple.py @@ -77,6 +77,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent "based on the problem described." ), name="triage_agent", + require_per_service_call_history_persistence=True, ) # Refund specialist: Handles refund requests @@ -86,6 +87,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent name="refund_agent", # In a real application, an agent can have multiple tools; here we keep it simple tools=[process_refund], + require_per_service_call_history_persistence=True, ) # Order/shipping specialist: Resolves delivery issues @@ -95,6 +97,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent name="order_agent", # In a real application, an agent can have multiple tools; here we keep it simple tools=[check_order_status], + require_per_service_call_history_persistence=True, ) # Return specialist: Handles return requests @@ -104,6 +107,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent, Agent name="return_agent", # In a real application, an agent can have multiple tools; here we keep it simple tools=[process_return], + require_per_service_call_history_persistence=True, ) return triage_agent, refund_agent, order_agent, return_agent diff --git a/python/samples/03-workflows/orchestrations/handoff_with_code_interpreter_file.py b/python/samples/03-workflows/orchestrations/handoff_with_code_interpreter_file.py index ced94109a2..e3b461eec6 100644 --- a/python/samples/03-workflows/orchestrations/handoff_with_code_interpreter_file.py +++ b/python/samples/03-workflows/orchestrations/handoff_with_code_interpreter_file.py @@ -105,6 +105,7 @@ async def main() -> None: "When the user asks to create or generate files, hand off to code_specialist " "by calling handoff_to_code_specialist." ), + require_per_service_call_history_persistence=True, ) code_interpreter_tool = client.get_code_interpreter_tool() @@ -117,6 +118,7 @@ async def main() -> None: "and create files when requested. Always save files to /mnt/data/ directory." ), tools=[code_interpreter_tool], + require_per_service_call_history_persistence=True, ) workflow = ( diff --git a/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py b/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py index 4b65561532..b27fa6ce6f 100644 --- a/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py +++ b/python/samples/03-workflows/orchestrations/handoff_with_tool_approval_checkpoint_resume.py @@ -71,6 +71,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent]: "if they need refund help or order tracking. Use handoff_to_refund_agent or " "handoff_to_order_agent to transfer them." ), + require_per_service_call_history_persistence=True, ) refund = Agent( @@ -83,6 +84,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent]: "to record the request before continuing." ), tools=[submit_refund], + require_per_service_call_history_persistence=True, ) order = Agent( @@ -92,6 +94,7 @@ def create_agents(client: FoundryChatClient) -> tuple[Agent, Agent, Agent]: "You are an order tracking specialist. Help customers track their orders. " "Ask for order numbers and provide shipping updates." ), + require_per_service_call_history_persistence=True, ) return triage, refund, order diff --git a/python/samples/05-end-to-end/ag_ui_workflow_handoff/README.md b/python/samples/05-end-to-end/ag_ui_workflow_handoff/README.md index 099c1b6c14..d0ab934ea3 100644 --- a/python/samples/05-end-to-end/ag_ui_workflow_handoff/README.md +++ b/python/samples/05-end-to-end/ag_ui_workflow_handoff/README.md @@ -16,6 +16,10 @@ It includes: The backend uses Azure OpenAI responses and supports intent-driven, non-linear handoff routing. +This demo keeps workflow state per `thread_id`. When the assistant ends a case with `Case complete.`, the UI blocks +later top-level input on that same thread and asks the user to start a new case explicitly instead of resuming a +terminated workflow. + ## Folder Layout - `backend/server.py` - FastAPI + AG-UI endpoint + Handoff workflow @@ -81,6 +85,28 @@ VITE_BACKEND_URL=http://127.0.0.1:8891 npm run dev 7. When replacement is requested, wait for the `submit_replacement` reviewer interrupt and approve/reject it. 8. If you asked for refund-only, the flow should close without replacement/shipping prompts. 9. Confirm the case snapshot updates and workflow completion. +10. After the case closes, another top-level message on the same thread is rejected with a notice. +11. Click **Start New Case** to begin a fresh thread. + +## Important: `require_per_service_call_history_persistence` + +All agents participating in a handoff workflow **must** be constructed with +`require_per_service_call_history_persistence=True`. The `HandoffBuilder` will +raise a `ValueError` at build time if any participant is missing this flag. + +**Why this is required:** Handoff workflows use middleware that short-circuits +tool calls via `MiddlewareTermination` when a handoff tool is invoked. Without +per-service-call history persistence, local history providers would persist tool +results that the service never received, causing call/result mismatches on +subsequent turns. + +```python +agent = Agent( + client=client, + name="my_agent", + require_per_service_call_history_persistence=True, # Required for handoff +) +``` ## What This Validates diff --git a/python/samples/05-end-to-end/ag_ui_workflow_handoff/backend/server.py b/python/samples/05-end-to-end/ag_ui_workflow_handoff/backend/server.py index 02329e8e16..c33774b65b 100644 --- a/python/samples/05-end-to-end/ag_ui_workflow_handoff/backend/server.py +++ b/python/samples/05-end-to-end/ag_ui_workflow_handoff/backend/server.py @@ -17,12 +17,17 @@ import logging.handlers import os import random +from collections.abc import AsyncGenerator +from typing import Any import uvicorn from agent_framework import ( Agent, Message, Workflow, + WorkflowBuilder, + WorkflowContext, + executor, tool, ) from agent_framework.ag_ui import AgentFrameworkWorkflow, add_agent_framework_fastapi_endpoint @@ -101,6 +106,7 @@ def create_agents() -> tuple[Agent, Agent, Agent]: "4. If the issue is fully resolved, send a concise wrap-up that ends with exactly: Case complete." ), client=client, + require_per_service_call_history_persistence=True, ) refund = Agent( @@ -126,6 +132,7 @@ def create_agents() -> tuple[Agent, Agent, Agent]: ), client=client, tools=[lookup_order_details, submit_refund], + require_per_service_call_history_persistence=True, ) order = Agent( @@ -149,19 +156,25 @@ def create_agents() -> tuple[Agent, Agent, Agent]: ), client=client, tools=[lookup_order_details, submit_replacement], + require_per_service_call_history_persistence=True, ) return triage, refund, order +def is_case_complete_text(text: str) -> bool: + """Return True when a message ends with the explicit demo completion marker.""" + + return text.strip().lower().endswith("case complete.") + + def _termination_condition(conversation: list[Message]) -> bool: """Stop when any assistant emits an explicit completion marker.""" for message in reversed(conversation): if message.role != "assistant": continue - text = (message.text or "").strip().lower() - if text.endswith("case complete."): + if is_case_complete_text(message.text or ""): return True return False @@ -215,6 +228,71 @@ def create_handoff_workflow() -> Workflow: return builder.with_start_agent(triage).build() +def create_closed_case_notice_workflow() -> Workflow: + """Build a tiny workflow that explains why a completed case cannot continue.""" + + @executor(id="closed_case_notice") + async def closed_case_notice(message: Message | None, ctx: WorkflowContext[None, str]) -> None: + del message + await ctx.yield_output( + "Your case is complete, but you're trying to do something new. Please start a new thread." + ) + + return WorkflowBuilder(start_executor=closed_case_notice).build() + + +class DemoHandoffWorkflow(AgentFrameworkWorkflow): + """Workflow wrapper that blocks new top-level input on completed demo threads.""" + + def __init__(self) -> None: + super().__init__( + workflow_factory=lambda _thread_id: create_handoff_workflow(), + name="ag_ui_handoff_workflow_demo", + description="Dynamic handoff workflow demo with tool approvals and request_info resumes.", + ) + self._completed_threads: set[str] = set() + self._closed_case_notice_runner = AgentFrameworkWorkflow(workflow=create_closed_case_notice_workflow()) + + async def run(self, input_data: dict[str, Any]) -> AsyncGenerator[Any]: + """Intercept completed threads and return a helpful notice instead of resuming them.""" + + thread_id = self._thread_id_from_input(input_data) + has_messages = isinstance(input_data.get("messages"), list) and len(input_data.get("messages", [])) > 0 + has_resume = input_data.get("resume") is not None + + if thread_id in self._completed_threads and has_messages and not has_resume: + async for event in self._closed_case_notice_runner.run(input_data): + yield event + return + + message_text_by_id: dict[str, str] = {} + case_completed_this_run = False + + async for event in super().run(input_data): + event_type = getattr(event, "type", None) + if event_type == "TEXT_MESSAGE_START": + message_id = getattr(event, "message_id", None) + if isinstance(message_id, str): + message_text_by_id[message_id] = "" + elif event_type == "TEXT_MESSAGE_CONTENT": + message_id = getattr(event, "message_id", None) + delta = getattr(event, "delta", None) + if isinstance(message_id, str) and isinstance(delta, str): + message_text_by_id[message_id] = f"{message_text_by_id.get(message_id, '')}{delta}" + elif event_type == "TEXT_MESSAGE_END": + message_id = getattr(event, "message_id", None) + if isinstance(message_id, str): + final_text = message_text_by_id.pop(message_id, "") + if is_case_complete_text(final_text): + case_completed_this_run = True + + yield event + + if case_completed_this_run: + self._completed_threads.add(thread_id) + self.clear_thread_workflow(thread_id) + + def create_app() -> FastAPI: """Create and configure the FastAPI application.""" @@ -231,11 +309,7 @@ def create_app() -> FastAPI: allow_headers=["*"], ) - demo_workflow = AgentFrameworkWorkflow( - workflow_factory=lambda _thread_id: create_handoff_workflow(), - name="ag_ui_handoff_workflow_demo", - description="Dynamic handoff workflow demo with tool approvals and request_info resumes.", - ) + demo_workflow = DemoHandoffWorkflow() add_agent_framework_fastapi_endpoint( app=app, diff --git a/python/samples/05-end-to-end/ag_ui_workflow_handoff/frontend/src/App.tsx b/python/samples/05-end-to-end/ag_ui_workflow_handoff/frontend/src/App.tsx index 4f45d51064..d04e450f5a 100644 --- a/python/samples/05-end-to-end/ag_ui_workflow_handoff/frontend/src/App.tsx +++ b/python/samples/05-end-to-end/ag_ui_workflow_handoff/frontend/src/App.tsx @@ -54,6 +54,16 @@ const STARTER_PROMPTS = [ "Help me with a damaged-order refund and replacement.", ]; +const DEFAULT_CASE_SNAPSHOT: CaseSnapshot = { + orderId: "Not captured", + refundAmount: "Not captured", + refundApproved: "pending", + shippingPreference: "Not selected", +}; + +const CLOSED_CASE_NOTICE = + "This case is already complete. Start a new case to open a fresh thread for a new request."; + function randomId(): string { if (typeof crypto !== "undefined" && typeof crypto.randomUUID === "function") { return crypto.randomUUID(); @@ -213,6 +223,10 @@ function normalizeTextForDedupe(text: string): string { return text.replace(/\s+/g, " ").trim(); } +function isCaseCompleteText(text: string): boolean { + return text.trim().toLowerCase().endsWith("case complete."); +} + function normalizeShippingPreference(text: string): string | null { const normalized = text.trim().toLowerCase(); if (normalized.length === 0) { @@ -263,24 +277,21 @@ export default function App(): JSX.Element { const assistantMessageIndexRef = useRef>({}); const activeRunIdRef = useRef(null); const pendingUsageRef = useRef(null); + const caseClosedRef = useRef(false); const [messages, setMessages] = useState([]); const [requestInfoById, setRequestInfoById] = useState>({}); const [pendingInterrupts, setPendingInterrupts] = useState([]); const [activeAgent, setActiveAgent] = useState("triage_agent"); const [visitedAgents, setVisitedAgents] = useState>(new Set(["triage_agent"])); - const [caseSnapshot, setCaseSnapshot] = useState({ - orderId: "Not captured", - refundAmount: "Not captured", - refundApproved: "pending", - shippingPreference: "Not selected", - }); + const [caseSnapshot, setCaseSnapshot] = useState(DEFAULT_CASE_SNAPSHOT); const [statusText, setStatusText] = useState("Ready"); const [isRunning, setIsRunning] = useState(false); const [inputText, setInputText] = useState(""); const [isApprovalModalOpen, setIsApprovalModalOpen] = useState(false); const [latestUsage, setLatestUsage] = useState(null); const [usageHistory, setUsageHistory] = useState([]); + const [isCaseClosed, setIsCaseClosed] = useState(false); const currentInterrupt = pendingInterrupts[0]; const currentInterruptKind = currentInterrupt ? interruptKind(currentInterrupt) : "unknown"; @@ -288,6 +299,7 @@ export default function App(): JSX.Element { const interruptPrompt = currentInterrupt ? extractPromptFromInterrupt(currentInterrupt, currentRequestInfo) : "No pending interrupt."; + const canStartFreshCase = !currentInterrupt && isCaseClosed; const functionCall = currentInterrupt ? extractFunctionCallFromInterrupt(currentInterrupt) : null; const functionArguments = useMemo(() => parseFunctionArguments(functionCall), [functionCall]); @@ -304,6 +316,34 @@ export default function App(): JSX.Element { setMessages((prev) => [...prev, message]); }; + const pushSystemNotice = (text: string): void => { + setMessages((prev) => { + if (prev.length > 0 && prev[prev.length - 1]?.role === "system" && prev[prev.length - 1]?.text === text) { + return prev; + } + return [...prev, { id: randomId(), role: "system", text }]; + }); + }; + + const resetConversationState = (): void => { + threadIdRef.current = randomId(); + assistantMessageIndexRef.current = {}; + activeRunIdRef.current = null; + pendingUsageRef.current = null; + caseClosedRef.current = false; + + setMessages([]); + setRequestInfoById({}); + setPendingInterrupts([]); + setActiveAgent("triage_agent"); + setVisitedAgents(new Set(["triage_agent"])); + setCaseSnapshot(DEFAULT_CASE_SNAPSHOT); + setStatusText("Ready"); + setInputText(""); + setIsApprovalModalOpen(false); + setIsCaseClosed(false); + }; + const rebuildAssistantMessageIndex = (items: DisplayMessage[]): void => { const next: Record = {}; items.forEach((item, index) => { @@ -364,6 +404,10 @@ export default function App(): JSX.Element { } const candidate = prev[index]; if (candidate.role === "user" || candidate.text.trim().length > 0) { + if (candidate.role === "assistant" && isCaseCompleteText(candidate.text)) { + caseClosedRef.current = true; + setIsCaseClosed(true); + } return prev; } const next = prev.filter((item) => item.id !== messageId); @@ -565,7 +609,9 @@ export default function App(): JSX.Element { } setPendingInterrupts(interruptPayload); - setStatusText(interruptPayload.length > 0 ? "Waiting for input" : "Run complete"); + setStatusText( + interruptPayload.length > 0 ? "Waiting for input" : caseClosedRef.current ? "Case complete" : "Run complete" + ); setIsRunning(false); break; } @@ -652,6 +698,12 @@ export default function App(): JSX.Element { }; const startNewTurn = async (text: string): Promise => { + if (caseClosedRef.current && pendingInterrupts.length === 0) { + pushSystemNotice(CLOSED_CASE_NOTICE); + setStatusText("Case complete"); + return; + } + pushMessage({ id: randomId(), role: "user", text }); await runWithPayload({ @@ -873,7 +925,20 @@ export default function App(): JSX.Element {

Pending Action

- {!currentInterrupt &&

No interrupt pending. Start with one of the prompts below.

} + {!currentInterrupt && ( +
+

+ {isCaseClosed + ? "This case is closed. New top-level messages on this thread are blocked until you start a new case." + : "No interrupt pending. Start with one of the prompts below."} +

+ {canStartFreshCase && ( + + )} +
+ )} {currentInterrupt && (
@@ -907,7 +972,7 @@ export default function App(): JSX.Element {
)} - {!currentInterrupt && ( + {!currentInterrupt && !isCaseClosed && (
{STARTER_PROMPTS.map((prompt) => (