From 70e007c5eb56a6334831d6ace78a1d3ef463f471 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Fri, 17 Apr 2026 06:10:29 -0700 Subject: [PATCH] fix: Refactor AG2Agent and ADK templates to use environment variables for project/location. FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/6596 from googleapis:release-please--branches--main b82c8bd94986c311f843f9a81597404ceea4a319 PiperOrigin-RevId: 901257696 --- .../test_reasoning_engine_templates_ag2.py | 44 +++--- vertexai/agent_engines/templates/adk.py | 4 +- vertexai/agent_engines/templates/ag2.py | 59 ++++---- .../reasoning_engines/templates/adk.py | 5 +- .../reasoning_engines/templates/ag2.py | 141 +++++++++--------- 5 files changed, 131 insertions(+), 122 deletions(-) diff --git a/tests/unit/vertex_ag2/test_reasoning_engine_templates_ag2.py b/tests/unit/vertex_ag2/test_reasoning_engine_templates_ag2.py index 62145f6c94..42feac90e1 100644 --- a/tests/unit/vertex_ag2/test_reasoning_engine_templates_ag2.py +++ b/tests/unit/vertex_ag2/test_reasoning_engine_templates_ag2.py @@ -150,11 +150,11 @@ def test_initialization(self): agent = reasoning_engines.AG2Agent( model=_TEST_MODEL, runnable_name=_TEST_RUNNABLE_NAME ) - assert agent._model_name == _TEST_MODEL - assert agent._runnable_name == _TEST_RUNNABLE_NAME - assert agent._project == _TEST_PROJECT - assert agent._location == _TEST_LOCATION - assert agent._runnable is None + assert agent._tmpl_attrs["model_name"] == _TEST_MODEL + assert agent._tmpl_attrs["runnable_name"] == _TEST_RUNNABLE_NAME + assert agent._tmpl_attrs["project"] == _TEST_PROJECT + assert agent._tmpl_attrs["location"] == _TEST_LOCATION + assert agent._tmpl_attrs["runnable"] is None def test_initialization_with_tools(self, autogen_tools_mock): tools = [ @@ -168,12 +168,12 @@ def test_initialization_with_tools(self, autogen_tools_mock): tools=tools, runnable_builder=lambda **kwargs: kwargs, ) - assert agent._runnable is None - assert agent._tools - assert not agent._ag2_tool_objects + assert agent._tmpl_attrs["runnable"] is None + assert agent._tmpl_attrs["tools"] + assert not agent._tmpl_attrs["ag2_tool_objects"] agent.set_up() - assert agent._runnable is not None - assert agent._ag2_tool_objects + assert agent._tmpl_attrs["runnable"] is not None + assert agent._tmpl_attrs["ag2_tool_objects"] def test_set_up(self): agent = reasoning_engines.AG2Agent( @@ -181,9 +181,9 @@ def test_set_up(self): runnable_name=_TEST_RUNNABLE_NAME, runnable_builder=lambda **kwargs: kwargs, ) - assert agent._runnable is None + assert agent._tmpl_attrs["runnable"] is None agent.set_up() - assert agent._runnable is not None + assert agent._tmpl_attrs["runnable"] is not None def test_clone(self): agent = reasoning_engines.AG2Agent( @@ -192,26 +192,26 @@ def test_clone(self): runnable_builder=lambda **kwargs: kwargs, ) agent.set_up() - assert agent._runnable is not None + assert agent._tmpl_attrs["runnable"] is not None agent_clone = agent.clone() - assert agent._runnable is not None - assert agent_clone._runnable is None + assert agent._tmpl_attrs["runnable"] is not None + assert agent_clone._tmpl_attrs["runnable"] is None agent_clone.set_up() - assert agent_clone._runnable is not None + assert agent_clone._tmpl_attrs["runnable"] is not None def test_query(self, dataclasses_asdict_mock): agent = reasoning_engines.AG2Agent( model=_TEST_MODEL, runnable_name=_TEST_RUNNABLE_NAME, ) - agent._runnable = mock.Mock() + agent._tmpl_attrs["runnable"] = mock.Mock() mocks = mock.Mock() - mocks.attach_mock(mock=agent._runnable, attribute="run") + mocks.attach_mock(mock=agent._tmpl_attrs["runnable"], attribute="run") agent.query(input="test query") mocks.assert_has_calls( [ mock.call.run.run( - {"content": "test query"}, + message={"content": "test query"}, user_input=False, tools=[], max_turns=None, @@ -233,10 +233,10 @@ def test_enable_tracing( runnable_name=_TEST_RUNNABLE_NAME, enable_tracing=True, ) - assert agent._instrumentor is None + assert agent._tmpl_attrs["instrumentor"] is None # TODO(b/384730642): Re-enable this test once the parent issue is fixed. # agent.set_up() - # assert agent._instrumentor is not None + # assert agent._tmpl_attrs["instrumentor"] is not None # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text @pytest.mark.usefixtures("caplog") @@ -246,7 +246,7 @@ def test_enable_tracing_warning(self, caplog, autogen_instrumentor_none_mock): runnable_name=_TEST_RUNNABLE_NAME, enable_tracing=True, ) - assert agent._instrumentor is None + assert agent._tmpl_attrs["instrumentor"] is None # TODO(b/384730642): Re-enable this test once the parent issue is fixed. # agent.set_up() # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 2c26b6266b..3e47737a39 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -811,10 +811,10 @@ def set_up(self): # to disable bound token sharing. os.environ["GOOGLE_API_PREVENT_AGENT_TOKEN_SHARING_FOR_GCP_SERVICES"] = "false" # --- END BOUND TOKEN PATCH --- - project = self._tmpl_attrs.get("project") + project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project") if project: os.environ["GOOGLE_CLOUD_PROJECT"] = project - location = self._tmpl_attrs.get("location") + location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location") if location: if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location diff --git a/vertexai/agent_engines/templates/ag2.py b/vertexai/agent_engines/templates/ag2.py index a7261a7d5f..910575d780 100644 --- a/vertexai/agent_engines/templates/ag2.py +++ b/vertexai/agent_engines/templates/ag2.py @@ -23,6 +23,8 @@ Sequence, Union, ) +import os +import copy if TYPE_CHECKING: try: @@ -351,26 +353,10 @@ def __init__( "instrumentor": None, "instrumentor_builder": instrumentor_builder, "enable_tracing": enable_tracing, + "provided_llm_config": copy.deepcopy(llm_config), + "provided_runnable_kwargs": copy.deepcopy(runnable_kwargs), } - self._tmpl_attrs["llm_config"] = llm_config or { - "config_list": [ - { - "project_id": self._tmpl_attrs.get("project"), - "location": self._tmpl_attrs.get("location"), - "model": self._tmpl_attrs.get("model_name"), - "api_type": self._tmpl_attrs.get("api_type"), - } - ] - } - self._tmpl_attrs["runnable_kwargs"] = _prepare_runnable_kwargs( - runnable_kwargs=runnable_kwargs, - llm_config=self._tmpl_attrs.get("llm_config"), - system_instruction=self._tmpl_attrs.get("system_instruction"), - runnable_name=self._tmpl_attrs.get("runnable_name"), - ) if tools: - # We validate tools at initialization for actionable feedback before - # they are deployed. _validate_tools(tools) self._tmpl_attrs["tools"] = tools @@ -378,18 +364,38 @@ def set_up(self): """Sets up the agent for execution of queries at runtime. It initializes the runnable, binds the runnable with tools. - - This method should not be called for an object that being passed to - the ReasoningEngine service for deployment, as it initializes clients - that can not be serialized. + Project and Location are sourced from environment variables. """ + project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project") + location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location") + + llm_config = { + "config_list": [ + { + "project_id": project, + "location": location, + "model": self._tmpl_attrs.get("model_name"), + "api_type": self._tmpl_attrs.get("api_type"), + } + ] + } + if self._tmpl_attrs.get("provided_llm_config"): + llm_config = self._tmpl_attrs.get("provided_llm_config") + + runnable_kwargs = _prepare_runnable_kwargs( + runnable_kwargs=self._tmpl_attrs.get("provided_runnable_kwargs"), + llm_config=llm_config, + system_instruction=self._tmpl_attrs.get("system_instruction"), + runnable_name=self._tmpl_attrs.get("runnable_name"), + ) + if self._tmpl_attrs.get("enable_tracing"): instrumentor_builder = ( self._tmpl_attrs.get("instrumentor_builder") or _default_instrumentor_builder ) self._tmpl_attrs["instrumentor"] = instrumentor_builder( - project_id=self._tmpl_attrs.get("project") + project_id=project, ) # Set up tools. @@ -408,21 +414,20 @@ def set_up(self): self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder ) self._tmpl_attrs["runnable"] = runnable_builder( - **self._tmpl_attrs.get("runnable_kwargs") + **runnable_kwargs ) def clone(self) -> "AG2Agent": """Returns a clone of the AG2Agent.""" - import copy return AG2Agent( model=self._tmpl_attrs.get("model_name"), api_type=self._tmpl_attrs.get("api_type"), - llm_config=copy.deepcopy(self._tmpl_attrs.get("llm_config")), + llm_config=copy.deepcopy(self._tmpl_attrs.get("provided_llm_config")), system_instruction=self._tmpl_attrs.get("system_instruction"), runnable_name=self._tmpl_attrs.get("runnable_name"), tools=copy.deepcopy(self._tmpl_attrs.get("tools")), - runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("runnable_kwargs")), + runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("provided_runnable_kwargs")), runnable_builder=self._tmpl_attrs.get("runnable_builder"), enable_tracing=self._tmpl_attrs.get("enable_tracing"), instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 4a3a62e3bb..03e4da4727 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -725,9 +725,8 @@ def set_up(self): from google.adk.memory.in_memory_memory_service import InMemoryMemoryService os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" - project = self._tmpl_attrs.get("project") - os.environ["GOOGLE_CLOUD_PROJECT"] = project - location = self._tmpl_attrs.get("location") + project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project") + location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location") if location: if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location diff --git a/vertexai/preview/reasoning_engines/templates/ag2.py b/vertexai/preview/reasoning_engines/templates/ag2.py index 3e194449f9..e7730b0401 100644 --- a/vertexai/preview/reasoning_engines/templates/ag2.py +++ b/vertexai/preview/reasoning_engines/templates/ag2.py @@ -23,6 +23,8 @@ Sequence, Union, ) +import os +import copy if TYPE_CHECKING: try: @@ -250,53 +252,55 @@ def __init__( """ from google.cloud.aiplatform import initializer - # Set up llm config. - self._project = initializer.global_config.project - self._location = initializer.global_config.location - self._model_name = model or "gemini-1.0-pro-001" - self._api_type = api_type or "google" - self._llm_config = llm_config or { - "config_list": [ - { - "project_id": self._project, - "location": self._location, - "model": self._model_name, - "api_type": self._api_type, - } - ] + self._tmpl_attrs: dict[str, Any] = { + "project": initializer.global_config.project, + "location": initializer.global_config.location, + "model_name": model, + "api_type": api_type or "google", + "system_instruction": system_instruction, + "runnable_name": runnable_name, + "tools": [], + "ag2_tool_objects": [], + "runnable": None, + "runnable_builder": runnable_builder, + "instrumentor": None, + "enable_tracing": enable_tracing, + "provided_llm_config": copy.deepcopy(llm_config), + "provided_runnable_kwargs": copy.deepcopy(runnable_kwargs), } - self._system_instruction = system_instruction - self._runnable_name = runnable_name - self._runnable_kwargs = _prepare_runnable_kwargs( - runnable_kwargs=runnable_kwargs, - llm_config=self._llm_config, - system_instruction=self._system_instruction, - runnable_name=self._runnable_name, - ) - - self._tools = [] if tools: - # We validate tools at initialization for actionable feedback before - # they are deployed. _validate_tools(tools) - self._tools = tools - self._ag2_tool_objects = [] - self._runnable = None - self._runnable_builder = runnable_builder - - self._instrumentor = None - self._enable_tracing = enable_tracing + self._tmpl_attrs["tools"] = tools def set_up(self): """Sets up the agent for execution of queries at runtime. It initializes the runnable, binds the runnable with tools. - - This method should not be called for an object that being passed to - the ReasoningEngine service for deployment, as it initializes clients - that can not be serialized. + Project and Location are sourced from environment variables. """ - if self._enable_tracing: + project = os.environ.get("GOOGLE_CLOUD_PROJECT") or self._tmpl_attrs.get("project") + location = os.environ.get("GOOGLE_CLOUD_LOCATION") or self._tmpl_attrs.get("location") + + llm_config = { + "config_list": [ + { + "project_id": project, + "location": location, + "model": self._tmpl_attrs.get("model_name"), + "api_type": self._tmpl_attrs.get("api_type"), + } + ] + } + if self._tmpl_attrs.get("provided_llm_config"): + llm_config = self._tmpl_attrs.get("provided_llm_config") + + runnable_kwargs = _prepare_runnable_kwargs( + runnable_kwargs=self._tmpl_attrs.get("provided_runnable_kwargs"), + llm_config=llm_config, + system_instruction=self._tmpl_attrs.get("system_instruction"), + runnable_name=self._tmpl_attrs.get("runnable_name"), + ) + if self._tmpl_attrs.get("enable_tracing"): from vertexai.reasoning_engines import _utils cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn() @@ -317,9 +321,9 @@ def set_up(self): credentials, _ = google.auth.default() span_exporter = cloud_trace_exporter.CloudTraceSpanExporter( - project_id=self._project, + project_id=project, client=cloud_trace_v2.TraceServiceClient( - credentials=credentials.with_quota_project(self._project), + credentials=credentials.with_quota_project(project), ), ) span_processor: SpanProcessor = ( @@ -381,34 +385,35 @@ def set_up(self): ) # Set up tools. - if self._tools and not self._ag2_tool_objects: + tools = self._tmpl_attrs.get("tools") + ag2_tool_objects = self._tmpl_attrs.get("ag2_tool_objects") + if tools and not ag2_tool_objects: from vertexai.reasoning_engines import _utils autogen_tools = _utils._import_autogen_tools_or_warn() if autogen_tools: - for tool in self._tools: - self._ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool)) + for tool in tools: + ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool)) - # Set up runnable. - runnable_builder = self._runnable_builder or _default_runnable_builder - self._runnable = runnable_builder( - **self._runnable_kwargs, + runnable_builder = ( + self._tmpl_attrs.get("runnable_builder") or _default_runnable_builder + ) + self._tmpl_attrs["runnable"] = runnable_builder( + **runnable_kwargs ) def clone(self) -> "AG2Agent": """Returns a clone of the AG2Agent.""" - import copy - return AG2Agent( - model=self._model_name, - api_type=self._api_type, - llm_config=copy.deepcopy(self._llm_config), - system_instruction=self._system_instruction, - runnable_name=self._runnable_name, - tools=copy.deepcopy(self._tools), - runnable_kwargs=copy.deepcopy(self._runnable_kwargs), - runnable_builder=self._runnable_builder, - enable_tracing=self._enable_tracing, + model=self._tmpl_attrs.get("model_name"), + api_type=self._tmpl_attrs.get("api_type"), + llm_config=copy.deepcopy(self._tmpl_attrs.get("provided_llm_config")), + system_instruction=self._tmpl_attrs.get("system_instruction"), + runnable_name=self._tmpl_attrs.get("runnable_name"), + tools=copy.deepcopy(self._tmpl_attrs.get("tools")), + runnable_kwargs=copy.deepcopy(self._tmpl_attrs.get("provided_runnable_kwargs")), + runnable_builder=self._tmpl_attrs.get("runnable_builder"), + enable_tracing=self._tmpl_attrs.get("enable_tracing"), ) def query( @@ -456,21 +461,21 @@ def query( ) kwargs.pop("user_input") - if not self._runnable: + if not self._tmpl_attrs.get("runnable"): self.set_up() + response = self._tmpl_attrs.get("runnable").run( + message=input, + user_input=False, + tools=self._tmpl_attrs.get("ag2_tool_objects"), + max_turns=max_turns, + **kwargs, + ) + from vertexai.reasoning_engines import _utils # `.run()` will return a `ChatResult` object, which is a dataclass. # We need to convert it to a JSON-serializable object. # More details of `ChatResult` can be found in # https://docs.ag2.ai/docs/api-reference/autogen/ChatResult. - return _utils.dataclass_to_dict( - self._runnable.run( - input, - user_input=False, - tools=self._ag2_tool_objects, - max_turns=max_turns, - **kwargs, - ) - ) + return _utils.dataclass_to_dict(response)