diff --git a/mini_agent/agent.py b/mini_agent/agent.py index b7d7feab..6091a4c0 100644 --- a/mini_agent/agent.py +++ b/mini_agent/agent.py @@ -4,13 +4,13 @@ import json from pathlib import Path from time import perf_counter -from typing import Optional +from typing import AsyncGenerator, Optional import tiktoken from .llm import LLMClient from .logger import AgentLogger -from .schema import Message +from .schema import AgentEvent, Message, ToolCallEvent, ToolResultEvent from .tools.base import Tool, ToolResult from .utils import calculate_display_width @@ -518,6 +518,200 @@ async def run(self, cancel_event: Optional[asyncio.Event] = None) -> str: print(f"\n{Colors.BRIGHT_YELLOW}⚠️ {error_msg}{Colors.RESET}") return error_msg + async def run_streaming( + self, cancel_event: Optional[asyncio.Event] = None + ) -> AsyncGenerator[AgentEvent, None]: + """Stream agent execution as a series of structured events. + + This is the streaming counterpart to run(). All core logic (LLM calls, + tool execution, token tracking, summarization, cancellation) is + identical to run(); only the output format differs. + + Each yielded AgentEvent has a 'type' field that acts as a discriminator: + - "thinking" -> Agent is generating extended thinking (content/thinking set) + - "content" -> Agent is generating text output (content set) + - "tool_call" -> Agent requested a tool (tool_call set) + - "tool_result"-> Tool execution finished (tool_result set) + - "step_complete" -> Step finished (step set) + - "final" -> Agent finished (final_text set, total_tokens set) + + Example consumer (SSE): + async for event in agent.run_streaming(): + if event.type == "content": + yield f"data: {event.content}\\n\\n" + elif event.type == "final": + yield f"data: [DONE]\\n\\n" + + Args: + cancel_event: Optional asyncio.Event to signal cancellation. + When set, the agent stops at the next safe checkpoint. + + Yields: + AgentEvent objects describing each step of execution. + """ + if cancel_event is not None: + self.cancel_event = cancel_event + + self.logger.start_new_run() + + step = 0 + final_text = "" + + while step < self.max_steps: + # Check cancellation + if self._check_cancelled(): + self._cleanup_incomplete_messages() + final_text = "Task cancelled by user." + yield AgentEvent(type="final", final_text=final_text) + return + + # Summarize if needed + await self._summarize_messages() + + tool_list = list(self.tools.values()) + + self.logger.log_request(messages=self.messages, tools=tool_list) + + try: + response = await self.llm.generate( + messages=self.messages, tools=tool_list + ) + except Exception as e: + from .retry import RetryExhaustedError + + if isinstance(e, RetryExhaustedError): + final_text = ( + f"LLM call failed after {e.attempts} retries. " + f"Last error: {str(e.last_exception)}" + ) + else: + final_text = f"LLM call failed: {str(e)}" + yield AgentEvent(type="final", final_text=final_text) + return + + if response.usage: + self.api_total_tokens = response.usage.total_tokens + + self.logger.log_response( + content=response.content, + thinking=response.thinking, + tool_calls=response.tool_calls, + finish_reason=response.finish_reason, + ) + + # Stream thinking + if response.thinking: + yield AgentEvent(type="thinking", thinking=response.thinking) + + # Stream content + if response.content: + yield AgentEvent(type="content", content=response.content) + + # Add assistant message to history + assistant_msg = Message( + role="assistant", + content=response.content, + thinking=response.thinking, + tool_calls=response.tool_calls, + ) + self.messages.append(assistant_msg) + + # No tool calls -> done + if not response.tool_calls: + yield AgentEvent( + type="final", + final_text=response.content, + total_tokens=self.api_total_tokens, + ) + return + + # Check cancellation before tools + if self._check_cancelled(): + self._cleanup_incomplete_messages() + final_text = "Task cancelled by user." + yield AgentEvent(type="final", final_text=final_text) + return + + # Execute tools + for tool_call in response.tool_calls: + tc_id = tool_call.id + fn_name = tool_call.function.name + fn_args = tool_call.function.arguments + + yield AgentEvent( + type="tool_call", + tool_call=ToolCallEvent( + id=tc_id, + name=fn_name, + arguments=fn_args, + ), + ) + + # Tool lookup + if fn_name not in self.tools: + result = ToolResult( + success=False, + content="", + error=f"Unknown tool: {fn_name}", + ) + else: + try: + tool = self.tools[fn_name] + result = await tool.execute(**fn_args) + except Exception as e: + import traceback + + result = ToolResult( + success=False, + content="", + error=( + f"Tool execution failed: {type(e).__name__}: {str(e)}\n" + f"Traceback:\n{traceback.format_exc()}" + ), + ) + + self.logger.log_tool_result( + tool_name=fn_name, + arguments=fn_args, + result_success=result.success, + result_content=result.content if result.success else None, + result_error=result.error if not result.success else None, + ) + + yield AgentEvent( + type="tool_result", + tool_result=ToolResultEvent( + id=tc_id, + name=fn_name, + success=result.success, + content=result.content, + error=result.error, + ), + ) + + # Add tool result message + tool_msg = Message( + role="tool", + content=result.content if result.success else f"Error: {result.error}", + tool_call_id=tc_id, + name=fn_name, + ) + self.messages.append(tool_msg) + + # Check cancellation after each tool + if self._check_cancelled(): + self._cleanup_incomplete_messages() + final_text = "Task cancelled by user." + yield AgentEvent(type="final", final_text=final_text) + return + + yield AgentEvent(type="step_complete", step=step + 1) + step += 1 + + # Max steps + final_text = f"Task couldn't be completed after {self.max_steps} steps." + yield AgentEvent(type="final", final_text=final_text) + def get_history(self) -> list[Message]: """Get message history.""" return self.messages.copy() diff --git a/mini_agent/schema/__init__.py b/mini_agent/schema/__init__.py index e4dc1f01..b877c8fd 100644 --- a/mini_agent/schema/__init__.py +++ b/mini_agent/schema/__init__.py @@ -1,19 +1,25 @@ """Schema definitions for Mini-Agent.""" from .schema import ( + AgentEvent, FunctionCall, LLMProvider, LLMResponse, Message, TokenUsage, ToolCall, + ToolCallEvent, + ToolResultEvent, ) __all__ = [ + "AgentEvent", "FunctionCall", "LLMProvider", "LLMResponse", "Message", "TokenUsage", "ToolCall", + "ToolCallEvent", + "ToolResultEvent", ] diff --git a/mini_agent/schema/schema.py b/mini_agent/schema/schema.py index 4bffb442..a1dce05c 100644 --- a/mini_agent/schema/schema.py +++ b/mini_agent/schema/schema.py @@ -53,3 +53,45 @@ class LLMResponse(BaseModel): tool_calls: list[ToolCall] | None = None finish_reason: str usage: TokenUsage | None = None # Token usage from API response + + +class ToolCallEvent(BaseModel): + """A tool call event emitted before tool execution.""" + + id: str # tool_call id + name: str # function name + arguments: dict[str, Any] # function arguments + + +class ToolResultEvent(BaseModel): + """A tool result event emitted after tool execution.""" + + id: str # tool_call id + name: str # function name + success: bool + content: str = "" + error: str | None = None + + +class AgentEvent(BaseModel): + """Streaming event emitted by Agent.run_streaming().""" + + # Event type discriminator + type: str # "thinking" | "content" | "tool_call" | "tool_result" | "step_complete" | "final" + + # Text/think chunks + content: str | None = None + thinking: str | None = None + + # Tool events + tool_call: ToolCallEvent | None = None + tool_result: ToolResultEvent | None = None + + # Step metadata + step: int | None = None + + # Final result + final_text: str | None = None + + # Token usage + total_tokens: int | None = None diff --git a/tests/test_streaming.py b/tests/test_streaming.py new file mode 100644 index 00000000..412bdf51 --- /dev/null +++ b/tests/test_streaming.py @@ -0,0 +1,233 @@ +"""Tests for Agent.run_streaming().""" + +import asyncio +import pytest + +from mini_agent.agent import Agent +from mini_agent.schema import AgentEvent, Message + + +class MockLLMClient: + """Minimal mock LLM client for testing.""" + + def __init__(self, response_content: str = "Test response", tool_calls: list | None = None): + self.response_content = response_content + self.tool_calls = tool_calls or [] + self.call_count = 0 + + async def generate(self, messages: list[Message], tools: list | None = None): + self.call_count += 1 + from mini_agent.schema import LLMResponse, TokenUsage + + class FakeResponse: + content = self.response_content + thinking = None + tool_calls = self.tool_calls if self.call_count == 1 else None + finish_reason = "tool_use" if self.tool_calls and self.call_count == 1 else "stop" + usage = TokenUsage(prompt_tokens=100, completion_tokens=50, total_tokens=150) + + return FakeResponse() + + +@pytest.mark.asyncio +async def test_run_streaming_yields_content_event(): + """run_streaming() yields content event when LLM returns text.""" + agent = Agent( + llm_client=MockLLMClient(response_content="Hello world"), + system_prompt="You are a helpful assistant.", + tools=[], + max_steps=5, + ) + agent.add_user_message("Say hello") + + events = [e async for e in agent.run_streaming()] + + assert len(events) >= 1 + assert events[0].type == "content" + assert events[0].content == "Hello world" + assert events[-1].type == "final" + assert events[-1].final_text == "Hello world" + + +@pytest.mark.asyncio +async def test_run_streaming_yields_final_event(): + """run_streaming() always ends with a final event.""" + agent = Agent( + llm_client=MockLLMClient(response_content="Done"), + system_prompt="You are a helpful assistant.", + tools=[], + max_steps=5, + ) + agent.add_user_message("End the conversation") + + events = [e async for e in agent.run_streaming()] + + assert any(e.type == "final" for e in events) + final_event = next(e for e in events if e.type == "final") + assert final_event.final_text == "Done" + assert final_event.total_tokens == 150 + + +@pytest.mark.asyncio +async def test_run_streaming_event_types(): + """run_streaming() yields correct event types in order.""" + from mini_agent.schema import FunctionCall, ToolCall + + tool_calls = [ + ToolCall( + id="call_1", + type="function", + function=FunctionCall(name="get_weather", arguments={"city": "Beijing"}), + ) + ] + + agent = Agent( + llm_client=MockLLMClient(response_content="Let me check the weather.", tool_calls=tool_calls), + system_prompt="You are a helpful assistant.", + tools=[], + max_steps=5, + ) + agent.add_user_message("What's the weather in Beijing?") + + events = [e async for e in agent.run_streaming()] + + event_types = [e.type for e in events] + + # First: content from first LLM response + assert "content" in event_types + # Then: tool_call + assert "tool_call" in event_types + # Then: tool_result (from the mock tool) + assert "tool_result" in event_types + # Then: step_complete + assert "step_complete" in event_types + # Finally: final + assert "final" in event_types + + # Verify ordering + content_idx = event_types.index("content") + tool_call_idx = event_types.index("tool_call") + tool_result_idx = event_types.index("tool_result") + step_complete_idx = event_types.index("step_complete") + + assert content_idx < tool_call_idx < tool_result_idx < step_complete_idx + + +@pytest.mark.asyncio +async def test_run_streaming_tool_call_event(): + """run_streaming() includes correct tool_call info in event.""" + from mini_agent.schema import FunctionCall, ToolCall + + tool_calls = [ + ToolCall( + id="call_abc", + type="function", + function=FunctionCall(name="get_stock_price", arguments={"symbol": "AAPL"}), + ) + ] + + agent = Agent( + llm_client=MockLLMClient(response_content="Getting stock price.", tool_calls=tool_calls), + system_prompt="You are a helpful assistant.", + tools=[], + max_steps=5, + ) + agent.add_user_message("Get AAPL price") + + events = [e async for e in agent.run_streaming()] + + tool_call_event = next(e for e in events if e.type == "tool_call") + assert tool_call_event.tool_call.id == "call_abc" + assert tool_call_event.tool_call.name == "get_stock_price" + assert tool_call_event.tool_call.arguments == {"symbol": "AAPL"} + + +@pytest.mark.asyncio +async def test_run_streaming_unknown_tool_yields_error_result(): + """run_streaming() handles unknown tools gracefully.""" + from mini_agent.schema import FunctionCall, ToolCall + + tool_calls = [ + ToolCall( + id="call_xyz", + type="function", + function=FunctionCall(name="nonexistent_tool", arguments={}), + ) + ] + + agent = Agent( + llm_client=MockLLMClient(response_content="Calling unknown tool.", tool_calls=tool_calls), + system_prompt="You are a helpful assistant.", + tools=[], # No tools registered + max_steps=5, + ) + agent.add_user_message("Use a tool that doesn't exist") + + events = [e async for e in agent.run_streaming()] + + tool_result_event = next(e for e in events if e.type == "tool_result") + assert tool_result_event.tool_result.success is False + assert "Unknown tool" in tool_result_event.tool_result.error + + +@pytest.mark.asyncio +async def test_run_streaming_cancel_via_event(): + """run_streaming() respects cancel_event to interrupt execution.""" + agent = Agent( + llm_client=MockLLMClient(response_content="Should not reach here"), + system_prompt="You are a helpful assistant.", + tools=[], + max_steps=5, + ) + agent.add_user_message("Long task") + + cancel_event = asyncio.Event() + cancel_event.set() # Cancel immediately + + events = [e async for e in agent.run_streaming(cancel_event=cancel_event)] + + assert any(e.type == "final" for e in events) + final_event = next(e for e in events if e.type == "final") + assert "cancelled" in final_event.final_text.lower() + + +@pytest.mark.asyncio +async def test_run_streaming_max_steps_yields_final(): + """run_streaming() yields final with max steps message when limit reached.""" + from mini_agent.schema import FunctionCall, ToolCall + + # Infinite loop: LLM always requests the same tool + tool_calls = [ + ToolCall( + id="call_loop", + type="function", + function=FunctionCall(name="dummy_tool", arguments={}), + ) + ] + + class LoopingLLM(MockLLMClient): + async def generate(self, messages, tools=None): + self.call_count += 1 + from mini_agent.schema import LLMResponse, TokenUsage + + class FakeResponse: + content = f"Step {self.call_count}" + thinking = None + tool_calls = tool_calls if self.call_count <= 3 else None + finish_reason = "tool_use" + usage = TokenUsage(prompt_tokens=10, completion_tokens=5, total_tokens=15) + + return FakeResponse() + + agent = Agent( + llm_client=LoopingLLM(), + system_prompt="You are a helpful assistant.", + tools=[], # No actual tools - will get "unknown tool" but loop continues + max_steps=3, + ) + agent.add_user_message("Keep calling tools forever") + + events = [e async for e in agent.run_streaming()] + + final_event = next(e for e in events if e.type == "final") + assert "couldn't be completed" in final_event.final_text.lower()