Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions google/genai/chats.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def send_message(
response = self._modules.generate_content(
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
config=config if config is not None else self._config,
)
model_output = (
[response.candidates[0].content]
Expand Down Expand Up @@ -310,7 +310,7 @@ def send_message_stream(
for chunk in self._modules.generate_content_stream(
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
config=config if config is not None else self._config,
):
if not _validate_response(chunk):
is_valid = False
Expand Down Expand Up @@ -414,7 +414,7 @@ async def send_message(
response = await self._modules.generate_content(
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
config=config if config is not None else self._config,
)
model_output = (
[response.candidates[0].content]
Expand Down Expand Up @@ -473,7 +473,7 @@ async def async_generator(): # type: ignore[no-untyped-def]
async for chunk in await self._modules.generate_content_stream( # type: ignore[attr-defined]
model=self._model,
contents=self._curated_history + [input_content], # type: ignore[arg-type]
config=config if config else self._config,
config=config if config is not None else self._config,
):
if not _validate_response(chunk):
is_valid = False
Expand Down
128 changes: 128 additions & 0 deletions google/genai/tests/chats/test_send_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,17 @@
import json
import os
import sys
from unittest import mock

from pydantic import BaseModel
from pydantic import ValidationError
import pytest

from .. import pytest_helper
from ... import chats
from ... import client as client_module
from ... import errors
from ... import models
from ... import types

try:
Expand All @@ -49,6 +53,30 @@

MODEL_NAME = 'gemini-2.5-flash'


@pytest.fixture
def mock_api_client():
api_client = mock.MagicMock(spec=client_module.ApiClient)
api_client.api_key = 'TEST_API_KEY'
api_client._host = lambda: 'test_host'
api_client._http_options = {'headers': {}}
api_client.vertexai = False
return api_client


def _valid_chat_response() -> types.GenerateContentResponse:
return types.GenerateContentResponse(
candidates=[
types.Candidate(
content=types.Content(
role='model',
parts=[types.Part.from_text(text='mock response')],
),
finish_reason=types.FinishReason.STOP,
)
]
)

def divide_intergers_with_customized_math_rule(
numerator: int, denominator: int
) -> int:
Expand Down Expand Up @@ -741,6 +769,106 @@ async def test_async_stream_function_calling(client):
}


def test_send_message_preserves_zero_temperature_override(mock_api_client):
seen_configs = []

def mock_generate_content(*, model, contents, config):
seen_configs.append(config)
return _valid_chat_response()

with mock.patch.object(
models.Models, 'generate_content', side_effect=mock_generate_content
):
chat = chats.Chats(modules=models.Models(mock_api_client)).create(
model=MODEL_NAME, config={'temperature': 1.0}
)
chat.send_message('first turn', config={'temperature': 0.0})
chat.send_message('second turn', config={'temperature': 0.0})

assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}]


def test_send_message_stream_preserves_zero_temperature_override(
mock_api_client,
):
seen_configs = []
response = _valid_chat_response()

def mock_generate_content_stream(*, model, contents, config):
seen_configs.append(config)
return [response]

with mock.patch.object(
models.Models,
'generate_content_stream',
side_effect=mock_generate_content_stream,
):
chat = chats.Chats(modules=models.Models(mock_api_client)).create(
model=MODEL_NAME, config={'temperature': 1.0}
)
list(chat.send_message_stream('first turn', config={'temperature': 0.0}))
list(chat.send_message_stream('second turn', config={'temperature': 0.0}))

assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}]


@pytest.mark.asyncio
async def test_async_send_message_preserves_zero_temperature_override(
mock_api_client,
):
seen_configs = []

async def mock_generate_content(*, model, contents, config):
seen_configs.append(config)
return _valid_chat_response()

with mock.patch.object(
models.AsyncModels, 'generate_content', side_effect=mock_generate_content
):
chat = chats.AsyncChats(modules=models.AsyncModels(mock_api_client)).create(
model=MODEL_NAME, config={'temperature': 1.0}
)
await chat.send_message('first turn', config={'temperature': 0.0})
await chat.send_message('second turn', config={'temperature': 0.0})

assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}]


@pytest.mark.asyncio
async def test_async_send_message_stream_preserves_zero_temperature_override(
mock_api_client,
):
seen_configs = []
response = _valid_chat_response()

async def mock_generate_content_stream(*, model, contents, config):
seen_configs.append(config)

async def iterator():
yield response

return iterator()

with mock.patch.object(
models.AsyncModels,
'generate_content_stream',
side_effect=mock_generate_content_stream,
):
chat = chats.AsyncChats(modules=models.AsyncModels(mock_api_client)).create(
model=MODEL_NAME, config={'temperature': 1.0}
)
async for _ in await chat.send_message_stream(
'first turn', config={'temperature': 0.0}
):
pass
async for _ in await chat.send_message_stream(
'second turn', config={'temperature': 0.0}
):
pass

assert seen_configs == [{'temperature': 0.0}, {'temperature': 0.0}]


@pytest.mark.asyncio
async def test_async_stream_send_2_messages(client):
chat = client.aio.chats.create(model=MODEL_NAME)
Expand Down