From 47e32d6385dd02268ba52b2c9148a000ae5fba29 Mon Sep 17 00:00:00 2001 From: Jason Dai Date: Thu, 16 Apr 2026 17:21:50 -0700 Subject: [PATCH] chore: GenAI Client - Add replay tests for 17 RubricMetrics in evals SDK PiperOrigin-RevId: 900984771 --- google/genai/_replay_api_client.py | 145 +++++++++++++++++++++-------- 1 file changed, 108 insertions(+), 37 deletions(-) diff --git a/google/genai/_replay_api_client.py b/google/genai/_replay_api_client.py index 263cf2ac8..2add5eb35 100644 --- a/google/genai/_replay_api_client.py +++ b/google/genai/_replay_api_client.py @@ -24,6 +24,7 @@ import json import os import re +import threading from typing import Any, Literal, Optional, Union, Iterator, AsyncIterator import google.auth @@ -316,6 +317,9 @@ def __init__( self._mode = mode self._replay_id = replay_id self._private = private + self._lock = threading.Lock() + self._thread_local = threading.local() + self._used_interactions = set() def initialize_replay_session(self, replay_id: str) -> None: self._replay_id = replay_id @@ -342,6 +346,8 @@ def _initialize_replay_session_if_not_loaded(self) -> None: def _initialize_replay_session(self) -> None: _debug_print('Test is using replay id: ' + self._replay_id) self._replay_index = 0 + with self._lock: + self._used_interactions.clear() self._sdk_response_index = 0 replay_file_path = self._get_replay_file_path() # This should not be triggered from the constructor. @@ -427,24 +433,38 @@ def _record_interaction( ) if self.replay_session is None: raise ValueError('No replay session found.') - self.replay_session.interactions.append( - ReplayInteraction(request=request, response=response) - ) + with self._lock: + self.replay_session.interactions.append( + ReplayInteraction(request=request, response=response) + ) def _match_request( self, http_request: HttpRequest, interaction: ReplayInteraction, - ) -> None: + silent: bool = False, + ) -> bool: _debug_print(f'http_request.url: {http_request.url}') _debug_print(f'interaction.request.url: {interaction.request.url}') - assert http_request.url == interaction.request.url - assert http_request.headers == interaction.request.headers, ( - 'Request headers mismatch:\n' - f'Actual: {http_request.headers}\n' - f'Expected: {interaction.request.headers}' - ) - assert http_request.method == interaction.request.method + + if http_request.url != interaction.request.url: + if not silent: + assert http_request.url == interaction.request.url + return False + + if http_request.headers != interaction.request.headers: + if not silent: + assert http_request.headers == interaction.request.headers, ( + 'Request headers mismatch:\n' + f'Actual: {http_request.headers}\n' + f'Expected: {interaction.request.headers}' + ) + return False + + if http_request.method != interaction.request.method: + if not silent: + assert http_request.method == interaction.request.method + return False # Sanitize the request body, rewrite any fields that vary. request_data_copy = copy.deepcopy(http_request.data) @@ -458,30 +478,82 @@ def _match_request( for segment in expected_request_body: if not isinstance(segment, bytes): _redact_request_body(segment) - assert _equals_ignore_key_case(actual_request_body, expected_request_body), ( - 'Request body mismatch:\n' - f'Actual: {actual_request_body}\n' - f'Expected: {expected_request_body}' - ) + + if not _equals_ignore_key_case(actual_request_body, expected_request_body): + if not silent: + assert _equals_ignore_key_case( + actual_request_body, expected_request_body + ), ( + 'Request body mismatch:\n' + f'Actual: {actual_request_body}\n' + f'Expected: {expected_request_body}' + ) + return False + + return True def _build_response_from_replay(self, http_request: HttpRequest) -> HttpResponse: redact_http_request(http_request) if self.replay_session is None: raise ValueError('No replay session found.') - interaction = self.replay_session.interactions[self._replay_index] - # Replay is on the right side of the assert so the diff makes more sense. - self._match_request(http_request, interaction) - self._replay_index += 1 - self._sdk_response_index = 0 - errors.APIError.raise_for_response(interaction.response) + + with self._lock: + found_interaction = None + found_index = -1 + total_interactions = len(self.replay_session.interactions) + + # Search for a matching interaction that hasn't been used yet. + # Start searching from _replay_index to maintain order when possible. + search_order = list(range(self._replay_index, total_interactions)) + list( + range(0, self._replay_index) + ) + + for i in search_order: + if i in self._used_interactions: + continue + interaction = self.replay_session.interactions[i] + if self._match_request(http_request, interaction, silent=True): + found_interaction = interaction + found_index = i + break + + if not found_interaction: + # If no match found, trigger the normal failure by matching against + # the next expected interaction in the sequence. + target_index = ( + self._replay_index + if self._replay_index < total_interactions + else 0 + ) + self._match_request( + http_request, self.replay_session.interactions[target_index] + ) + # Should not reach here if _match_request properly asserts. + raise AssertionError( + 'Request did not match any available interaction.' + ) + + self._used_interactions.add(found_index) + # Advance _replay_index if we just used it. + if found_index == self._replay_index: + while ( + self._replay_index < total_interactions + and self._replay_index in self._used_interactions + ): + self._replay_index += 1 + + self._thread_local.current_interaction = found_interaction + self._thread_local.sdk_response_index = 0 + + errors.APIError.raise_for_response(found_interaction.response) http_response = HttpResponse( - headers=interaction.response.headers, + headers=found_interaction.response.headers, response_stream=[ json.dumps(segment) - for segment in interaction.response.body_segments + for segment in found_interaction.response.body_segments ], - byte_stream=interaction.response.byte_segments, + byte_stream=found_interaction.response.byte_segments, ) if http_response.response_stream == ['{}']: http_response.response_stream = [""] @@ -492,18 +564,19 @@ def _verify_response(self, response_model: BaseModel) -> None: return if not self.replay_session: raise ValueError('No replay session found.') - # replay_index is advanced in _build_response_from_replay, so we need to -1. - interaction = self.replay_session.interactions[self._replay_index - 1] + + interaction = getattr(self._thread_local, 'current_interaction', None) + if not interaction: + # Fallback to the old behavior if not set on thread-local (unlikely). + interaction = self.replay_session.interactions[self._replay_index - 1] + + sdk_response_index = getattr(self._thread_local, 'sdk_response_index', 0) if self._should_update_replay(): if isinstance(response_model, list): response_model = response_model[0] sdk_response_response = getattr(response_model, 'sdk_http_response', None) - if response_model and ( - sdk_response_response is not None - ): - headers = getattr( - sdk_response_response, 'headers', None - ) + if response_model and (sdk_response_response is not None): + headers = getattr(sdk_response_response, 'headers', None) if headers: pop_undeterministic_headers(headers) interaction.response.sdk_response_segments.append( @@ -517,9 +590,7 @@ def _verify_response(self, response_model: BaseModel) -> None: f'response_model: {response_model.model_dump(exclude_none=True)}' ) actual = response_model.model_dump(exclude_none=True, mode='json') - expected = interaction.response.sdk_response_segments[ - self._sdk_response_index - ] + expected = interaction.response.sdk_response_segments[sdk_response_index] # The sdk_http_response.body has format in the string, need to get rid of # the format information before comparing. if isinstance(expected, dict): @@ -540,7 +611,7 @@ def _verify_response(self, response_model: BaseModel) -> None: ), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}' else: _debug_print(f'Expected SDK response mismatch:\nActual: {actual}\nExpected: {expected}') - self._sdk_response_index += 1 + self._thread_local.sdk_response_index = sdk_response_index + 1 def _request( self,