Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions tests/unit/vertex_ag2/test_reasoning_engine_templates_ag2.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def test_initialization(self):
agent = reasoning_engines.AG2Agent(
model=_TEST_MODEL, runnable_name=_TEST_RUNNABLE_NAME
)
assert agent._model_name == _TEST_MODEL
assert agent._runnable_name == _TEST_RUNNABLE_NAME
assert agent._project == _TEST_PROJECT
assert agent._location == _TEST_LOCATION
assert agent._runnable is None
assert agent._tmpl_attrs["model_name"] == _TEST_MODEL
assert agent._tmpl_attrs["runnable_name"] == _TEST_RUNNABLE_NAME
assert agent._tmpl_attrs["project"] == _TEST_PROJECT
assert agent._tmpl_attrs["location"] == _TEST_LOCATION
assert agent._tmpl_attrs["runnable"] is None

def test_initialization_with_tools(self, autogen_tools_mock):
tools = [
Expand All @@ -168,22 +168,22 @@ def test_initialization_with_tools(self, autogen_tools_mock):
tools=tools,
runnable_builder=lambda **kwargs: kwargs,
)
assert agent._runnable is None
assert agent._tools
assert not agent._ag2_tool_objects
assert agent._tmpl_attrs["runnable"] is None
assert agent._tmpl_attrs["tools"]
assert not agent._tmpl_attrs["ag2_tool_objects"]
agent.set_up()
assert agent._runnable is not None
assert agent._ag2_tool_objects
assert agent._tmpl_attrs["runnable"] is not None
assert agent._tmpl_attrs["ag2_tool_objects"]

def test_set_up(self):
agent = reasoning_engines.AG2Agent(
model=_TEST_MODEL,
runnable_name=_TEST_RUNNABLE_NAME,
runnable_builder=lambda **kwargs: kwargs,
)
assert agent._runnable is None
assert agent._tmpl_attrs["runnable"] is None
agent.set_up()
assert agent._runnable is not None
assert agent._tmpl_attrs["runnable"] is not None

def test_clone(self):
agent = reasoning_engines.AG2Agent(
Expand All @@ -192,26 +192,26 @@ def test_clone(self):
runnable_builder=lambda **kwargs: kwargs,
)
agent.set_up()
assert agent._runnable is not None
assert agent._tmpl_attrs["runnable"] is not None
agent_clone = agent.clone()
assert agent._runnable is not None
assert agent_clone._runnable is None
assert agent._tmpl_attrs["runnable"] is not None
assert agent_clone._tmpl_attrs["runnable"] is None
agent_clone.set_up()
assert agent_clone._runnable is not None
assert agent_clone._tmpl_attrs["runnable"] is not None

