Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
145 changes: 108 additions & 37 deletions google/genai/_replay_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import json
import os
import re
import threading
from typing import Any, Literal, Optional, Union, Iterator, AsyncIterator

import google.auth
Expand Down Expand Up @@ -316,6 +317,9 @@ def __init__(
self._mode = mode
self._replay_id = replay_id
self._private = private
self._lock = threading.Lock()
self._thread_local = threading.local()
self._used_interactions = set()

def initialize_replay_session(self, replay_id: str) -> None:
self._replay_id = replay_id
Expand All @@ -342,6 +346,8 @@ def _initialize_replay_session_if_not_loaded(self) -> None:
def _initialize_replay_session(self) -> None:
_debug_print('Test is using replay id: ' + self._replay_id)
self._replay_index = 0
with self._lock:
self._used_interactions.clear()
self._sdk_response_index = 0
replay_file_path = self._get_replay_file_path()
# This should not be triggered from the constructor.
Expand Down Expand Up @@ -427,24 +433,38 @@ def _record_interaction(
)
if self.replay_session is None:
raise ValueError('No replay session found.')
self.replay_session.interactions.append(
ReplayInteraction(request=request, response=response)
)
with self._lock:
self.replay_session.interactions.append(
ReplayInteraction(request=request, response=response)
)

def _match_request(
self,
http_request: HttpRequest,
interaction: ReplayInteraction,
) -> None:
silent: bool = False,
) -> bool:
_debug_print(f'http_request.url: {http_request.url}')
_debug_print(f'interaction.request.url: {interaction.request.url}')
assert http_request.url == interaction.request.url
assert http_request.headers == interaction.request.headers, (
'Request headers mismatch:\n'
f'Actual: {http_request.headers}\n'
f'Expected: {interaction.request.headers}'
)
assert http_request.method == interaction.request.method

if http_request.url != interaction.request.url:
if not silent:
assert http_request.url == interaction.request.url
return False

if http_request.headers != interaction.request.headers:
if not silent:
assert http_request.headers == interaction.request.headers, (
'Request headers mismatch:\n'
f'Actual: {http_request.headers}\n'
f'Expected: {interaction.request.headers}'
)
return False

if http_request.method != interaction.request.method:
if not silent:
assert http_request.method == interaction.request.method
return False

# Sanitize the request body, rewrite any fields that vary.
request_data_copy = copy.deepcopy(http_request.data)
Expand All @@ -458,30 +478,82 @@ def _match_request(
for segment in expected_request_body:
if not isinstance(segment, bytes):
_redact_request_body(segment)
assert _equals_ignore_key_case(actual_request_body, expected_request_body), (
'Request body mismatch:\n'
f'Actual: {actual_request_body}\n'
f'Expected: {expected_request_body}'
)

if not _equals_ignore_key_case(actual_request_body, expected_request_body):
if not silent:
assert _equals_ignore_key_case(
actual_request_body, expected_request_body
), (
'Request body mismatch:\n'
f'Actual: {actual_request_body}\n'
f'Expected: {expected_request_body}'
)
return False

return True

def _build_response_from_replay(self, http_request: HttpRequest) -> HttpResponse:
redact_http_request(http_request)

if self.replay_session is None:
raise ValueError('No replay session found.')
interaction = self.replay_session.interactions[self._replay_index]
# Replay is on the right side of the assert so the diff makes more sense.
self._match_request(http_request, interaction)
self._replay_index += 1
self._sdk_response_index = 0
errors.APIError.raise_for_response(interaction.response)

with self._lock:
found_interaction = None
found_index = -1
total_interactions = len(self.replay_session.interactions)

# Search for a matching interaction that hasn't been used yet.
# Start searching from _replay_index to maintain order when possible.
search_order = list(range(self._replay_index, total_interactions)) + list(
range(0, self._replay_index)
)

for i in search_order:
if i in self._used_interactions:
continue
interaction = self.replay_session.interactions[i]
if self._match_request(http_request, interaction, silent=True):
found_interaction = interaction
found_index = i
break

if not found_interaction:
# If no match found, trigger the normal failure by matching against
# the next expected interaction in the sequence.
target_index = (
self._replay_index
if self._replay_index < total_interactions
else 0
)
self._match_request(
http_request, self.replay_session.interactions[target_index]
)
# Should not reach here if _match_request properly asserts.
raise AssertionError(
'Request did not match any available interaction.'
)

self._used_interactions.add(found_index)
# Advance _replay_index if we just used it.
if found_index == self._replay_index:
while (
self._replay_index < total_interactions
and self._replay_index in self._used_interactions
):
self._replay_index += 1

self._thread_local.current_interaction = found_interaction
self._thread_local.sdk_response_index = 0

errors.APIError.raise_for_response(found_interaction.response)
http_response = HttpResponse(
headers=interaction.response.headers,
headers=found_interaction.response.headers,
response_stream=[
json.dumps(segment)
for segment in interaction.response.body_segments
for segment in found_interaction.response.body_segments
],
byte_stream=interaction.response.byte_segments,
byte_stream=found_interaction.response.byte_segments,
)
if http_response.response_stream == ['{}']:
http_response.response_stream = [""]
Expand All @@ -492,18 +564,19 @@ def _verify_response(self, response_model: BaseModel) -> None:
return
if not self.replay_session:
raise ValueError('No replay session found.')
# replay_index is advanced in _build_response_from_replay, so we need to -1.
interaction = self.replay_session.interactions[self._replay_index - 1]

interaction = getattr(self._thread_local, 'current_interaction', None)
if not interaction:
# Fallback to the old behavior if not set on thread-local (unlikely).
interaction = self.replay_session.interactions[self._replay_index - 1]

sdk_response_index = getattr(self._thread_local, 'sdk_response_index', 0)
if self._should_update_replay():
if isinstance(response_model, list):
response_model = response_model[0]
sdk_response_response = getattr(response_model, 'sdk_http_response', None)
if response_model and (
sdk_response_response is not None
):
headers = getattr(
sdk_response_response, 'headers', None
)
if response_model and (sdk_response_response is not None):
headers = getattr(sdk_response_response, 'headers', None)
if headers:
pop_undeterministic_headers(headers)
interaction.response.sdk_response_segments.append(
Expand All @@ -517,9 +590,7 @@ def _verify_response(self, response_model: BaseModel) -> None:
f'response_model: {response_model.model_dump(exclude_none=True)}'
)
actual = response_model.model_dump(exclude_none=True, mode='json')
expected = interaction.response.sdk_response_segments[
self._sdk_response_index
]
expected = interaction.response.sdk_response_segments[sdk_response_index]
# The sdk_http_response.body has format in the string, need to get rid of
# the format information before comparing.
if isinstance(expected, dict):
Expand All @@ -540,7 +611,7 @@ def _verify_response(self, response_model: BaseModel) -> None:
), f'SDK response mismatch:\nActual: {actual}\nExpected: {expected}'
else:
_debug_print(f'Expected SDK response mismatch:\nActual: {actual}\nExpected: {expected}')
self._sdk_response_index += 1
self._thread_local.sdk_response_index = sdk_response_index + 1

def _request(
self,
Expand Down
Loading