From 12d2f9ae35bdc3428de18c3d2947a3ee8c7b651e Mon Sep 17 00:00:00 2001 From: noahpodgurski Date: Mon, 6 Apr 2026 14:53:07 -0400 Subject: [PATCH 1/5] feat: implement exponential backoff retry strategy for transient errors --- pyproject.toml | 1 + src/mlpa/core/completions.py | 186 +++++++++++------- src/mlpa/core/utils.py | 136 ++++++++++++- src/tests/integration/test_user_signup_cap.py | 9 + src/tests/unit/test_completions.py | 73 +++++-- 5 files changed, 305 insertions(+), 100 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 154d940..48d22a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ dependencies = [ "sentry-sdk[fastapi]==2.42.0", "sqlalchemy==2.0.44", "tabulate==0.9.0", + "tenacity==8.5.0", "tiktoken==0.11.0", "uvicorn==0.35.0", "loguru==0.7.3", diff --git a/src/mlpa/core/completions.py b/src/mlpa/core/completions.py index 618014e..9c1deba 100644 --- a/src/mlpa/core/completions.py +++ b/src/mlpa/core/completions.py @@ -3,6 +3,7 @@ import httpx from fastapi import HTTPException +from tenacity import retry, stop_after_attempt, wait_exponential from mlpa.core.classes import AuthorizedChatRequest, LitellmRoutingSnapshot from mlpa.core.config import ( @@ -26,7 +27,10 @@ get_or_create_user, is_context_window_error, is_rate_limit_error, + litellm_request, + log_litellm_retry_attempt, raise_and_log, + should_retry_on_litellm_error, ) _RATE_LIMIT_REJECTION: dict[int, tuple[PrometheusRejectionReason, str]] = { @@ -38,6 +42,38 @@ } +@retry( + wait=wait_exponential(multiplier=1, min=1, max=4), + stop=stop_after_attempt(5), + retry=lambda state: ( + should_retry_on_litellm_error(state.outcome.exception()) + if state.outcome.failed + else False + ), + before_sleep=log_litellm_retry_attempt, + reraise=True, +) +async def _call_litellm_with_retry( + client: httpx.AsyncClient, + method: str, + url: str, + headers: dict, + json: dict, + timeout: float, + stream: bool = False, +): + """Helper to make LiteLLM calls with retry logic.""" + return await litellm_request( + client, + method, + url, + headers, + json=json, + timeout=timeout, + stream=stream, + ) + + def _parse_rate_limit_error(error_text: str, user: str) -> int | None: """ Parse error response to detect budget or rate limit errors. @@ -206,53 +242,16 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest): ) try: client = get_http_client() - async with client.stream( - "POST", - LITELLM_COMPLETIONS_URL, + response = await _call_litellm_with_retry( + client=client, + method="POST", + url=LITELLM_COMPLETIONS_URL, headers=LITELLM_COMPLETION_AUTH_HEADERS, json=body, timeout=env.STREAMING_TIMEOUT_SECONDS, - ) as response: - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - # Read the error response content for streaming responses - error_text_str = "" - try: - error_bytes = await e.response.aread() - error_text_str = error_bytes.decode("utf-8") if error_bytes else "" - except Exception: - pass - - if e.response.status_code in {400, 429}: - # Check for budget or rate limit errors - error_code = _parse_rate_limit_error( - error_text_str, authorized_chat_request.user - ) - if error_code in _RATE_LIMIT_REJECTION: - reason, _ = _RATE_LIMIT_REJECTION[error_code] - _record_rejection(authorized_chat_request, reason) - yield f'data: {{"error": {error_code}}}\n\n'.encode() - return - - # Context window exceeded: detect by error text or upstream 413 - if e.response.status_code == 413 or is_context_window_error( - error_text_str - ): - logger.warning( - f"Context window exceeded for user {authorized_chat_request.user}" - ) - _record_rejection( - authorized_chat_request, - PrometheusRejectionReason.PAYLOAD_TOO_LARGE, - ) - yield f'data: {{"error": {ERROR_CODE_REQUEST_TOO_LARGE}}}\n\n'.encode() - return - - # For other errors or if we couldn't parse the error - yield raise_and_log(e, True) - return - + stream=True, + ) + try: litellm_routing_snapshot = parse_litellm_routing_headers(response.headers) async for chunk in response.aiter_bytes(): @@ -342,8 +341,44 @@ async def stream_completion(authorized_chat_request: AuthorizedChatRequest): completion_tokens, ) result = PrometheusResult.SUCCESS + finally: + await response.aclose() except httpx.HTTPStatusError as e: if not streaming_started: + # Read the error response content for streaming responses + error_text_str = "" + try: + error_bytes = await e.response.aread() + error_text_str = error_bytes.decode("utf-8") if error_bytes else "" + except Exception: + pass + finally: + await e.response.aclose() + + if e.response.status_code in {400, 429}: + # Check for budget or rate limit errors + error_code = _parse_rate_limit_error( + error_text_str, authorized_chat_request.user + ) + if error_code in _RATE_LIMIT_REJECTION: + reason, _ = _RATE_LIMIT_REJECTION[error_code] + _record_rejection(authorized_chat_request, reason) + yield f'data: {{"error": {error_code}}}\n\n'.encode() + return + + # Context window exceeded: detect by error text or upstream 413 + if e.response.status_code == 413 or is_context_window_error(error_text_str): + logger.warning( + f"Context window exceeded for user {authorized_chat_request.user}" + ) + _record_rejection( + authorized_chat_request, + PrometheusRejectionReason.PAYLOAD_TOO_LARGE, + ) + yield f'data: {{"error": {ERROR_CODE_REQUEST_TOO_LARGE}}}\n\n'.encode() + return + + # For other errors or if we couldn't parse the error yield raise_and_log(e, True) else: logger.error(f"Upstream service returned an error: {e}") @@ -381,40 +416,14 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): ) try: client = get_http_client() - response = await client.post( - LITELLM_COMPLETIONS_URL, + response = await _call_litellm_with_retry( + client=client, + method="POST", + url=LITELLM_COMPLETIONS_URL, headers=LITELLM_COMPLETION_AUTH_HEADERS, json=body, + timeout=env.STREAMING_TIMEOUT_SECONDS, ) - try: - response.raise_for_status() - except httpx.HTTPStatusError as e: - error_text = e.response.text - if e.response.status_code in {400, 429}: - error_code = _parse_rate_limit_error( - error_text, authorized_chat_request.user - ) - if error_code in _RATE_LIMIT_REJECTION: - reason, retry_after = _RATE_LIMIT_REJECTION[error_code] - _record_rejection(authorized_chat_request, reason) - raise HTTPException( - status_code=429, - detail={"error": error_code}, - headers={"Retry-After": retry_after}, - ) - # Context window exceeded: detect by error text or upstream 413 - if e.response.status_code == 413 or is_context_window_error(error_text): - logger.warning( - f"Context window exceeded for user {authorized_chat_request.user}" - ) - _record_rejection( - authorized_chat_request, PrometheusRejectionReason.PAYLOAD_TOO_LARGE - ) - raise HTTPException( - status_code=413, - detail={"error": ERROR_CODE_REQUEST_TOO_LARGE}, - ) - raise_and_log(e) litellm_routing_snapshot = parse_litellm_routing_headers(response.headers) data = response.json() usage = data.get("usage", {}) @@ -474,6 +483,33 @@ async def get_completion(authorized_chat_request: AuthorizedChatRequest): ) result = PrometheusResult.SUCCESS return data + except httpx.HTTPStatusError as e: + error_text = e.response.text + if e.response.status_code in {400, 429}: + error_code = _parse_rate_limit_error( + error_text, authorized_chat_request.user + ) + if error_code in _RATE_LIMIT_REJECTION: + reason, retry_after = _RATE_LIMIT_REJECTION[error_code] + _record_rejection(authorized_chat_request, reason) + raise HTTPException( + status_code=429, + detail={"error": error_code}, + headers={"Retry-After": retry_after}, + ) + # Context window exceeded: detect by error text or upstream 413 + if e.response.status_code == 413 or is_context_window_error(error_text): + logger.warning( + f"Context window exceeded for user {authorized_chat_request.user}" + ) + _record_rejection( + authorized_chat_request, PrometheusRejectionReason.PAYLOAD_TOO_LARGE + ) + raise HTTPException( + status_code=413, + detail={"error": ERROR_CODE_REQUEST_TOO_LARGE}, + ) + raise_and_log(e) except HTTPException: raise except Exception as e: diff --git a/src/mlpa/core/utils.py b/src/mlpa/core/utils.py index dd72fba..60ea960 100644 --- a/src/mlpa/core/utils.py +++ b/src/mlpa/core/utils.py @@ -4,9 +4,11 @@ import time from functools import lru_cache +import httpx from fastapi import HTTPException from fxa.oauth import Client from jwtoxide import DecodingKey, ValidationOptions, decode, encode +from tenacity import retry, stop_after_attempt, wait_exponential from mlpa.core.classes import AssertionAuth, AttestationAuth from mlpa.core.config import ( @@ -19,6 +21,121 @@ from mlpa.core.pg_services.services import litellm_pg +def should_retry_on_litellm_error(exception: Exception) -> bool: + if isinstance(exception, httpx.HTTPStatusError): + status_code = exception.response.status_code + if status_code == 429: + return is_litellm_upstream_rate_limit(exception.response.text) + return status_code in {502, 503, 504} + return isinstance( + exception, + ( + httpx.TimeoutException, + httpx.ConnectError, + httpx.RemoteProtocolError, + ), + ) + + +def is_litellm_upstream_rate_limit(error_text: str) -> bool: + """Detect upstream LiteLLM throttling errors for retry.""" + if not error_text: + return False + return ( + "litellm.RateLimitError" in error_text + or '"status": "RESOURCE_EXHAUSTED"' in error_text + or '"type":"throttling_error"' in error_text + ) + + +def log_litellm_retry_attempt(retry_state) -> None: + exception = retry_state.outcome.exception() + if isinstance(exception, httpx.HTTPStatusError): + logger.warning( + "Retrying LiteLLM request: attempt " + f"{retry_state.attempt_number}, status_code=" + f"{exception.response.status_code}, " + f"next wait {retry_state.next_action.sleep}s" + ) + else: + logger.warning( + "Retrying LiteLLM request: attempt " + f"{retry_state.attempt_number}, error_type=" + f"{type(exception).__name__}, " + f"next wait {retry_state.next_action.sleep}s" + ) + + +async def litellm_request( + client: httpx.AsyncClient, + method: str, + url: str, + headers: dict, + params: dict | None = None, + json: dict | None = None, + timeout: float | None = None, + stream: bool = False, +): + if stream: + request = client.build_request(method, url, headers=headers, json=json) + if timeout is not None: + request.extensions["timeout"] = { + "connect": timeout, + "read": timeout, + "write": timeout, + "pool": timeout, + } + response = await client.send(request, stream=True) + else: + response = await client.request( + method, url, headers=headers, params=params, json=json, timeout=timeout + ) + + status_code = getattr(response, "status_code", None) + if isinstance(status_code, int) and status_code >= 400: + try: + await response.aread() + except (AttributeError, TypeError): + pass + if stream: + await response.aclose() + response.raise_for_status() + return response + + +@retry( + wait=wait_exponential(multiplier=1, min=1, max=4), + stop=stop_after_attempt(5), + retry=lambda state: ( + should_retry_on_litellm_error(state.outcome.exception()) + if state.outcome.failed + else False + ), + before_sleep=log_litellm_retry_attempt, + reraise=True, +) +async def litellm_request_with_retry( + client: httpx.AsyncClient, + method: str, + url: str, + headers: dict, + params: dict | None = None, + json: dict | None = None, + timeout: float | None = None, + stream: bool = False, +): + return await litellm_request( + client, + method, + url, + headers, + params=params, + json=json, + timeout=timeout, + stream=stream, + ) + + async def get_or_create_user(user_id: str): """Returns user info from LiteLLM, creating the user if they don't exist. Args: @@ -38,10 +155,13 @@ async def get_or_create_user(user_id: str): claimed_new_identity = False try: params = {"end_user_id": user_id} - response = await client.get( + + response = await litellm_request_with_retry( + client, + "GET", f"{env.LITELLM_API_BASE}/customer/info", - params=params, headers=LITELLM_MASTER_AUTH_HEADERS, + params=params, ) user = response.json() @@ -61,15 +181,19 @@ async def get_or_create_user(user_id: str): ) claimed_new_identity = newly_claimed - await client.post( + await litellm_request_with_retry( + client, + "POST", f"{env.LITELLM_API_BASE}/customer/new", - json={"user_id": user_id, "budget_id": budget_id}, headers=LITELLM_MASTER_AUTH_HEADERS, + json={"user_id": user_id, "budget_id": budget_id}, ) - response = await client.get( + response = await litellm_request_with_retry( + client, + "GET", f"{env.LITELLM_API_BASE}/customer/info", - params=params, headers=LITELLM_MASTER_AUTH_HEADERS, + params=params, ) created_user = response.json() diff --git a/src/tests/integration/test_user_signup_cap.py b/src/tests/integration/test_user_signup_cap.py index 314468f..a308ba5 100644 --- a/src/tests/integration/test_user_signup_cap.py +++ b/src/tests/integration/test_user_signup_cap.py @@ -59,6 +59,15 @@ async def post(self, url: str, json=None, headers=None): return _FakeResponse({}) + async def request( + self, method: str, url: str, params=None, json=None, headers=None + ): + if method.upper() == "GET": + return await self.get(url, params=params, headers=headers) + if method.upper() == "POST": + return await self.post(url, json=json, headers=headers) + return _FakeResponse({}) + def test_managed_cap_rejects_new_identities_and_s2s_bypasses( mocked_client_integration, mocker diff --git a/src/tests/unit/test_completions.py b/src/tests/unit/test_completions.py index 6c69aad..0d72100 100644 --- a/src/tests/unit/test_completions.py +++ b/src/tests/unit/test_completions.py @@ -7,7 +7,11 @@ from pytest_httpx import HTTPXMock, IteratorStream from mlpa.core.classes import AuthorizedChatRequest -from mlpa.core.completions import get_completion, stream_completion +from mlpa.core.completions import ( + _call_litellm_with_retry, + get_completion, + stream_completion, +) from mlpa.core.config import ( ERROR_CODE_BUDGET_LIMIT_EXCEEDED, ERROR_CODE_RATE_LIMIT_EXCEEDED, @@ -61,7 +65,7 @@ async def test_get_completion_success(mocker): # This mock will be the 'client' inside the 'async with' block mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response # Patch the shared HTTP client accessor to return our mock client mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") @@ -75,8 +79,8 @@ async def test_get_completion_success(mocker): # Assert: Verify the behavior and outcome # 1. Check that the httpx client was used to make the correct call - mock_client.post.assert_awaited_once() - _, call_kwargs = mock_client.post.call_args + mock_client.request.assert_awaited_once() + _, call_kwargs = mock_client.request.call_args sent_json = call_kwargs.get("json", {}) assert sent_json["model"] == SAMPLE_REQUEST.model assert sent_json["messages"] == SAMPLE_REQUEST.messages @@ -157,6 +161,35 @@ async def test_get_completion_success(mocker): assert result_data == SUCCESSFUL_CHAT_RESPONSE +async def test_call_litellm_with_retry_retries_on_upstream_429(mocker): + error_text = ( + '{"error":{"message":"litellm.RateLimitError: Vertex_aiException - ' + '[{\\"error\\":{\\"code\\":429,\\"message\\":\\"throttled\\",' + '\\"status\\":\\"RESOURCE_EXHAUSTED\\"}}]",' + '"type":"throttling_error","code":"429"}}' + ) + request = httpx.Request("POST", LITELLM_COMPLETIONS_URL) + response_429 = httpx.Response(429, request=request, content=error_text.encode()) + response_200 = httpx.Response(200, request=request, content=b'{"ok": true}') + + mock_client = AsyncMock() + mock_client.request = AsyncMock(side_effect=[response_429, response_200]) + + mocker.patch.object(_call_litellm_with_retry.retry, "sleep", new=AsyncMock()) + + response = await _call_litellm_with_retry( + mock_client, + "POST", + LITELLM_COMPLETIONS_URL, + {}, + {}, + 1.0, + ) + + assert response.status_code == 200 + assert mock_client.request.await_count == 2 + + async def test_get_completion_litellm_routing_with_fallback(mocker): mock_response = MagicMock() mock_response.json.return_value = SUCCESSFUL_CHAT_RESPONSE @@ -169,7 +202,7 @@ async def test_get_completion_litellm_routing_with_fallback(mocker): mock_response.raise_for_status.return_value = None mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mocker.patch("mlpa.core.completions.get_http_client", return_value=mock_client) mock_metrics = mocker.patch("mlpa.core.completions.metrics") @@ -202,7 +235,7 @@ async def test_get_completion_litellm_routing_skips_invalid_optional_headers(moc mock_response.raise_for_status.return_value = None mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mocker.patch("mlpa.core.completions.get_http_client", return_value=mock_client) mock_metrics = mocker.patch("mlpa.core.completions.metrics") @@ -221,7 +254,7 @@ async def test_get_completion_litellm_routing_skips_negative_duration_ms(mocker) mock_response.raise_for_status.return_value = None mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mocker.patch("mlpa.core.completions.get_http_client", return_value=mock_client) mock_metrics = mocker.patch("mlpa.core.completions.metrics") @@ -245,7 +278,7 @@ async def test_get_completion_http_error(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -280,13 +313,15 @@ async def test_get_completion_network_error(mocker): """ # Arrange: Mock httpx to raise a network error mock_client = AsyncMock() - mock_client.post.side_effect = httpx.TimeoutException("Connection timed out") + mock_client.request.side_effect = httpx.TimeoutException("Connection timed out") mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client mock_metrics = mocker.patch("mlpa.core.completions.metrics") mocker.patch.object(env, "MLPA_DEBUG", True) + mocker.patch.object(_call_litellm_with_retry.retry, "sleep", new=AsyncMock()) + # Act & Assert: Expect an HTTPException with pytest.raises(HTTPException) as exc_info: await get_completion(SAMPLE_REQUEST) @@ -458,7 +493,7 @@ async def test_get_completion_budget_limit_exceeded_429(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -512,7 +547,7 @@ async def test_get_completion_budget_limit_exceeded_400(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -558,7 +593,7 @@ async def test_get_completion_rate_limit_exceeded(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -604,7 +639,7 @@ async def test_get_completion_400_non_rate_limit_error(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client mocker.patch.object(env, "MLPA_DEBUG", False) @@ -634,7 +669,7 @@ async def test_get_completion_429_non_rate_limit_error(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -667,7 +702,7 @@ async def test_get_completion_context_window_exceeded(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -711,7 +746,7 @@ async def test_get_completion_429_invalid_json(mocker): mock_response.raise_for_status.side_effect = mock_http_status_error mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -1095,7 +1130,7 @@ async def test_get_completion_preserves_tools(mocker): mock_response.raise_for_status.return_value = None mock_client = AsyncMock() - mock_client.post.return_value = mock_response + mock_client.request.return_value = mock_response mock_get_client = mocker.patch("mlpa.core.completions.get_http_client") mock_get_client.return_value = mock_client @@ -1104,8 +1139,8 @@ async def test_get_completion_preserves_tools(mocker): await get_completion(request_with_tools) - mock_client.post.assert_awaited_once() - _, call_kwargs = mock_client.post.call_args + mock_client.request.assert_awaited_once() + _, call_kwargs = mock_client.request.call_args sent_json = call_kwargs.get("json", {}) assert sent_json["tools"] == tools From 21a398c929d91ba96ae587cfeb2e6318b80687f0 Mon Sep 17 00:00:00 2001 From: noahpodgurski Date: Mon, 6 Apr 2026 15:06:00 -0400 Subject: [PATCH 2/5] better upstream rate limit JSON handling --- src/mlpa/core/utils.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/mlpa/core/utils.py b/src/mlpa/core/utils.py index 60ea960..c8f2bb0 100644 --- a/src/mlpa/core/utils.py +++ b/src/mlpa/core/utils.py @@ -41,10 +41,22 @@ def is_litellm_upstream_rate_limit(error_text: str) -> bool: """Detect upstream LiteLLM throttling errors for retry.""" if not error_text: return False + # Try to parse as JSON and check fields + try: + error_json = json.loads(error_text) + if ( + error_json.get("status") == "RESOURCE_EXHAUSTED" + or error_json.get("type") == "throttling_error" + ): + return True + except Exception: + pass + # Fallback to normalized string matching + normalized = error_text.replace(" ", "").lower() return ( - "litellm.RateLimitError" in error_text - or '"status": "RESOURCE_EXHAUSTED"' in error_text - or '"type":"throttling_error"' in error_text + "litellm.ratelimiterror" in normalized + or '"status":"resource_exhausted"' in normalized + or '"type":"throttling_error"' in normalized ) From 39db4383fa632a96487e5e84bdc7010f0aa30616 Mon Sep 17 00:00:00 2001 From: noahpodgurski Date: Mon, 6 Apr 2026 15:13:54 -0400 Subject: [PATCH 3/5] fix tests --- src/tests/integration/test_user_signup_cap.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tests/integration/test_user_signup_cap.py b/src/tests/integration/test_user_signup_cap.py index a308ba5..c312ab3 100644 --- a/src/tests/integration/test_user_signup_cap.py +++ b/src/tests/integration/test_user_signup_cap.py @@ -60,7 +60,7 @@ async def post(self, url: str, json=None, headers=None): return _FakeResponse({}) async def request( - self, method: str, url: str, params=None, json=None, headers=None + self, method: str, url: str, params=None, json=None, headers=None, timeout=None ): if method.upper() == "GET": return await self.get(url, params=params, headers=headers) From 1bcf1c007961dc6f2f6143a0ad6c3be9950123e4 Mon Sep 17 00:00:00 2001 From: noahpodgurski Date: Tue, 7 Apr 2026 15:14:27 -0400 Subject: [PATCH 4/5] return 500 as error code when upstream error occurs --- src/mlpa/core/completions.py | 5 +++++ src/mlpa/core/config.py | 1 + src/mlpa/core/prometheus_metrics.py | 1 + 3 files changed, 7 insertions(+) diff --git a/src/mlpa/core/completions.py b/src/mlpa/core/completions.py index 9c1deba..5931b5c 100644 --- a/src/mlpa/core/completions.py +++ b/src/mlpa/core/completions.py @@ -11,6 +11,7 @@ ERROR_CODE_MAX_USERS_REACHED, ERROR_CODE_RATE_LIMIT_EXCEEDED, ERROR_CODE_REQUEST_TOO_LARGE, + ERROR_CODE_UPSTREAM_ERROR, LITELLM_COMPLETION_AUTH_HEADERS, LITELLM_COMPLETIONS_URL, env, @@ -26,6 +27,7 @@ from mlpa.core.utils import ( get_or_create_user, is_context_window_error, + is_litellm_upstream_rate_limit, is_rate_limit_error, litellm_request, log_litellm_retry_attempt, @@ -39,6 +41,7 @@ "86400", ), ERROR_CODE_RATE_LIMIT_EXCEEDED: (PrometheusRejectionReason.RATE_LIMITED, "60"), + ERROR_CODE_UPSTREAM_ERROR: (PrometheusRejectionReason.UPSTREAM_ERROR, "60"), } @@ -90,6 +93,8 @@ def _parse_rate_limit_error(error_text: str, user: str) -> int | None: elif is_rate_limit_error(error_data, ["rate"]): logger.warning(f"Rate limit exceeded for user {user}: {error_text}") return ERROR_CODE_RATE_LIMIT_EXCEEDED + elif is_litellm_upstream_rate_limit(error_text): + return ERROR_CODE_UPSTREAM_ERROR except (json.JSONDecodeError, AttributeError, UnicodeDecodeError): pass diff --git a/src/mlpa/core/config.py b/src/mlpa/core/config.py index 3a9b62b..08b2d98 100644 --- a/src/mlpa/core/config.py +++ b/src/mlpa/core/config.py @@ -295,6 +295,7 @@ def __init__(self): ERROR_CODE_RATE_LIMIT_EXCEEDED: int = 2 ERROR_CODE_REQUEST_TOO_LARGE: int = 3 ERROR_CODE_MAX_USERS_REACHED: int = 4 +ERROR_CODE_UPSTREAM_ERROR: int = 500 RATE_LIMIT_ERROR_RESPONSE = { 429: { diff --git a/src/mlpa/core/prometheus_metrics.py b/src/mlpa/core/prometheus_metrics.py index cc5a8b6..5f6ce67 100644 --- a/src/mlpa/core/prometheus_metrics.py +++ b/src/mlpa/core/prometheus_metrics.py @@ -14,6 +14,7 @@ class PrometheusRejectionReason(StrEnum): RATE_LIMITED = "rate_limited" PAYLOAD_TOO_LARGE = "payload_too_large" SIGNUP_CAP_EXCEEDED = "signup_cap_exceeded" + UPSTREAM_ERROR = "upstream_error" BUCKETS_FAST_AUTH = (0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, float("inf")) From 22c24e19d0b8c9c857ceb1f12fa129c736183a9b Mon Sep 17 00:00:00 2001 From: noahpodgurski Date: Wed, 8 Apr 2026 09:18:29 -0400 Subject: [PATCH 5/5] address comments --- src/mlpa/core/config.py | 2 +- src/mlpa/core/utils.py | 13 +++++-------- src/tests/unit/test_completions.py | 28 ++++++++++++++++++++++++++++ 3 files changed, 34 insertions(+), 9 deletions(-) diff --git a/src/mlpa/core/config.py b/src/mlpa/core/config.py index 08b2d98..2952e13 100644 --- a/src/mlpa/core/config.py +++ b/src/mlpa/core/config.py @@ -295,7 +295,7 @@ def __init__(self): ERROR_CODE_RATE_LIMIT_EXCEEDED: int = 2 ERROR_CODE_REQUEST_TOO_LARGE: int = 3 ERROR_CODE_MAX_USERS_REACHED: int = 4 -ERROR_CODE_UPSTREAM_ERROR: int = 500 +ERROR_CODE_UPSTREAM_ERROR: int = 5 RATE_LIMIT_ERROR_RESPONSE = { 429: { diff --git a/src/mlpa/core/utils.py b/src/mlpa/core/utils.py index c8f2bb0..c9a55c7 100644 --- a/src/mlpa/core/utils.py +++ b/src/mlpa/core/utils.py @@ -62,19 +62,16 @@ def is_litellm_upstream_rate_limit(error_text: str) -> bool: def log_litellm_retry_attempt(retry_state) -> None: exception = retry_state.outcome.exception() + next_wait = getattr(retry_state.next_action, "sleep", "?") if isinstance(exception, httpx.HTTPStatusError): logger.warning( - "Retrying LiteLLM request: attempt " - f"{retry_state.attempt_number}, status_code=" - f"{exception.response.status_code}, " - f"next wait {retry_state.next_action.sleep}s" + f"Retrying LiteLLM request: attempt {retry_state.attempt_number}, " + f"status_code={exception.response.status_code}, next wait {next_wait}s" ) else: logger.warning( - "Retrying LiteLLM request: attempt " - f"{retry_state.attempt_number}, error_type=" - f"{type(exception).__name__}, " - f"next wait {retry_state.next_action.sleep}s" + f"Retrying LiteLLM request: attempt {retry_state.attempt_number}, " + f"error_type={type(exception).__name__}, next wait {next_wait}s" ) diff --git a/src/tests/unit/test_completions.py b/src/tests/unit/test_completions.py index 0d72100..be16310 100644 --- a/src/tests/unit/test_completions.py +++ b/src/tests/unit/test_completions.py @@ -190,6 +190,34 @@ async def test_call_litellm_with_retry_retries_on_upstream_429(mocker): assert mock_client.request.await_count == 2 +async def test_call_litellm_with_retry_exhausts_on_upstream_429(mocker): + error_text = ( + '{"error":{"message":"litellm.RateLimitError: Vertex_aiException - ' + '[{\\"error\\":{\\"code\\":429,\\"message\\":\\"throttled\\",' + '\\"status\\":\\"RESOURCE_EXHAUSTED\\"}}]",' + '"type":"throttling_error","code":"429"}}' + ) + request = httpx.Request("POST", LITELLM_COMPLETIONS_URL) + response_429 = httpx.Response(429, request=request, content=error_text.encode()) + + mock_client = AsyncMock() + mock_client.request = AsyncMock(side_effect=[response_429] * 5) + + mocker.patch.object(_call_litellm_with_retry.retry, "sleep", new=AsyncMock()) + + with pytest.raises(httpx.HTTPStatusError): + await _call_litellm_with_retry( + mock_client, + "POST", + LITELLM_COMPLETIONS_URL, + {}, + {}, + 1.0, + ) + + assert mock_client.request.await_count == 5 + + async def test_get_completion_litellm_routing_with_fallback(mocker): mock_response = MagicMock() mock_response.json.return_value = SUCCESSFUL_CHAT_RESPONSE