def test_query(self, dataclasses_asdict_mock):
agent = reasoning_engines.AG2Agent(
model=_TEST_MODEL,
runnable_name=_TEST_RUNNABLE_NAME,
)
agent._runnable = mock.Mock()
agent._tmpl_attrs["runnable"] = mock.Mock()
mocks = mock.Mock()
mocks.attach_mock(mock=agent._runnable, attribute="run")
mocks.attach_mock(mock=agent._tmpl_attrs["runnable"], attribute="run")
agent.query(input="test query")
mocks.assert_has_calls(
[
mock.call.run.run(
{"content": "test query"},
message={"content": "test query"},
user_input=False,
tools=[],
max_turns=None,
Expand All @@ -233,10 +233,10 @@ def test_enable_tracing(
runnable_name=_TEST_RUNNABLE_NAME,
enable_tracing=True,
)
assert agent._instrumentor is None
assert agent._tmpl_attrs["instrumentor"] is None
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
# agent.set_up()
# assert agent._instrumentor is not None
# assert agent._tmpl_attrs["instrumentor"] is not None
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text

@pytest.mark.usefixtures("caplog")
Expand All @@ -246,7 +246,7 @@ def test_enable_tracing_warning(self, caplog, autogen_instrumentor_none_mock):
runnable_name=_TEST_RUNNABLE_NAME,
enable_tracing=True,
)
assert agent._instrumentor is None
assert agent._tmpl_attrs["instrumentor"] is None
# TODO(b/384730642): Re-enable this test once the parent issue is fixed.
# agent.set_up()
# assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text
Expand Down
4 changes: 2 additions & 2 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,10 +811,10 @@ def set_up(self):
# to disable bound token sharing.
os.environ["GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"] = "false"
# --- END BOUND TOKEN PATCH ---
project = self._tmpl_attrs.get("project")
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
if project:
os.environ["GOOGLE_CLOUD_PROJECT"] = project
location = self._tmpl_attrs.get("location")
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")
if location:
if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ:
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location
Expand Down
59 changes: 32 additions & 27 deletions vertexai/agent_engines/templates/ag2.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
Sequence,
Union,
)
import os
import copy

if TYPE_CHECKING:
try:
Expand Down Expand Up @@ -351,45 +353,49 @@ def __init__(
"instrumentor": None,
"instrumentor_builder": instrumentor_builder,
"enable_tracing": enable_tracing,
"provided_llm_config": copy.deepcopy(llm_config),
"provided_runnable_kwargs": copy.deepcopy(runnable_kwargs),
}
self._tmpl_attrs["llm_config"] = llm_config or {
"config_list": [
{
"project_id": self._tmpl_attrs.get("project"),
"location": self._tmpl_attrs.get("location"),
"model": self._tmpl_attrs.get("model_name"),
"api_type": self._tmpl_attrs.get("api_type"),
}
]
}
self._tmpl_attrs["runnable_kwargs"] = _prepare_runnable_kwargs(
runnable_kwargs=runnable_kwargs,
llm_config=self._tmpl_attrs.get("llm_config"),
system_instruction=self._tmpl_attrs.get("system_instruction"),
runnable_name=self._tmpl_attrs.get("runnable_name"),
)
if tools:
# We validate tools at initialization for actionable feedback before
# they are deployed.
_validate_tools(tools)
self._tmpl_attrs["tools"] = tools

def set_up(self):
"""Sets up the agent for execution of queries at runtime.

It initializes the runnable, binds the runnable with tools.

This method should not be called for an object that being passed to
the ReasoningEngine service for deployment, as it initializes clients
that can not be serialized.
Project and Location are sourced from environment variables.
"""
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")

llm_config = {
"config_list": [
{
"project_id": project,
"location": location,
"model": self._tmpl_attrs.get("model_name"),
"api_type": self._tmpl_attrs.get("api_type"),
}
]
}
if self._tmpl_attrs.get("provided_llm_config"):
llm_config = self._tmpl_attrs.get("provided_llm_config")

runnable_kwargs = _prepare_runnable_kwargs(
runnable_kwargs=self._tmpl_attrs.get("provided_runnable_kwargs"),
llm_config=llm_config,
system_instruction=self._tmpl_attrs.get("system_instruction"),
runnable_name=self._tmpl_attrs.get("runnable_name"),
)

if self._tmpl_attrs.get("enable_tracing"):
instrumentor_builder = (
self._tmpl_attrs.get("instrumentor_builder")
or _default_instrumentor_builder
)
self._tmpl_attrs["instrumentor"] = instrumentor_builder(
project_id=self._tmpl_attrs.get("project")
project_id=project,
)

# Set up tools.
Expand All @@ -408,21 +414,20 @@ def set_up(self):
self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder
)
self._tmpl_attrs["runnable"] = runnable_builder(
**self._tmpl_attrs.get("runnable_kwargs")
**runnable_kwargs
)

def clone(self) -> "AG2Agent":
"""Returns a clone of the AG2Agent."""
import copy

return AG2Agent(
model=self._tmpl_attrs.get("model_name"),
api_type=self._tmpl_attrs.get("api_type"),
llm_config=copy.deepcopy(self._tmpl_attrs.get("llm_config")),
llm_config=copy.deepcopy(self._tmpl_attrs.get("provided_llm_config")),
system_instruction=self._tmpl_attrs.get("system_instruction"),
runnable_name=self._tmpl_attrs.get("runnable_name"),
tools=copy.deepcopy(self._tmpl_attrs.get("tools")),
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")),
runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("provided_runnable_kwargs")),
runnable_builder=self._tmpl_attrs.get("runnable_builder"),
enable_tracing=self._tmpl_attrs.get("enable_tracing"),
instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"),
Expand Down
5 changes: 2 additions & 3 deletions vertexai/preview/reasoning_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,9 +725,8 @@ def set_up(self):
from google.adk.memory.in_memory_memory_service import InMemoryMemoryService

os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1"
project = self._tmpl_attrs.get("project")
os.environ["GOOGLE_CLOUD_PROJECT"] = project
location = self._tmpl_attrs.get("location")
project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project")
location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location")
if location:
if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ:
os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location
Expand Down
Loading
Loading