diff --git a/google/genai/chats.py b/google/genai/chats.py index 3d5e181e9..3aec1ffb9 100644 --- a/google/genai/chats.py +++ b/google/genai/chats.py @@ -252,7 +252,7 @@ def send_message( response = self._modules.generate_content( model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, + config=config if config is not None else self._config, ) model_output = ( [response.candidates[0].content] @@ -310,7 +310,7 @@ def send_message_stream( for chunk in self._modules.generate_content_stream( model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, + config=config if config is not None else self._config, ): if not _validate_response(chunk): is_valid = False @@ -414,7 +414,7 @@ async def send_message( response = await self._modules.generate_content( model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, + config=config if config is not None else self._config, ) model_output = ( [response.candidates[0].content] @@ -473,7 +473,7 @@ async def async_generator(): # type: ignore[no-untyped-def] async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined] model=self._model, contents=self._curated_history + [input_content], # type: ignore[arg-type] - config=config if config else self._config, + config=config if config is not None else self._config, ): if not _validate_response(chunk): is_valid = False diff --git a/google/genai/tests/chats/test_send_message.py b/google/genai/tests/chats/test_send_message.py index cee6bfe3c..0c4039d7e 100644 --- a/google/genai/tests/chats/test_send_message.py +++ b/google/genai/tests/chats/test_send_message.py @@ -16,13 +16,17 @@ import json import os import sys +from unittest import mock from pydantic import BaseModel from pydantic import ValidationError import pytest from .. import pytest_helper +from ... import chats +from ... import client as client_module from ... import errors +from ... import models from ... import types try: @@ -49,6 +53,30 @@ MODEL_NAME = 'gemini-2.5-flash' + +@pytest.fixture +def mock_api_client(): + api_client = mock.MagicMock(spec=client_module.ApiClient) + api_client.api_key = 'TEST_API_KEY' + api_client._host = lambda: 'test_host' + api_client._http_options = {'headers': {}} + api_client.vertexai = False + return api_client + + +def _valid_chat_response() -> types.GenerateContentResponse: + return types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=types.Content( + role='model', + parts=[types.Part.from_text(text='mock response')], + ), + finish_reason=types.FinishReason.STOP, + ) + ] + ) + def divide_intergers_with_customized_math_rule( numerator: int, denominator: int ) -> int: @@ -741,6 +769,106 @@ async def test_async_stream_function_calling(client): } +def test_send_message_preserves_zero_temperature_override(mock_api_client): + seen_configs = [] + + def mock_generate_content(*, model, contents, config): + seen_configs.append(config) + return _valid_chat_response() + + with mock.patch.object( + models.Models, 'generate_content', side_effect=mock_generate_content + ): + chat = chats.Chats(modules=models.Models(mock_api_client)).create( + model=MODEL_NAME, config={'temperature': 1.0} + ) + chat.send_message('first turn', config={'temperature': 0.0}) + chat.send_message('second turn', config={'temperature': 0.0}) + + assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}] + + +def test_send_message_stream_preserves_zero_temperature_override( + mock_api_client, +): + seen_configs = [] + response = _valid_chat_response() + + def mock_generate_content_stream(*, model, contents, config): + seen_configs.append(config) + return [response] + + with mock.patch.object( + models.Models, + 'generate_content_stream', + side_effect=mock_generate_content_stream, + ): + chat = chats.Chats(modules=models.Models(mock_api_client)).create( + model=MODEL_NAME, config={'temperature': 1.0} + ) + list(chat.send_message_stream('first turn', config={'temperature': 0.0})) + list(chat.send_message_stream('second turn', config={'temperature': 0.0})) + + assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}] + + +@pytest.mark.asyncio +async def test_async_send_message_preserves_zero_temperature_override( + mock_api_client, +): + seen_configs = [] + + async def mock_generate_content(*, model, contents, config): + seen_configs.append(config) + return _valid_chat_response() + + with mock.patch.object( + models.AsyncModels, 'generate_content', side_effect=mock_generate_content + ): + chat = chats.AsyncChats(modules=models.AsyncModels(mock_api_client)).create( + model=MODEL_NAME, config={'temperature': 1.0} + ) + await chat.send_message('first turn', config={'temperature': 0.0}) + await chat.send_message('second turn', config={'temperature': 0.0}) + + assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}] + + +@pytest.mark.asyncio +async def test_async_send_message_stream_preserves_zero_temperature_override( + mock_api_client, +): + seen_configs = [] + response = _valid_chat_response() + + async def mock_generate_content_stream(*, model, contents, config): + seen_configs.append(config) + + async def iterator(): + yield response + + return iterator() + + with mock.patch.object( + models.AsyncModels, + 'generate_content_stream', + side_effect=mock_generate_content_stream, + ): + chat = chats.AsyncChats(modules=models.AsyncModels(mock_api_client)).create( + model=MODEL_NAME, config={'temperature': 1.0} + ) + async for _ in await chat.send_message_stream( + 'first turn', config={'temperature': 0.0} + ): + pass + async for _ in await chat.send_message_stream( + 'second turn', config={'temperature': 0.0} + ): + pass + + assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}] + + @pytest.mark.asyncio async def test_async_stream_send_2_messages(client): chat = client.aio.chats.create(model=MODEL_NAME)