From 25c8e92df1d0ed99c23ab4db7eaadd3457fcc119 Mon Sep 17 00:00:00 2001 From: mukunda katta Date: Sun, 19 Apr 2026 17:31:39 -0700 Subject: [PATCH] fix(live): keep receive open across turns --- google/genai/live.py | 3 -- google/genai/tests/live/test_live_response.py | 37 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/google/genai/live.py b/google/genai/live.py index b35e918ce..526f38ed7 100644 --- a/google/genai/live.py +++ b/google/genai/live.py @@ -454,9 +454,6 @@ async def receive(self) -> AsyncIterator[types.LiveServerMessage]: """ # TODO(b/365983264) Handle intermittent issues for the user. while result := await self._receive(): - if result.server_content and result.server_content.turn_complete: - yield result - break yield result async def start_stream( diff --git a/google/genai/tests/live/test_live_response.py b/google/genai/tests/live/test_live_response.py index ab93b24c6..54643a4aa 100644 --- a/google/genai/tests/live/test_live_response.py +++ b/google/genai/tests/live/test_live_response.py @@ -161,3 +161,40 @@ async def test_receive_server_content_with_turn_reason(mock_websocket, vertexai) assert result.server_content.turn_complete is True assert result.server_content.turn_complete_reason == types.TurnCompleteReason.NEED_MORE_INPUT assert result.server_content.waiting_for_input is True + + +@pytest.mark.parametrize('vertexai', [True, False]) +@pytest.mark.asyncio +async def test_receive_continues_after_turn_complete(mock_websocket, vertexai): + session = live.AsyncSession( + api_client=mock_api_client(vertexai=vertexai), websocket=mock_websocket + ) + + first_turn_complete = types.LiveServerMessage( + server_content=types.LiveServerContent(turn_complete=True) + ) + second_turn_message = types.LiveServerMessage( + server_content=types.LiveServerContent( + model_turn=types.Content(parts=[types.Part(text='second turn')]) + ) + ) + second_turn_complete = types.LiveServerMessage( + server_content=types.LiveServerContent(turn_complete=True) + ) + + session._receive = AsyncMock( + side_effect=[ + first_turn_complete, + second_turn_message, + second_turn_complete, + None, + ] + ) + + messages = [message async for message in session.receive()] + + assert messages == [ + first_turn_complete, + second_turn_message, + second_turn_complete, + ]