Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 43 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ async def _default_elicitation_callback(
context: RequestContext[ClientSession],
params: types.ElicitRequestParams,
) -> types.ElicitResult | types.ErrorData:
return types.ErrorData( # pragma: no cover
return types.ErrorData(
code=types.INVALID_REQUEST,
message="Elicitation not supported",
)
Expand Down Expand Up @@ -216,6 +216,48 @@ def experimental(self) -> ExperimentalClientFeatures:
self._experimental_features = ExperimentalClientFeatures(self)
return self._experimental_features

def set_sampling_callback(self, callback: SamplingFnT | None) -> None:
"""Update the sampling callback.

Note: Client capabilities are advertised to the server during :meth:`initialize`
and will not be re-negotiated when this setter is called. If a sampling
callback is set after initialization, the server may not be aware of the
capability.

Args:
callback: The new sampling callback, or ``None`` to restore the default
(which rejects all sampling requests with an error).
"""
self._sampling_callback = callback or _default_sampling_callback

def set_elicitation_callback(self, callback: ElicitationFnT | None) -> None:
"""Update the elicitation callback.

Note: Client capabilities are advertised to the server during :meth:`initialize`
and will not be re-negotiated when this setter is called. If an elicitation
callback is set after initialization, the server may not be aware of the
capability.

Args:
callback: The new elicitation callback, or ``None`` to restore the default
(which rejects all elicitation requests with an error).
"""
self._elicitation_callback = callback or _default_elicitation_callback

def set_list_roots_callback(self, callback: ListRootsFnT | None) -> None:
"""Update the list roots callback.

Note: Client capabilities are advertised to the server during :meth:`initialize`
and will not be re-negotiated when this setter is called. If a list-roots
callback is set after initialization, the server may not be aware of the
capability.

Args:
callback: The new list roots callback, or ``None`` to restore the default
(which rejects all list-roots requests with an error).
"""
self._list_roots_callback = callback or _default_list_roots_callback

async def send_ping(self, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
"""Send a ping request."""
return await self.send_request(types.PingRequest(params=types.RequestParams(_meta=meta)), types.EmptyResult)
Expand Down
51 changes: 51 additions & 0 deletions tests/client/test_elicitation_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

import pytest
from pydantic import BaseModel, Field

from mcp import Client
from mcp.client.session import ClientSession
from mcp.server.mcpserver import Context, MCPServer
from mcp.shared._context import RequestContext
from mcp.types import ElicitRequestParams, ElicitResult, TextContent


class AnswerSchema(BaseModel):
answer: str = Field(description="The user's answer")


@pytest.mark.anyio
async def test_set_elicitation_callback():
server = MCPServer("test")

updated_answer = "Updated answer"

async def updated_callback(
context: RequestContext[ClientSession],
params: ElicitRequestParams,
) -> ElicitResult:
return ElicitResult(action="accept", content={"answer": updated_answer})

@server.tool("ask")
async def ask(prompt: str, ctx: Context) -> str:
result = await ctx.elicit(message=prompt, schema=AnswerSchema)
if result.action == "accept" and result.data:
return result.data.answer
return "no answer" # pragma: no cover

async with Client(server) as client:
# Before setting callback — default rejects with error
result = await client.call_tool("ask", {"prompt": "question?"})
assert result.is_error is True

# Set new callback — should succeed
client.session.set_elicitation_callback(updated_callback)
result = await client.call_tool("ask", {"prompt": "question?"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == updated_answer

# Reset to None — back to default error
client.session.set_elicitation_callback(None)
result = await client.call_tool("ask", {"prompt": "question?"})
assert result.is_error is True
39 changes: 39 additions & 0 deletions tests/client/test_list_roots_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,42 @@ async def test_list_roots(context: Context, message: str):
assert result.is_error is True
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported"


@pytest.mark.anyio
async def test_set_list_roots_callback():
server = MCPServer("test")

updated_result = ListRootsResult(
roots=[
Root(uri=FileUrl("file://users/fake/updated"), name="Updated Root"),
]
)

async def updated_callback(
context: RequestContext[ClientSession],
) -> ListRootsResult:
return updated_result

@server.tool("get_roots")
async def get_roots(context: Context, param: str) -> bool:
roots = await context.session.list_roots()
assert roots == updated_result
return True

async with Client(server) as client:
# Before setting callback — default rejects with error
result = await client.call_tool("get_roots", {"param": "x"})
assert result.is_error is True

# Set new callback — should succeed
client.session.set_list_roots_callback(updated_callback)
result = await client.call_tool("get_roots", {"param": "x"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"

# Reset to None — back to default error
client.session.set_list_roots_callback(None)
result = await client.call_tool("get_roots", {"param": "x"})
assert result.is_error is True
44 changes: 44 additions & 0 deletions tests/client/test_sampling_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,50 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool:
assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported"


@pytest.mark.anyio
async def test_set_sampling_callback():
server = MCPServer("test")

updated_return = CreateMessageResult(
role="assistant",
content=TextContent(type="text", text="Updated response"),
model="updated-model",
stop_reason="endTurn",
)

async def updated_callback(
context: RequestContext[ClientSession],
params: CreateMessageRequestParams,
) -> CreateMessageResult:
return updated_return

@server.tool("do_sample")
async def do_sample(message: str, ctx: Context) -> bool:
value = await ctx.session.create_message(
messages=[SamplingMessage(role="user", content=TextContent(type="text", text=message))],
max_tokens=100,
)
assert value == updated_return
return True

async with Client(server) as client:
# Before setting callback — default rejects with error
result = await client.call_tool("do_sample", {"message": "test"})
assert result.is_error is True

# Set new callback — should succeed
client.session.set_sampling_callback(updated_callback)
result = await client.call_tool("do_sample", {"message": "test"})
assert result.is_error is False
assert isinstance(result.content[0], TextContent)
assert result.content[0].text == "true"

# Reset to None — back to default error
client.session.set_sampling_callback(None)
result = await client.call_tool("do_sample", {"message": "test"})
assert result.is_error is True


@pytest.mark.anyio
async def test_create_message_backwards_compat_single_content():
"""Test backwards compatibility: create_message without tools returns single content."""
Expand Down
Loading