diff --git a/packages/toolbox-core/src/toolbox_core/client.py b/packages/toolbox-core/src/toolbox_core/client.py index 60f02bc27..7bb32cdf9 100644 --- a/packages/toolbox-core/src/toolbox_core/client.py +++ b/packages/toolbox-core/src/toolbox_core/client.py @@ -21,6 +21,8 @@ from aiohttp import ClientSession from deprecated import deprecated +from toolbox_core.exceptions import ProtocolNegotiationError + from . import version from .itransport import ITransport from .mcp_transport import ( @@ -28,6 +30,7 @@ McpHttpTransportV20250326, McpHttpTransportV20250618, McpHttpTransportV20251125, + McpHttpTransportV20260618, ) from .protocol import Protocol, ToolSchema from .tool import ToolboxTool @@ -39,6 +42,112 @@ ) +class _McpTransportProxy(ITransport): + """A proxy transport that transparently handles protocol fallback negotiation.""" + + def __init__( + self, + url: str, + session: Optional[ClientSession], + protocol: Protocol, + client_name: Optional[str], + client_version: Optional[str], + telemetry_enabled: bool, + ): + self._url = url + self._session = session + self._client_name = client_name + self._client_version = client_version + self._telemetry_enabled = telemetry_enabled + self._active_transport = self._create_transport(protocol) + + def _create_transport(self, protocol: Protocol) -> ITransport: + match protocol: + case Protocol.MCP_v20260618: + return McpHttpTransportV20260618( + self._url, + self._session, + protocol, + self._client_name, + self._client_version, + telemetry_enabled=self._telemetry_enabled, + ) + case Protocol.MCP_v20251125: + return McpHttpTransportV20251125( + self._url, + self._session, + protocol, + self._client_name, + self._client_version, + telemetry_enabled=self._telemetry_enabled, + ) + case Protocol.MCP_v20250618: + return McpHttpTransportV20250618( + self._url, + self._session, + protocol, + self._client_name, + self._client_version, + telemetry_enabled=self._telemetry_enabled, + ) + case Protocol.MCP_v20250326: + return McpHttpTransportV20250326( + self._url, + self._session, + protocol, + self._client_name, + self._client_version, + telemetry_enabled=self._telemetry_enabled, + ) + case Protocol.MCP_v20241105: + return McpHttpTransportV20241105( + self._url, + self._session, + protocol, + self._client_name, + self._client_version, + telemetry_enabled=self._telemetry_enabled, + ) + case _: + raise ValueError(f"Unsupported MCP protocol version: {protocol}") + + @property + def base_url(self) -> str: + return self._active_transport.base_url + + @property + def _protocol_version(self) -> str: + # We must expose this for tests asserting the current protocol version. + return getattr(self._active_transport, "_protocol_version", "") + + async def _execute_with_fallback( + self, method_name: str, *args: Any, **kwargs: Any + ) -> Any: + try: + return await getattr(self._active_transport, method_name)(*args, **kwargs) + except ProtocolNegotiationError as e: + fallback_protocol = Protocol(e.negotiated_version) + logging.warning( + f"Protocol fallback required. Switching from " + f"{self._protocol_version} to {fallback_protocol.value}" + ) + await self._active_transport.close() + self._active_transport = self._create_transport(fallback_protocol) + return await getattr(self._active_transport, method_name)(*args, **kwargs) + + async def tool_get(self, *args: Any, **kwargs: Any) -> Any: + return await self._execute_with_fallback("tool_get", *args, **kwargs) + + async def tools_list(self, *args: Any, **kwargs: Any) -> Any: + return await self._execute_with_fallback("tools_list", *args, **kwargs) + + async def tool_invoke(self, *args: Any, **kwargs: Any) -> Any: + return await self._execute_with_fallback("tool_invoke", *args, **kwargs) + + async def close(self) -> None: + await self._active_transport.close() + + class ToolboxClient: """ An asynchronous client for interacting with a Toolbox service. @@ -85,45 +194,14 @@ def __init__( "Please use Protocol.MCP_LATEST to use the latest features." ) - match protocol: - case Protocol.MCP_v20251125: - self.__transport = McpHttpTransportV20251125( - url, - session, - protocol, - client_name, - client_version, - telemetry_enabled=telemetry_enabled, - ) - case Protocol.MCP_v20250618: - self.__transport = McpHttpTransportV20250618( - url, - session, - protocol, - client_name, - client_version, - telemetry_enabled=telemetry_enabled, - ) - case Protocol.MCP_v20250326: - self.__transport = McpHttpTransportV20250326( - url, - session, - protocol, - client_name, - client_version, - telemetry_enabled=telemetry_enabled, - ) - case Protocol.MCP_v20241105: - self.__transport = McpHttpTransportV20241105( - url, - session, - protocol, - client_name, - client_version, - telemetry_enabled=telemetry_enabled, - ) - case _: - raise ValueError(f"Unsupported MCP protocol version: {protocol}") + self.__transport = _McpTransportProxy( + url, + session, + protocol, + client_name, + client_version, + telemetry_enabled, + ) self.__client_headers = client_headers if client_headers is not None else {} warn_if_http_and_headers(url, self.__client_headers) diff --git a/packages/toolbox-core/src/toolbox_core/exceptions.py b/packages/toolbox-core/src/toolbox_core/exceptions.py new file mode 100644 index 000000000..384bcd52c --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/exceptions.py @@ -0,0 +1,27 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class ToolboxError(Exception): + """Base exception for all MCP Toolbox errors.""" + + pass + + +class ProtocolNegotiationError(ToolboxError): + """Raised when the server requires a different protocol version during a stateless request.""" + + def __init__(self, negotiated_version: str): + self.negotiated_version = negotiated_version + super().__init__(f"Server requires protocol fallback to {negotiated_version}") diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py index 95a93a79f..ca5d0217f 100644 --- a/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/__init__.py @@ -16,10 +16,12 @@ from .v20250326.mcp import McpHttpTransportV20250326 from .v20250618.mcp import McpHttpTransportV20250618 from .v20251125.mcp import McpHttpTransportV20251125 +from .v20260618.mcp import McpHttpTransportV20260618 __all__ = [ "McpHttpTransportV20241105", "McpHttpTransportV20250326", "McpHttpTransportV20250618", "McpHttpTransportV20251125", + "McpHttpTransportV20260618", ] diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py new file mode 100644 index 000000000..5633bc4ad --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/mcp.py @@ -0,0 +1,297 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from typing import Mapping, Optional, TypeVar + +from pydantic import BaseModel + +from ... import version +from ...exceptions import ProtocolNegotiationError +from ...protocol import ManifestSchema, Protocol, TelemetryAttributes +from .. import telemetry +from ..transport_base import _McpHttpTransportBase +from . import types + +ReceiveResultT = TypeVar("ReceiveResultT", bound=BaseModel) + + +class McpHttpTransportV20260618(_McpHttpTransportBase): + """Transport for the MCP draft Request-Metadata (v2026-06-18) protocol.""" + + async def _send_request( + self, + url: str, + request: types.MCPRequest[ReceiveResultT] | types.MCPNotification, + headers: Optional[Mapping[str, str]] = None, + ) -> ReceiveResultT | None: + """Sends a JSON-RPC request to the MCP server.""" + req_headers = dict(headers or {}) + req_headers["MCP-Protocol-Version"] = self._protocol_version + + # Dynamically update the _meta protocol version in the parameters model + if hasattr(request, "params") and request.params is not None: + if ( + hasattr(request.params, "field_meta") + and request.params.field_meta is not None + ): + request.params.field_meta.protocol_version = self._protocol_version + + params = ( + request.params.model_dump(mode="json", exclude_none=True, by_alias=True) + if isinstance(request.params, BaseModel) + else request.params + ) + + rpc_msg: BaseModel + if isinstance(request, types.MCPNotification): + rpc_msg = types.JSONRPCNotification(method=request.method, params=params) + else: + rpc_msg = types.JSONRPCRequest(method=request.method, params=params) + + payload = rpc_msg.model_dump(mode="json", exclude_none=True) + + async with self._session.post( + url, json=payload, headers=req_headers + ) as response: + if response.status == 400: + try: + json_resp = await response.json() + if "error" in json_resp: + err_val = json_resp["error"] + if isinstance(err_val, dict) and err_val.get("code") == -32004: + server_supported = err_val.get("data", {}).get( + "supported", [] + ) + + client_supported = Protocol.get_supported_mcp_versions() + mutually_supported = [ + v for v in client_supported if v in server_supported + ] + + if mutually_supported: + raise ProtocolNegotiationError(mutually_supported[0]) + else: + raise RuntimeError( + "No mutually supported protocol version. " + f"Client supports: {client_supported}, " + f"Server supports: {server_supported}" + ) + elif ( + isinstance(err_val, str) + and "invalid protocol version" in err_val.lower() + ): + # Legacy 2025-06-18 servers don't use the -32004 code or provide + # a supported versions list. They return this raw string error + # instead. We safely assume 2025-06-18 here. + raise ProtocolNegotiationError(Protocol.MCP_v20250618) + except Exception as e: + if isinstance(e, (RuntimeError, ProtocolNegotiationError)): + raise e + + if not response.ok: + error_text = await response.text() + raise RuntimeError( + "API request failed with status" + f" {response.status} ({response.reason}). Server response:" + f" {error_text}" + ) + + if response.status == 204 or response.content.at_eof(): + return None + + json_resp = await response.json() + + # Check for JSON-RPC Error + if "error" in json_resp: + try: + err = types.JSONRPCError.model_validate(json_resp).error + raise RuntimeError( + f"MCP request failed with code {err.code}: {err.message}" + ) + except Exception: + # Fallback if the error doesn't match our schema exactly + raw_error = json_resp.get("error", {}) + raise RuntimeError(f"MCP request failed: {raw_error}") + + # Parse Result + if isinstance(request, types.MCPRequest): + try: + rpc_resp = types.JSONRPCResponse.model_validate(json_resp) + return request.get_result_model().model_validate(rpc_resp.result) + except Exception as e: + raise RuntimeError(f"Failed to parse JSON-RPC response: {e}") + return None + + async def _initialize_session( + self, headers: Optional[Mapping[str, str]] = None + ) -> None: + """No-op for stateless transport since there is no session handshake.""" + pass + + async def tools_list( + self, + toolset_name: Optional[str] = None, + headers: Optional[Mapping[str, str]] = None, + ) -> ManifestSchema: + """Lists available tools from the server using the MCP protocol.""" + await self._ensure_initialized(headers=headers) + + url = self._mcp_base_url + (toolset_name if toolset_name else "") + + meta = types.MCPMeta( + protocol_version=self._protocol_version, + client_info=types.Implementation( + name=self._client_name or "toolbox-core-python", + version=self._client_version or version.__version__, + ), + client_capabilities=types.ClientCapabilities(), + ) + + if self._telemetry_enabled: + operation_start = time.time() + span, traceparent, tracestate = telemetry.start_span( + self._tracer, + "tools/list", + self._protocol_version, + url, + network_transport="tcp", + ) + if span is not None: + meta.traceparent = traceparent or None + meta.tracestate = tracestate or None + + error: Optional[Exception] = None + try: + result = await self._send_request( + url=url, + request=types.ListToolsRequest( + params=types.ListToolsRequestParams(field_meta=meta) + ), + headers=headers, + ) + if result is None: + raise RuntimeError("Failed to list tools: No response from server.") + + tools_map = {t["name"]: self._convert_tool_schema(t) for t in result.tools} + + return ManifestSchema( + serverVersion="1.0.0", + tools=tools_map, + ) + except Exception as e: + error = e + raise + finally: + if self._telemetry_enabled: + operation_duration = time.time() - operation_start + telemetry.record_operation_duration( + self._operation_duration_histogram, + operation_duration, + "tools/list", + self._protocol_version, + url, + network_transport="tcp", + error=error, + ) + telemetry.end_span(span, error=error) + + async def tool_get( + self, tool_name: str, headers: Optional[Mapping[str, str]] = None + ) -> ManifestSchema: + """Gets a single tool from the server by listing all and filtering.""" + manifest = await self.tools_list(headers=headers) + + if tool_name not in manifest.tools: + raise ValueError(f"Tool '{tool_name}' not found.") + + return ManifestSchema( + serverVersion=manifest.serverVersion, + tools={tool_name: manifest.tools[tool_name]}, + ) + + async def tool_invoke( + self, + tool_name: str, + arguments: dict, + headers: Optional[Mapping[str, str]], + telemetry_attributes: Optional[TelemetryAttributes] = None, + ) -> str: + """Invokes a specific tool on the server using the MCP protocol.""" + await self._ensure_initialized(headers=headers) + + payload = self._build_telemetry_payload(telemetry_attributes) + + meta = types.MCPMeta( + protocol_version=self._protocol_version, + client_info=types.Implementation( + name=self._client_name or "toolbox-core-python", + version=self._client_version or version.__version__, + ), + client_capabilities=types.ClientCapabilities(), + telemetry_attributes=payload, + ) + + span = None + if self._telemetry_enabled: + operation_start = time.time() + span, traceparent, tracestate = telemetry.start_span( + self._tracer, + "tools/call", + self._protocol_version, + self._mcp_base_url, + tool_name=tool_name, + network_transport="tcp", + ) + meta.traceparent = traceparent or None + meta.tracestate = tracestate or None + if span is not None and payload: + for key, value in payload.items(): + span.set_attribute(key, value) + + error: Optional[Exception] = None + try: + result = await self._send_request( + url=self._mcp_base_url, + request=types.CallToolRequest( + params=types.CallToolRequestParams( + name=tool_name, arguments=arguments, field_meta=meta + ) + ), + headers=headers, + ) + + if result is None: + raise RuntimeError( + f"Failed to invoke tool '{tool_name}': No response from server." + ) + + return self._process_tool_result_content(result.content) + except Exception as e: + error = e + raise + finally: + if self._telemetry_enabled: + operation_duration = time.time() - operation_start + telemetry.record_operation_duration( + self._operation_duration_histogram, + operation_duration, + "tools/call", + self._protocol_version, + self._mcp_base_url, + tool_name=tool_name, + network_transport="tcp", + error=error, + ) + telemetry.end_span(span, error=error) diff --git a/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/types.py b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/types.py new file mode 100644 index 000000000..a18eea704 --- /dev/null +++ b/packages/toolbox-core/src/toolbox_core/mcp_transport/v20260618/types.py @@ -0,0 +1,160 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import uuid +from typing import Any, Generic, Literal, Type, TypeVar + +from pydantic import BaseModel, ConfigDict, Field + + +class _BaseMCPModel(BaseModel): + """Base model with common configuration.""" + + model_config = ConfigDict(extra="allow") + + +class JSONRPCRequest(_BaseMCPModel): + jsonrpc: Literal["2.0"] = "2.0" + id: str | int = Field(default_factory=lambda: str(uuid.uuid4())) + method: str + params: dict[str, Any] | None = None + + +class JSONRPCNotification(_BaseMCPModel): + """A notification which does not expect a response (no ID).""" + + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: dict[str, Any] | None = None + + +class JSONRPCResponse(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + result: dict[str, Any] + + +class ErrorData(_BaseMCPModel): + code: int + message: str + data: Any | None = None + + +class JSONRPCError(_BaseMCPModel): + jsonrpc: Literal["2.0"] + id: str | int + error: ErrorData + + +class SamplingCapabilities(_BaseMCPModel): + context: dict[str, Any] | None = None + tools: dict[str, Any] | None = None + + +class ElicitationCapabilities(_BaseMCPModel): + form: dict[str, Any] | None = None + url: dict[str, Any] | None = None + + +class ClientCapabilities(_BaseMCPModel): + experimental: dict[str, Any] | None = None + roots: dict[str, Any] | None = None + sampling: SamplingCapabilities | None = None + elicitation: ElicitationCapabilities | None = None + extensions: dict[str, Any] | None = None + + +class Implementation(_BaseMCPModel): + name: str + version: str + + +class MCPMeta(_BaseMCPModel): + """Metadata for MCP requests. + + Carries the three required fields in io.modelcontextprotocol/* namespace. + """ + + protocol_version: str = Field( + ..., serialization_alias="io.modelcontextprotocol/protocolVersion" + ) + client_info: Implementation = Field( + ..., serialization_alias="io.modelcontextprotocol/clientInfo" + ) + client_capabilities: ClientCapabilities = Field( + ..., serialization_alias="io.modelcontextprotocol/clientCapabilities" + ) + + # Tracing and attributes + traceparent: str | None = None + tracestate: str | None = None + telemetry_attributes: dict[str, Any] | None = Field( + default=None, serialization_alias="dev.mcp-toolbox/telemetry" + ) + + +class ListToolsResult(_BaseMCPModel): + tools: list[dict[str, Any]] + + +class TextContent(_BaseMCPModel): + type: Literal["text"] + text: str + + +class CallToolResult(_BaseMCPModel): + content: list[TextContent] + isError: bool = False + + +ResultT = TypeVar("ResultT", bound=BaseModel) + + +class MCPRequest(_BaseMCPModel, Generic[ResultT]): + method: str + params: dict[str, Any] | BaseModel | None = None + + def get_result_model(self) -> Type[ResultT]: + raise NotImplementedError + + +class MCPNotification(_BaseMCPModel): + method: str + params: dict[str, Any] | BaseModel | None = None + + +class ListToolsRequestParams(_BaseMCPModel): + field_meta: MCPMeta = Field(..., serialization_alias="_meta") + + +class ListToolsRequest(MCPRequest[ListToolsResult]): + method: Literal["tools/list"] = "tools/list" + params: ListToolsRequestParams + + def get_result_model(self) -> Type[ListToolsResult]: + return ListToolsResult + + +class CallToolRequestParams(_BaseMCPModel): + name: str + arguments: dict[str, Any] + field_meta: MCPMeta = Field(..., serialization_alias="_meta") + + +class CallToolRequest(MCPRequest[CallToolResult]): + method: Literal["tools/call"] = "tools/call" + params: CallToolRequestParams + + def get_result_model(self) -> Type[CallToolResult]: + return CallToolResult diff --git a/packages/toolbox-core/src/toolbox_core/protocol.py b/packages/toolbox-core/src/toolbox_core/protocol.py index 378f83fe8..472580a94 100644 --- a/packages/toolbox-core/src/toolbox_core/protocol.py +++ b/packages/toolbox-core/src/toolbox_core/protocol.py @@ -47,17 +47,19 @@ def _empty_string_to_none(cls, value: Any) -> Any: class Protocol(str, Enum): """Defines how the client should choose between communication protocols.""" + MCP_v20260618 = "DRAFT-2026-v1" MCP_v20250618 = "2025-06-18" MCP_v20250326 = "2025-03-26" MCP_v20241105 = "2024-11-05" MCP_v20251125 = "2025-11-25" MCP = MCP_v20250618 - MCP_LATEST = MCP_v20251125 + MCP_LATEST = MCP_v20260618 @staticmethod def get_supported_mcp_versions() -> list[str]: """Returns a list of supported MCP protocol versions.""" return [ + Protocol.MCP_v20260618.value, Protocol.MCP_v20251125.value, Protocol.MCP_v20250618.value, Protocol.MCP_v20250326.value, diff --git a/packages/toolbox-core/tests/conformance/client.py b/packages/toolbox-core/tests/conformance/client.py index 9ab58812c..5d59d593a 100644 --- a/packages/toolbox-core/tests/conformance/client.py +++ b/packages/toolbox-core/tests/conformance/client.py @@ -18,9 +18,16 @@ import sys from toolbox_core.client import ToolboxClient +from toolbox_core.protocol import Protocol async def main(): + """Harness main execution block. + + NOTE: All non-protocol outputs (logs, traces, errors) must be directed to + sys.stderr. The test runner captures stdout for protocol messages only, + printing other content to stdout will pollute the stream and crash the runner. + """ if len(sys.argv) < 2: print("Usage: client.py ", file=sys.stderr) sys.exit(1) @@ -41,7 +48,13 @@ async def main(): client_headers = {"Accept": "application/json, text/event-stream"} - async with ToolboxClient(server_url, client_headers=client_headers) as client: + protocol = Protocol.MCP + if scenario == "request-metadata": + protocol = Protocol.MCP_v20260618 + + async with ToolboxClient( + server_url, client_headers=client_headers, protocol=protocol + ) as client: if scenario == "initialize": await client.load_toolset() print("Client initialization test completed", file=sys.stderr) @@ -51,6 +64,10 @@ async def main(): await add_numbers(a=1, b=2) print("Invoked add_numbers(a=1, b=2)", file=sys.stderr) + elif scenario == "request-metadata": + await client.load_toolset() + print("Client request-metadata test completed", file=sys.stderr) + else: # Default behavior: load default toolset to trigger standard interactions await client.load_toolset() diff --git a/packages/toolbox-core/tests/mcp_transport/test_v20260618.py b/packages/toolbox-core/tests/mcp_transport/test_v20260618.py new file mode 100644 index 000000000..4b1be165b --- /dev/null +++ b/packages/toolbox-core/tests/mcp_transport/test_v20260618.py @@ -0,0 +1,232 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +import pytest_asyncio +from aiohttp import ClientSession + +from toolbox_core.mcp_transport.v20260618 import types +from toolbox_core.mcp_transport.v20260618.mcp import McpHttpTransportV20260618 +from toolbox_core.protocol import ManifestSchema, Protocol + + +def create_fake_tools_list_result(): + return types.ListToolsResult( + tools=[ + { + "name": "get_weather", + "description": "Gets the weather.", + "inputSchema": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + } + ] + ) + + +@pytest_asyncio.fixture( + params=[False, True], ids=["telemetry_disabled", "telemetry_enabled"] +) +async def transport(request, mocker): + if request.param: + mocker.patch("toolbox_core.mcp_transport.telemetry.TELEMETRY_AVAILABLE", True) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.get_tracer", return_value=MagicMock() + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.get_meter", return_value=MagicMock() + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.create_operation_duration_histogram", + return_value=MagicMock(), + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.create_session_duration_histogram", + return_value=MagicMock(), + ) + mocker.patch( + "toolbox_core.mcp_transport.telemetry.start_span", + return_value=(MagicMock(), "00-traceparent", ""), + ) + mocker.patch("toolbox_core.mcp_transport.telemetry.end_span") + mocker.patch("toolbox_core.mcp_transport.telemetry.record_operation_duration") + mocker.patch("toolbox_core.mcp_transport.telemetry.record_session_duration") + mock_session = AsyncMock(spec=ClientSession) + transport = McpHttpTransportV20260618( + "http://fake-server.com", + session=mock_session, + protocol=Protocol.MCP_v20260618, + telemetry_enabled=request.param, + ) + yield transport + await transport.close() + + +@pytest.mark.asyncio +class TestMcpHttpTransportV20260618: + + # --- Request Sending Tests (Standard + Header) --- + + async def test_send_request_success(self, transport): + mock_response = AsyncMock() + mock_response.ok = True + mock_response.status = 200 + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + result = await transport._send_request("url", TestRequest()) + assert result == TestResult() + + async def test_send_request_adds_protocol_header(self, transport): + """Test that the MCP-Protocol-Version header is added.""" + mock_response = AsyncMock() + mock_response.ok = True + mock_response.content = Mock() + mock_response.content.at_eof.return_value = False + mock_response.json.return_value = {"jsonrpc": "2.0", "id": "1", "result": {}} + transport._session.post.return_value.__aenter__.return_value = mock_response + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + await transport._send_request("url", TestRequest()) + + call_args = transport._session.post.call_args + headers = call_args.kwargs["headers"] + assert headers["MCP-Protocol-Version"] == "DRAFT-2026-v1" + + # --- Version Negotiation Tests --- + + async def test_version_negotiation_raises_fallback(self, transport): + """Tests that the client raises ProtocolNegotiationError when the server requests a fallback.""" + from toolbox_core.exceptions import ProtocolNegotiationError + + mock_response_reject = AsyncMock() + mock_response_reject.ok = False + mock_response_reject.status = 400 + mock_response_reject.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": { + "code": -32004, + "message": "Unsupported protocol version", + "data": {"supported": ["DRAFT-2026-v1"]}, + }, + } + + transport._session.post.return_value.__aenter__.return_value = ( + mock_response_reject + ) + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises(ProtocolNegotiationError) as exc_info: + await transport._send_request("url", TestRequest()) + + assert exc_info.value.negotiated_version == "DRAFT-2026-v1" + assert transport._session.post.call_count == 1 + + async def test_version_negotiation_empty_intersection(self, transport): + """Tests that the client errors immediately without retrying when there is no mutual version.""" + mock_response_reject = AsyncMock() + mock_response_reject.ok = False + mock_response_reject.status = 400 + mock_response_reject.json.return_value = { + "jsonrpc": "2.0", + "id": "1", + "error": { + "code": -32004, + "message": "Unsupported protocol version", + "data": {"supported": ["UNSUPPORTED-VERSION"]}, + }, + } + + transport._session.post.return_value.__aenter__.return_value = ( + mock_response_reject + ) + + class TestResult(types.BaseModel): + pass + + class TestRequest(types.MCPRequest[TestResult]): + method: str = "method" + params: dict = {} + + def get_result_model(self): + return TestResult + + with pytest.raises( + RuntimeError, match="No mutually supported protocol version" + ): + await transport._send_request("url", TestRequest()) + + assert transport._session.post.call_count == 1 + + # --- Tool Management Tests --- + + async def test_tools_list_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=create_fake_tools_list_result(), + ) + manifest = await transport.tools_list() + assert isinstance(manifest, ManifestSchema) + assert "get_weather" in manifest.tools + + async def test_tool_invoke_success(self, transport, mocker): + mocker.patch.object(transport, "_ensure_initialized", new_callable=AsyncMock) + mocker.patch.object( + transport, + "_send_request", + new_callable=AsyncMock, + return_value=types.CallToolResult( + content=[types.TextContent(type="text", text="Result")] + ), + ) + result = await transport.tool_invoke("tool", {}, {}) + assert result == "Result" diff --git a/packages/toolbox-core/tests/test_client.py b/packages/toolbox-core/tests/test_client.py index 16a00e3f3..d41b3f581 100644 --- a/packages/toolbox-core/tests/test_client.py +++ b/packages/toolbox-core/tests/test_client.py @@ -248,6 +248,94 @@ async def test_load_tool_not_found_in_manifest(mock_transport, test_tool_str): mock_transport.tool_get_mock.assert_awaited_once_with(REQUESTED_TOOL_NAME, {}) +@pytest.mark.asyncio +async def test_load_tool_protocol_fallback_success(test_tool_str): + """ + Tests that the client successfully swaps transports and retries when a + ProtocolNegotiationError is raised. + """ + TOOL_NAME = "test_tool_1" + manifest = ManifestSchema(serverVersion="0.0.0", tools={TOOL_NAME: test_tool_str}) + + from toolbox_core.exceptions import ProtocolNegotiationError + + # We need to mock the transports that client.py will instantiate + with ( + patch("toolbox_core.client.McpHttpTransportV20260618") as mock_2026_cls, + patch("toolbox_core.client.McpHttpTransportV20250618") as mock_2025_cls, + ): + + mock_2026 = AsyncMock() + mock_2026.base_url = TEST_BASE_URL + mock_2026.tool_get.side_effect = ProtocolNegotiationError("2025-06-18") + mock_2026_cls.return_value = mock_2026 + + mock_2025 = AsyncMock() + mock_2025.base_url = TEST_BASE_URL + mock_2025.tool_get.return_value = manifest + mock_2025.tool_invoke.return_value = "ok_from_fallback" + mock_2025_cls.return_value = mock_2025 + + async with ToolboxClient( + TEST_BASE_URL, protocol=Protocol.MCP_v20260618 + ) as client: + # This should trigger the fallback + loaded_tool = await client.load_tool(TOOL_NAME) + + # Assert the first transport was closed + mock_2026.close.assert_awaited_once() + + # Assert the second transport was instantiated and used + mock_2025_cls.assert_called_once() + mock_2025.tool_get.assert_awaited_once_with(TOOL_NAME, {}) + + # Assert the tool was bound to the *new* transport + assert await loaded_tool("some value") == "ok_from_fallback" + mock_2025.tool_invoke.assert_awaited_once_with( + TOOL_NAME, {"param1": "some value"}, {} + ) + + +@pytest.mark.asyncio +async def test_load_tool_protocol_fallback_infinite_loop_prevention(test_tool_str): + """ + Tests that if the fallback transport *also* raises ProtocolNegotiationError, + the client does not get stuck in an infinite loop. + """ + TOOL_NAME = "test_tool_1" + + from toolbox_core.exceptions import ProtocolNegotiationError + + with ( + patch("toolbox_core.client.McpHttpTransportV20260618") as mock_2026_cls, + patch("toolbox_core.client.McpHttpTransportV20250618") as mock_2025_cls, + ): + + mock_2026 = AsyncMock() + mock_2026.base_url = TEST_BASE_URL + mock_2026.tool_get.side_effect = ProtocolNegotiationError("2025-06-18") + mock_2026_cls.return_value = mock_2026 + + mock_2025 = AsyncMock() + mock_2025.base_url = TEST_BASE_URL + # The fallback transport also throws the error + mock_2025.tool_get.side_effect = ProtocolNegotiationError("2024-11-05") + mock_2025_cls.return_value = mock_2025 + + async with ToolboxClient( + TEST_BASE_URL, protocol=Protocol.MCP_v20260618 + ) as client: + with pytest.raises( + ProtocolNegotiationError, + match="Server requires protocol fallback to 2024-11-05", + ): + await client.load_tool(TOOL_NAME) + + # Assert we tried both, but then let the exception bubble up instead of looping + mock_2026.tool_get.assert_awaited_once() + mock_2025.tool_get.assert_awaited_once() + + class TestAuth: @pytest.fixture def expected_header(self): diff --git a/packages/toolbox-core/tests/test_e2e_mcp.py b/packages/toolbox-core/tests/test_e2e_mcp.py index 6acaeaded..c35695ffa 100644 --- a/packages/toolbox-core/tests/test_e2e_mcp.py +++ b/packages/toolbox-core/tests/test_e2e_mcp.py @@ -24,9 +24,11 @@ from toolbox_core.tool import ToolboxTool +# TODO: Include draft versions in E2E integration tests once the server +# supports SEP-2575 (stateless MCP / Request-Metadata). @pytest_asyncio.fixture( scope="function", - params=Protocol.get_supported_mcp_versions(), + params=[v for v in Protocol.get_supported_mcp_versions() if "DRAFT" not in v], ) async def toolbox(request): """Creates a ToolboxClient instance shared by all tests in this module.""" @@ -98,6 +100,21 @@ async def test_run_tool_missing_params(self, get_n_rows_tool: ToolboxTool): with pytest.raises(TypeError, match="missing a required argument: 'num_rows'"): await get_n_rows_tool() + async def test_protocol_fallback_e2e(self): + """Tests that a client using MCP_LATEST can fallback to an older protocol against a server that doesn't support the latest version.""" + # The E2E server currently does not support DRAFT 2026, so this will trigger a fallback. + async with ToolboxClient( + "http://localhost:5000", protocol=Protocol.MCP_LATEST + ) as client: + tool = await client.load_tool("get-n-rows") + response = await tool(num_rows="1") + assert "row1" in response + # Verify that fallback occurred by checking the transport's final protocol version + assert ( + client._ToolboxClient__transport._protocol_version + != Protocol.MCP_LATEST.value + ) + async def test_run_tool_wrong_param_type(self, get_n_rows_tool: ToolboxTool): """Invoke a tool with wrong param type.""" with pytest.raises( diff --git a/packages/toolbox-core/tests/test_protocol.py b/packages/toolbox-core/tests/test_protocol.py index d8079d8ee..3c50b3225 100644 --- a/packages/toolbox-core/tests/test_protocol.py +++ b/packages/toolbox-core/tests/test_protocol.py @@ -77,7 +77,13 @@ def test_get_supported_mcp_versions(): Tests that get_supported_mcp_versions returns the correct list of versions, sorted from newest to oldest. """ - expected_versions = ["2025-11-25", "2025-06-18", "2025-03-26", "2024-11-05"] + expected_versions = [ + "DRAFT-2026-v1", + "2025-11-25", + "2025-06-18", + "2025-03-26", + "2024-11-05", + ] supported_versions = Protocol.get_supported_mcp_versions() assert supported_versions == expected_versions