Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 196 additions & 2 deletions mini_agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
import json
from pathlib import Path
from time import perf_counter
from typing import Optional
from typing import AsyncGenerator, Optional

import tiktoken

from .llm import LLMClient
from .logger import AgentLogger
from .schema import Message
from .schema import AgentEvent, Message, ToolCallEvent, ToolResultEvent
from .tools.base import Tool, ToolResult
from .utils import calculate_display_width

Expand Down Expand Up @@ -518,6 +518,200 @@ async def run(self, cancel_event: Optional[asyncio.Event] = None) -> str:
print(f"\n{Colors.BRIGHT_YELLOW}⚠️ {error_msg}{Colors.RESET}")
return error_msg

async def run_streaming(
self, cancel_event: Optional[asyncio.Event] = None
) -> AsyncGenerator[AgentEvent, None]:
"""Stream agent execution as a series of structured events.

This is the streaming counterpart to run(). All core logic (LLM calls,
tool execution, token tracking, summarization, cancellation) is
identical to run(); only the output format differs.

Each yielded AgentEvent has a 'type' field that acts as a discriminator:
- "thinking" -> Agent is generating extended thinking (content/thinking set)
- "content" -> Agent is generating text output (content set)
- "tool_call" -> Agent requested a tool (tool_call set)
- "tool_result"-> Tool execution finished (tool_result set)
- "step_complete" -> Step finished (step set)
- "final" -> Agent finished (final_text set, total_tokens set)

Example consumer (SSE):
async for event in agent.run_streaming():
if event.type == "content":
yield f"data: {event.content}\\n\\n"
elif event.type == "final":
yield f"data: [DONE]\\n\\n"

Args:
cancel_event: Optional asyncio.Event to signal cancellation.
When set, the agent stops at the next safe checkpoint.

Yields:
AgentEvent objects describing each step of execution.
"""
if cancel_event is not None:
self.cancel_event = cancel_event

self.logger.start_new_run()

step = 0
final_text = ""

while step < self.max_steps:
# Check cancellation
if self._check_cancelled():
self._cleanup_incomplete_messages()
final_text = "Task cancelled by user."
yield AgentEvent(type="final", final_text=final_text)
return

# Summarize if needed
await self._summarize_messages()

tool_list = list(self.tools.values())

self.logger.log_request(messages=self.messages, tools=tool_list)

try:
response = await self.llm.generate(
messages=self.messages, tools=tool_list
)
except Exception as e:
from .retry import RetryExhaustedError

if isinstance(e, RetryExhaustedError):
final_text = (
f"LLM call failed after {e.attempts} retries. "
f"Last error: {str(e.last_exception)}"
)
else:
final_text = f"LLM call failed: {str(e)}"
yield AgentEvent(type="final", final_text=final_text)
return

if response.usage:
self.api_total_tokens = response.usage.total_tokens

self.logger.log_response(
content=response.content,
thinking=response.thinking,
tool_calls=response.tool_calls,
finish_reason=response.finish_reason,
)

# Stream thinking
if response.thinking:
yield AgentEvent(type="thinking", thinking=response.thinking)

# Stream content
if response.content:
yield AgentEvent(type="content", content=response.content)

# Add assistant message to history
assistant_msg = Message(
role="assistant",
content=response.content,
thinking=response.thinking,
tool_calls=response.tool_calls,
)
self.messages.append(assistant_msg)

# No tool calls -> done
if not response.tool_calls:
yield AgentEvent(
type="final",
final_text=response.content,
total_tokens=self.api_total_tokens,
)
return

# Check cancellation before tools
if self._check_cancelled():
self._cleanup_incomplete_messages()
final_text = "Task cancelled by user."
yield AgentEvent(type="final", final_text=final_text)
return

# Execute tools
for tool_call in response.tool_calls:
tc_id = tool_call.id
fn_name = tool_call.function.name
fn_args = tool_call.function.arguments

yield AgentEvent(
type="tool_call",
tool_call=ToolCallEvent(
id=tc_id,
name=fn_name,
arguments=fn_args,
),
)

# Tool lookup
if fn_name not in self.tools:
result = ToolResult(
success=False,
content="",
error=f"Unknown tool: {fn_name}",
)
else:
try:
tool = self.tools[fn_name]
result = await tool.execute(**fn_args)
except Exception as e:
import traceback

result = ToolResult(
success=False,
content="",
error=(
f"Tool execution failed: {type(e).__name__}: {str(e)}\n"
f"Traceback:\n{traceback.format_exc()}"
),
)

self.logger.log_tool_result(
tool_name=fn_name,
arguments=fn_args,
result_success=result.success,
result_content=result.content if result.success else None,
result_error=result.error if not result.success else None,
)

yield AgentEvent(
type="tool_result",
tool_result=ToolResultEvent(
id=tc_id,
name=fn_name,
success=result.success,
content=result.content,
error=result.error,
),
)

# Add tool result message
tool_msg = Message(
role="tool",
content=result.content if result.success else f"Error: {result.error}",
tool_call_id=tc_id,
name=fn_name,
)
self.messages.append(tool_msg)

# Check cancellation after each tool
if self._check_cancelled():
self._cleanup_incomplete_messages()
final_text = "Task cancelled by user."
yield AgentEvent(type="final", final_text=final_text)
return

yield AgentEvent(type="step_complete", step=step + 1)
step += 1

# Max steps
final_text = f"Task couldn't be completed after {self.max_steps} steps."
yield AgentEvent(type="final", final_text=final_text)

def get_history(self) -> list[Message]:
"""Get message history."""
return self.messages.copy()
6 changes: 6 additions & 0 deletions mini_agent/schema/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,25 @@
"""Schema definitions for Mini-Agent."""

from .schema import (
AgentEvent,
FunctionCall,
LLMProvider,
LLMResponse,
Message,
TokenUsage,
ToolCall,
ToolCallEvent,
ToolResultEvent,
)

__all__ = [
"AgentEvent",
"FunctionCall",
"LLMProvider",
"LLMResponse",
"Message",
"TokenUsage",
"ToolCall",
"ToolCallEvent",
"ToolResultEvent",
]
42 changes: 42 additions & 0 deletions mini_agent/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,45 @@ class LLMResponse(BaseModel):
tool_calls: list[ToolCall] | None = None
finish_reason: str
usage: TokenUsage | None = None # Token usage from API response


class ToolCallEvent(BaseModel):
"""A tool call event emitted before tool execution."""

id: str # tool_call id
name: str # function name
arguments: dict[str, Any] # function arguments


class ToolResultEvent(BaseModel):
"""A tool result event emitted after tool execution."""

id: str # tool_call id
name: str # function name
success: bool
content: str = ""
error: str | None = None


class AgentEvent(BaseModel):
"""Streaming event emitted by Agent.run_streaming()."""

# Event type discriminator
type: str # "thinking" | "content" | "tool_call" | "tool_result" | "step_complete" | "final"

# Text/think chunks
content: str | None = None
thinking: str | None = None

# Tool events
tool_call: ToolCallEvent | None = None
tool_result: ToolResultEvent | None = None

# Step metadata
step: int | None = None

# Final result
final_text: str | None = None

# Token usage
total_tokens: int | None = None
Loading