Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 82 additions & 37 deletions src/strands/multiagent/swarm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from .._async import run_async
from ..agent import Agent
from ..agent.base import AgentBase
from ..agent.state import AgentState
from ..hooks.events import (
AfterMultiAgentInvocationEvent,
Expand Down Expand Up @@ -65,7 +66,7 @@ class SwarmNode:
"""Represents a node (e.g. Agent) in the swarm."""

node_id: str
executor: Agent
executor: AgentBase
swarm: Optional["Swarm"] = None
_initial_messages: Messages = field(default_factory=list, init=False)
_initial_state: AgentState = field(default_factory=AgentState, init=False)
Expand All @@ -74,9 +75,14 @@ class SwarmNode:
def __post_init__(self) -> None:
"""Capture initial executor state after initialization."""
# Deep copy the initial messages and state to preserve them
self._initial_messages = copy.deepcopy(self.executor.messages)
self._initial_state = AgentState(self.executor.state.get())
self._initial_model_state = copy.deepcopy(self.executor._model_state)
if hasattr(self.executor, "messages"):
self._initial_messages = copy.deepcopy(self.executor.messages)

if hasattr(self.executor, "state") and hasattr(self.executor.state, "get"):
self._initial_state = AgentState(self.executor.state.get())

if hasattr(self.executor, "_model_state"):
self._initial_model_state = copy.deepcopy(self.executor._model_state)

def __hash__(self) -> int:
"""Return hash for SwarmNode based on node_id."""
Expand All @@ -101,17 +107,26 @@ def reset_executor_state(self) -> None:

If Swarm is resuming from an interrupt, we reset the executor state from the interrupt context.
"""
if self.swarm and self.swarm._interrupt_state.activated:
# Handle interrupt state restoration (Agent-specific)
if self.swarm and self.swarm._interrupt_state.activated and isinstance(self.executor, Agent):
if self.node_id not in self.swarm._interrupt_state.context:
return
context = self.swarm._interrupt_state.context[self.node_id]
self.executor.messages = context["messages"]
self.executor.state = AgentState(context["state"])
self.executor._interrupt_state = _InterruptState.from_dict(context["interrupt_state"])
self.executor._model_state = context.get("model_state", {})
return

self.executor.messages = copy.deepcopy(self._initial_messages)
self.executor.state = AgentState(self._initial_state.get())
self.executor._model_state = copy.deepcopy(self._initial_model_state)
# Reset to initial state (works with any AgentBase that has these attributes)
if hasattr(self.executor, "messages"):
self.executor.messages = copy.deepcopy(self._initial_messages)

if hasattr(self.executor, "state"):
self.executor.state = AgentState(self._initial_state.get())

if hasattr(self.executor, "_model_state"):
self.executor._model_state = copy.deepcopy(self._initial_model_state)


@dataclass
Expand Down Expand Up @@ -236,9 +251,9 @@ class Swarm(MultiAgentBase):

def __init__(
self,
nodes: list[Agent],
nodes: list[AgentBase],
*,
entry_point: Agent | None = None,
entry_point: AgentBase | None = None,
max_handoffs: int = 20,
max_iterations: int = 20,
execution_timeout: float = 900.0,
Expand Down Expand Up @@ -301,6 +316,7 @@ def __init__(

self._resume_from_session = False

self._handoff_capable_nodes: set[str] = set()
self._setup_swarm(nodes)
self._inject_swarm_tools()
run_async(lambda: self.hooks.invoke_callbacks_async(MultiAgentInitializedEvent(self)))
Expand Down Expand Up @@ -462,33 +478,35 @@ async def _stream_with_timeout(
except asyncio.TimeoutError as err:
raise Exception(timeout_message) from err

def _setup_swarm(self, nodes: list[Agent]) -> None:
def _setup_swarm(self, nodes: list[AgentBase]) -> None:
"""Initialize swarm configuration."""
# Validate nodes before setup
self._validate_swarm(nodes)

# Validate agents have names and create SwarmNode objects
for i, node in enumerate(nodes):
if not node.name:
# Only access name if it exists (AgentBase protocol doesn't guarantee it)
node_name = getattr(node, "name", None)
if not node_name:
node_id = f"node_{i}"
node.name = node_id
logger.debug("node_id=<%s> | agent has no name, dynamically generating one", node_id)

node_id = str(node.name)
logger.debug("node_id=<%s> | agent has no name, using generated id", node_id)
else:
node_id = str(node_name)

# Ensure node IDs are unique
if node_id in self.nodes:
raise ValueError(f"Node ID '{node_id}' is not unique. Each agent must have a unique name.")

self.nodes[node_id] = SwarmNode(node_id, node, swarm=self)

# Validate entry point if specified
# Validate entry point if specified (use identity-based lookup to handle nameless AgentBase)
if self.entry_point is not None:
entry_point_node_id = str(self.entry_point.name)
if (
entry_point_node_id not in self.nodes
or self.nodes[entry_point_node_id].executor is not self.entry_point
):
entry_node = None
for swarm_node in self.nodes.values():
if swarm_node.executor is self.entry_point:
entry_node = swarm_node
break
if entry_node is None:
available_agents = [
f"{node_id} ({type(node.executor).__name__})" for node_id, node in self.nodes.items()
]
Expand All @@ -504,7 +522,7 @@ def _setup_swarm(self, nodes: list[Agent]) -> None:
first_node = next(iter(self.nodes.keys()))
logger.debug("entry_point=<%s> | using first node as entry point", first_node)

def _validate_swarm(self, nodes: list[Agent]) -> None:
def _validate_swarm(self, nodes: list[AgentBase]) -> None:
"""Validate swarm structure and nodes."""
# Check for duplicate object instances
seen_instances = set()
Expand All @@ -513,18 +531,31 @@ def _validate_swarm(self, nodes: list[Agent]) -> None:
raise ValueError("Duplicate node instance detected. Each node must have a unique object instance.")
seen_instances.add(id(node))

# Check for session persistence
if node._session_manager is not None:
# Check for session persistence (only Agent has _session_manager attribute)
if isinstance(node, Agent) and node._session_manager is not None:
raise ValueError("Session persistence is not supported for Swarm agents yet.")

def _inject_swarm_tools(self) -> None:
"""Add swarm coordination tools to each agent."""
"""Add swarm coordination tools to each agent.

Note: Only Agent instances can receive swarm tools. AgentBase implementations
without tool_registry will not have handoff capabilities.
"""
# Create tool functions with proper closures
swarm_tools = [
self._create_handoff_tool(),
]

injected_count = 0
for node in self.nodes.values():
# Only Agent (not generic AgentBase) has tool_registry attribute
if not isinstance(node.executor, Agent):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do you determine handoffs for non-Agent AgentBase instances then?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. Looking at the implementation: non-Agent AgentBase instances currently cannot initiate handoffs because _inject_swarm_tools (line 534-577) skips tool injection for nodes that aren't isinstance(node.executor, Agent). They can only be handed to by Agent nodes that do have the handoff tool.

This is a significant capability gap that should be clearly communicated. The _build_node_input method (line 706) also tells all nodes they have "access to swarm coordination tools" regardless of whether tools were injected, which could confuse LLM-backed AgentBase implementations. I've left a separate comment on that.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They can't. Only native Agent instances get the handoff_to_agent tool injected since AgentBase doesn't guarantee a tool_registry. AgentBase nodes can only be handed to by Agent nodes. This is intentional for the #1720 use case of integrating existing agents from other frameworks as handoff targets.

I also updated the prompt text to be conditional now. Agent nodes see "You have access to swarm coordination tools..." while AgentBase nodes just see "If you complete your task, the swarm will consider the task complete." so the LLM isn't told about tools it doesn't have.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Non-Agent nodes silently lose key Swarm capabilities with no user feedback.

When an AgentBase (non-Agent) node is skipped for tool injection, the user isn't informed at Swarm creation time that these nodes won't be able to initiate handoffs. The debug log on line 549 is only visible at DEBUG level. A user might expect their custom AgentBase to participate in handoffs and be surprised when it doesn't.

Suggestion:

  1. Add a logger.warning (not just debug) when non-Agent nodes are skipped, or at minimum a clear log after setup summarizing which nodes have handoff capabilities and which don't.
  2. Consider adding a note in the Swarm.__init__ docstring that AgentBase nodes without tool_registry will not have handoff capabilities and can only serve as entry/exit points or handoff targets.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed. _build_node_input prompt is now conditional on whether the node actually received handoff tools (tracked via a _handoff_capable_nodes set). Kept the tool injection skip at debug level since it's expected behavior, not an error.

logger.debug(
"node_id=<%s> | skipping tool injection for non-Agent node",
node.node_id,
)
continue

# Check for existing tools with conflicting names
existing_tools = node.executor.tool_registry.registry
conflicting_tools = []
Expand All @@ -540,11 +571,14 @@ def _inject_swarm_tools(self) -> None:

# Use the agent's tool registry to process and register the tools
node.executor.tool_registry.process_tools(swarm_tools)
self._handoff_capable_nodes.add(node.node_id)
injected_count += 1

logger.debug(
"tool_count=<%d>, node_count=<%d> | injected coordination tools into agents",
"tool_count=<%d>, node_count=<%d>, injected_count=<%d> | injected coordination tools",
len(swarm_tools),
len(self.nodes),
injected_count,
)

def _create_handoff_tool(self) -> Callable[..., Any]:
Expand Down Expand Up @@ -673,10 +707,13 @@ def _build_node_input(self, target_node: SwarmNode) -> str:
context_text += "\n"
context_text += "\n"

context_text += (
"You have access to swarm coordination tools if you need help from other agents. "
"If you don't hand off to another agent, the swarm will consider the task complete."
)
if target_node.node_id in self._handoff_capable_nodes:
context_text += (
"You have access to swarm coordination tools if you need help from other agents. "
"If you don't hand off to another agent, the swarm will consider the task complete."
)
else:
context_text += "If you complete your task, the swarm will consider the task complete."

return context_text

Expand All @@ -696,13 +733,19 @@ def _activate_interrupt(self, node: SwarmNode, interrupts: list[Interrupt]) -> M
logger.debug("node=<%s> | node interrupted", node.node_id)
self.state.completion_status = Status.INTERRUPTED

# Only Agent (not generic AgentBase) has _interrupt_state, state, and messages attributes
self._interrupt_state.context[node.node_id] = {
"activated": node.executor._interrupt_state.activated,
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
"model_state": node.executor._model_state,
"activated": isinstance(node.executor, Agent) and node.executor._interrupt_state.activated,
}
if isinstance(node.executor, Agent):
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Issue: Silent degradation for interrupt state on non-Agent nodes.

When an interrupt occurs on a non-Agent AgentBase node, no executor context is saved (the isinstance(node.executor, Agent) guard prevents it). This means if a user tries to resume from an interrupt on a non-Agent node, the resume behavior is undefined — reset_executor_state will try to use self.swarm._interrupt_state.context[self.node_id] (line 112) but that key won't exist, causing a KeyError.

Suggestion: Either:

  1. Raise a clear error if an interrupt is activated on a non-Agent node explaining that interrupt/resume is not supported for AgentBase implementations, or
  2. Save a minimal context for non-Agent nodes (e.g., just marking the interrupt state without Agent-specific attributes), or
  3. Add a guard in reset_executor_state to handle the case where no context exists for non-Agent nodes during interrupt resume.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed two things:

  1. _activate_interrupt now saves a base context entry for all nodes (matching Graph's pattern), with Agent-specific fields added conditionally.
  2. reset_executor_state has a guard for missing context keys to prevent KeyError on resume.

self._interrupt_state.context[node.node_id].update(
{
"interrupt_state": node.executor._interrupt_state.to_dict(),
"state": node.executor.state.get(),
"messages": node.executor.messages,
"model_state": node.executor._model_state,
}
)

self._interrupt_state.interrupts.update({interrupt.id: interrupt for interrupt in interrupts})
self._interrupt_state.activate()
Expand Down Expand Up @@ -1042,5 +1085,7 @@ def _from_dict(self, payload: dict[str, Any]) -> None:

def _initial_node(self) -> SwarmNode:
if self.entry_point:
return self.nodes[str(self.entry_point.name)]
for node in self.nodes.values():
if node.executor is self.entry_point:
return node
return next(iter(self.nodes.values())) # First SwarmNode
Loading