Skip to content

Commit 7f94bef

Browse files
Client sampling and roots capabilities set to None if not implemented (#802)
Co-authored-by: ihrpr <[email protected]>
1 parent d55cb2b commit 7f94bef

File tree

3 files changed

+177
-4
lines changed

3 files changed

+177
-4
lines changed

src/mcp/client/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,18 @@ def __init__(
116116
self._message_handler = message_handler or _default_message_handler
117117

118118
async def initialize(self) -> types.InitializeResult:
119-
sampling = types.SamplingCapability()
120-
roots = types.RootsCapability(
119+
sampling = (
120+
types.SamplingCapability()
121+
if self._sampling_callback is not _default_sampling_callback
122+
else None
123+
)
124+
roots = (
121125
# TODO: Should this be based on whether we
122126
# _will_ send notifications, or only whether
123127
# they're supported?
124-
listChanged=True,
128+
types.RootsCapability(listChanged=True)
129+
if self._list_roots_callback is not _default_list_roots_callback
130+
else None
125131
)
126132

127133
result = await self.send_request(

src/mcp/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ class RootsCapability(BaseModel):
218218

219219

220220
class SamplingCapability(BaseModel):
221-
"""Capability for logging operations."""
221+
"""Capability for sampling operations."""
222222

223223
model_config = ConfigDict(extra="allow")
224224

tests/client/test_session.py

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
from typing import Any
2+
13
import anyio
24
import pytest
35

46
import mcp.types as types
57
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
8+
from mcp.shared.context import RequestContext
69
from mcp.shared.message import SessionMessage
710
from mcp.shared.session import RequestResponder
811
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
@@ -380,3 +383,167 @@ async def mock_server():
380383
# Should raise RuntimeError for unsupported version
381384
with pytest.raises(RuntimeError, match="Unsupported protocol version"):
382385
await session.initialize()
386+
387+
388+
@pytest.mark.anyio
389+
async def test_client_capabilities_default():
390+
"""Test that client capabilities are properly set with default callbacks"""
391+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
392+
SessionMessage
393+
](1)
394+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
395+
SessionMessage
396+
](1)
397+
398+
received_capabilities = None
399+
400+
async def mock_server():
401+
nonlocal received_capabilities
402+
403+
session_message = await client_to_server_receive.receive()
404+
jsonrpc_request = session_message.message
405+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
406+
request = ClientRequest.model_validate(
407+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
408+
)
409+
assert isinstance(request.root, InitializeRequest)
410+
received_capabilities = request.root.params.capabilities
411+
412+
result = ServerResult(
413+
InitializeResult(
414+
protocolVersion=LATEST_PROTOCOL_VERSION,
415+
capabilities=ServerCapabilities(),
416+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
417+
)
418+
)
419+
420+
async with server_to_client_send:
421+
await server_to_client_send.send(
422+
SessionMessage(
423+
JSONRPCMessage(
424+
JSONRPCResponse(
425+
jsonrpc="2.0",
426+
id=jsonrpc_request.root.id,
427+
result=result.model_dump(
428+
by_alias=True, mode="json", exclude_none=True
429+
),
430+
)
431+
)
432+
)
433+
)
434+
# Receive initialized notification
435+
await client_to_server_receive.receive()
436+
437+
async with (
438+
ClientSession(
439+
server_to_client_receive,
440+
client_to_server_send,
441+
) as session,
442+
anyio.create_task_group() as tg,
443+
client_to_server_send,
444+
client_to_server_receive,
445+
server_to_client_send,
446+
server_to_client_receive,
447+
):
448+
tg.start_soon(mock_server)
449+
await session.initialize()
450+
451+
# Assert that capabilities are properly set with defaults
452+
assert received_capabilities is not None
453+
assert received_capabilities.sampling is None # No custom sampling callback
454+
assert received_capabilities.roots is None # No custom list_roots callback
455+
456+
457+
@pytest.mark.anyio
458+
async def test_client_capabilities_with_custom_callbacks():
459+
"""Test that client capabilities are properly set with custom callbacks"""
460+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
461+
SessionMessage
462+
](1)
463+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
464+
SessionMessage
465+
](1)
466+
467+
received_capabilities = None
468+
469+
async def custom_sampling_callback(
470+
context: RequestContext["ClientSession", Any],
471+
params: types.CreateMessageRequestParams,
472+
) -> types.CreateMessageResult | types.ErrorData:
473+
return types.CreateMessageResult(
474+
role="assistant",
475+
content=types.TextContent(type="text", text="test"),
476+
model="test-model",
477+
)
478+
479+
async def custom_list_roots_callback(
480+
context: RequestContext["ClientSession", Any],
481+
) -> types.ListRootsResult | types.ErrorData:
482+
return types.ListRootsResult(roots=[])
483+
484+
async def mock_server():
485+
nonlocal received_capabilities
486+
487+
session_message = await client_to_server_receive.receive()
488+
jsonrpc_request = session_message.message
489+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
490+
request = ClientRequest.model_validate(
491+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
492+
)
493+
assert isinstance(request.root, InitializeRequest)
494+
received_capabilities = request.root.params.capabilities
495+
496+
result = ServerResult(
497+
InitializeResult(
498+
protocolVersion=LATEST_PROTOCOL_VERSION,
499+
capabilities=ServerCapabilities(),
500+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
501+
)
502+
)
503+
504+
async with server_to_client_send:
505+
await server_to_client_send.send(
506+
SessionMessage(
507+
JSONRPCMessage(
508+
JSONRPCResponse(
509+
jsonrpc="2.0",
510+
id=jsonrpc_request.root.id,
511+
result=result.model_dump(
512+
by_alias=True, mode="json", exclude_none=True
513+
),
514+
)
515+
)
516+
)
517+
)
518+
# Receive initialized notification
519+
await client_to_server_receive.receive()
520+
521+
async with (
522+
ClientSession(
523+
server_to_client_receive,
524+
client_to_server_send,
525+
sampling_callback=custom_sampling_callback,
526+
list_roots_callback=custom_list_roots_callback,
527+
) as session,
528+
anyio.create_task_group() as tg,
529+
client_to_server_send,
530+
client_to_server_receive,
531+
server_to_client_send,
532+
server_to_client_receive,
533+
):
534+
tg.start_soon(mock_server)
535+
await session.initialize()
536+
537+
# Assert that capabilities are properly set with custom callbacks
538+
assert received_capabilities is not None
539+
assert (
540+
received_capabilities.sampling is not None
541+
) # Custom sampling callback provided
542+
assert isinstance(received_capabilities.sampling, types.SamplingCapability)
543+
assert (
544+
received_capabilities.roots is not None
545+
) # Custom list_roots callback provided
546+
assert isinstance(received_capabilities.roots, types.RootsCapability)
547+
assert (
548+
received_capabilities.roots.listChanged is True
549+
) # Should be True for custom callback

0 commit comments

Comments
 (0)