diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index 9a119c633..d4ab57f24 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -47,7 +47,7 @@ # Reconnection defaults DEFAULT_RECONNECTION_DELAY_MS = 1000 # 1 second fallback when server doesn't provide retry -MAX_RECONNECTION_ATTEMPTS = 2 # Max retry attempts before giving up +MAX_RECONNECTION_ATTEMPTS = 5 # Max retry attempts before giving up class StreamableHTTPError(Exception): @@ -380,7 +380,7 @@ async def _handle_reconnection( ) -> None: """Reconnect with Last-Event-ID to resume stream after server disconnect.""" # Bail if max retries exceeded - if attempt >= MAX_RECONNECTION_ATTEMPTS: # pragma: no cover + if attempt >= MAX_RECONNECTION_ATTEMPTS: logger.debug(f"Max reconnection attempts ({MAX_RECONNECTION_ATTEMPTS}) exceeded") return @@ -421,9 +421,9 @@ async def _handle_reconnection( await event_source.response.aclose() return - # Stream ended again without response - reconnect again (reset attempt counter) + # Stream ended again without response - reconnect again logger.info("SSE stream disconnected, reconnecting...") - await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, 0) + await self._handle_reconnection(ctx, reconnect_last_event_id, reconnect_retry_ms, attempt + 1) except Exception as e: # pragma: no cover logger.debug(f"Reconnection failed: {e}") # Try to reconnect again if we still have an event ID diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index 3d5770fb6..f26bc7921 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -14,7 +14,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass, field from typing import Any -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from urllib.parse import urlparse import anyio @@ -29,6 +29,7 @@ from mcp import MCPError, types from mcp.client.session import ClientSession +from mcp.client.streamable_http import RequestContext as HTTPRequestContext from mcp.client.streamable_http import StreamableHTTPTransport, streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.server.streamable_http import ( @@ -2318,3 +2319,64 @@ async def test_streamable_http_client_preserves_custom_with_mcp_headers( assert "content-type" in headers_data assert headers_data["content-type"] == "application/json" + + +@pytest.mark.anyio +async def test_reconnection_attempt_counter_increments_on_clean_disconnect( + event_server: tuple[SimpleEventStore, str], +) -> None: + """Verify that _handle_reconnection increments attempt on clean stream close. + + Previously, the attempt counter was reset to 0 on clean disconnect, causing + MAX_RECONNECTION_ATTEMPTS to be ineffective and allowing infinite reconnect + loops when a server repeatedly accepts then closes the stream without responding. + + With the fix (attempt+1), MAX_RECONNECTION_ATTEMPTS is respected for clean + disconnects too, and the client gives up after a bounded number of retries. + + Uses tool_with_multiple_stream_closes with more checkpoints than MAX_RECONNECTION_ATTEMPTS + so that the attempt counter is exercised all the way to the limit. + """ + import mcp.client.streamable_http as streamable_http_module + + _, server_url = event_server + + attempts_seen: list[int] = [] + original_handle_reconnection = StreamableHTTPTransport._handle_reconnection + + async def spy_handle_reconnection( + self: StreamableHTTPTransport, + ctx: HTTPRequestContext, + last_event_id: str, + retry_interval_ms: int | None = None, + attempt: int = 0, + ) -> None: + attempts_seen.append(attempt) + await original_handle_reconnection(self, ctx, last_event_id, retry_interval_ms, attempt) + + with ( + patch.object(streamable_http_module, "MAX_RECONNECTION_ATTEMPTS", 2), + patch.object(StreamableHTTPTransport, "_handle_reconnection", spy_handle_reconnection), + ): + with anyio.move_on_after(8): + async with streamable_http_client(f"{server_url}/mcp") as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + try: + # Use more checkpoints than MAX_RECONNECTION_ATTEMPTS=2. + # Each checkpoint closes then reopens the stream. With attempt+1, + # the counter increments and hits the limit after MAX attempts. + await session.call_tool( + "tool_with_multiple_stream_closes", + {"checkpoints": 3, "sleep_time": 0.6}, + ) + except Exception: # pragma: no cover + pass + + # With the fix: attempts seen are [0, 1, 2] — counter increments on each clean close. + # The bail at attempt=2 (>= MAX=2) covers the MAX_RECONNECTION_ATTEMPTS guard. + # Without the fix: attempts would repeat [0, 0, 0, ...] forever. + assert attempts_seen == [0, 1, 2], ( + f"Expected attempt counter to increment [0, 1, 2], got {attempts_seen}. " + "This indicates the reconnect counter is not incrementing on clean disconnects." + )