diff --git a/google/genai/_api_client.py b/google/genai/_api_client.py index bb9865cc4..2f1537014 100644 --- a/google/genai/_api_client.py +++ b/google/genai/_api_client.py @@ -21,6 +21,7 @@ import asyncio from collections.abc import Generator +import contextlib import copy from dataclasses import dataclass import inspect @@ -890,6 +891,16 @@ def __del__(self, _warnings: Any = warnings) -> None: return self._aiohttp_session # type: ignore[return-value] + async def _reset_aiohttp_session(self) -> None: + """Closes the internal aiohttp session so a fresh one can be created.""" + if self._aiohttp_session is None or self._http_options.aiohttp_client: + self._aiohttp_session = None + return + + with contextlib.suppress(Exception): + await self._aiohttp_session.close() + self._aiohttp_session = None + @staticmethod def _ensure_httpx_ssl_ctx( options: HttpOptions, @@ -1388,6 +1399,9 @@ async def _async_request_once( self._async_client_session_request_args = ( self._ensure_aiohttp_ssl_ctx(self._http_options) ) + # Reset the current session before retrying so the mTLS path does not + # reuse a broken AsyncAuthorizedSession or leak an unclosed one. + await self._reset_aiohttp_session() # Instantiate a new session with the updated SSL context. self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment] response = await self._aiohttp_session.request( # type: ignore[union-attr] @@ -1465,6 +1479,9 @@ async def _async_request_once( self._async_client_session_request_args = ( self._ensure_aiohttp_ssl_ctx(self._http_options) ) + # Reset the current session before retrying so the mTLS path does not + # reuse a broken AsyncAuthorizedSession or leak an unclosed one. + await self._reset_aiohttp_session() # Instantiate a new session with the updated SSL context. self._aiohttp_session = await self._get_aiohttp_session() # type: ignore[assignment] response = await self._aiohttp_session.request( # type: ignore[union-attr] diff --git a/google/genai/models.py b/google/genai/models.py index b05985cc9..44ddb5ff3 100644 --- a/google/genai/models.py +++ b/google/genai/models.py @@ -3429,10 +3429,15 @@ def _Part_to_mldev( ) if getv(from_object, ['function_response']) is not None: + function_response = getv(from_object, ['function_response']) + if isinstance(function_response, dict): + function_response = types.FunctionResponse.model_validate( + function_response + ) setv( to_object, ['functionResponse'], - getv(from_object, ['function_response']), + function_response.model_dump(by_alias=True, exclude_none=True, mode='json'), ) if getv(from_object, ['inline_data']) is not None: @@ -3500,10 +3505,15 @@ def _Part_to_vertex( setv(to_object, ['functionCall'], getv(from_object, ['function_call'])) if getv(from_object, ['function_response']) is not None: + function_response = getv(from_object, ['function_response']) + if isinstance(function_response, dict): + function_response = types.FunctionResponse.model_validate( + function_response + ) setv( to_object, ['functionResponse'], - getv(from_object, ['function_response']), + function_response.model_dump(by_alias=True, exclude_none=True, mode='json'), ) if getv(from_object, ['inline_data']) is not None: diff --git a/google/genai/tests/client/test_retries.py b/google/genai/tests/client/test_retries.py index 49af5264b..1791eaa20 100644 --- a/google/genai/tests/client/test_retries.py +++ b/google/genai/tests/client/test_retries.py @@ -17,6 +17,7 @@ import asyncio from collections.abc import Sequence +import contextlib import datetime from unittest import mock import pytest @@ -1361,3 +1362,60 @@ async def run(): mock_request.assert_called() asyncio.run(run()) + + +@requires_aiohttp +@mock.patch.object(AsyncAuthorizedSession, 'close', autospec=True) +@mock.patch.object(AsyncAuthorizedSession, 'request', autospec=True) +def test_aiohttp_retries_client_connector_error_recreates_mtls_session( + mock_request, mock_close +): + api_client.has_aiohttp = True + + async def run(): + connector_error = aiohttp.ClientConnectorError( + connection_key=aiohttp.client_reqrep.ConnectionKey( + 'localhost', 80, False, True, None, None, None + ), + os_error=OSError, + ) + res200 = await _aiohttp_async_response(200) + mock_auth_res200 = mock.Mock(spec=AsyncAuthorizedSessionResponse) + mock_auth_res200._response = res200 + mock_request.side_effect = [connector_error, mock_auth_res200] + + sessions = [] + original_get_session = api_client.BaseApiClient._get_aiohttp_session + + async def tracking_get_session(self): + session = await original_get_session(self) + if session not in sessions: + sessions.append(session) + return session + + client = api_client.BaseApiClient( + vertexai=True, + project='test_project', + location='global', + ) + + with mock.patch.object( + api_client.BaseApiClient, + '_get_aiohttp_session', + new=tracking_get_session, + ): + with mock.patch( + 'google.auth.transport.mtls.should_use_client_cert', return_value=True + ): + with _patch_auth_default(): + response = await client.async_request( + http_method='GET', path='path', request_dict={} + ) + + assert response.headers['status-code'] == '200' + assert len(sessions) == 2 + mock_close.assert_awaited_once_with(sessions[0]) + with contextlib.suppress(Exception): + await client.aclose() + + asyncio.run(run()) diff --git a/google/genai/tests/types/test_types.py b/google/genai/tests/types/test_types.py index 099ff2a2d..4e636950b 100644 --- a/google/genai/tests/types/test_types.py +++ b/google/genai/tests/types/test_types.py @@ -22,6 +22,7 @@ import PIL.Image import pydantic import pytest +from ... import models from ... import types _is_mcp_imported = False @@ -115,6 +116,28 @@ def test_factory_method_part_from_function_response_with_multi_modal_parts(): assert isinstance(my_part, SubPart) +def test_part_to_mldev_serializes_function_response_display_name(): + my_part = SubPart.from_function_response( + name='get_image', + response={'image_ref': {'$ref': 'instrument.jpg'}}, + parts=[ + { + 'inline_data': { + 'data': b'123', + 'mime_type': 'image/jpeg', + 'display_name': 'instrument.jpg', + } + } + ], + ) + + serialized = models._Part_to_mldev(my_part) + + assert serialized['functionResponse']['parts'][0]['inlineData'][ + 'displayName' + ] == 'instrument.jpg' + + def test_factory_method_function_response_part_from_bytes(): my_part = SubFunctionResponsePart.from_bytes( data=b'123', mime_type='image/png'