Skip to content

Commit c4beb3e

Browse files
jerome3o-anthropicClaude
and
Claude
authored
Support custom client info throughout client APIs (#474)
Co-authored-by: Claude <[email protected]>
1 parent da54ea0 commit c4beb3e

File tree

4 files changed

+142
-3
lines changed

4 files changed

+142
-3
lines changed

Diff for: src/mcp/client/__main__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,13 @@ async def message_handler(
3838
async def run_session(
3939
read_stream: MemoryObjectReceiveStream[JSONRPCMessage | Exception],
4040
write_stream: MemoryObjectSendStream[JSONRPCMessage],
41+
client_info: types.Implementation | None = None,
4142
):
4243
async with ClientSession(
43-
read_stream, write_stream, message_handler=message_handler
44+
read_stream,
45+
write_stream,
46+
message_handler=message_handler,
47+
client_info=client_info,
4448
) as session:
4549
logger.info("Initializing session")
4650
await session.initialize()

Diff for: src/mcp/client/session.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from mcp.shared.session import BaseSession, RequestResponder
1111
from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
1212

13+
DEFAULT_CLIENT_INFO = types.Implementation(name="mcp", version="0.1.0")
14+
1315

1416
class SamplingFnT(Protocol):
1517
async def __call__(
@@ -97,6 +99,7 @@ def __init__(
9799
list_roots_callback: ListRootsFnT | None = None,
98100
logging_callback: LoggingFnT | None = None,
99101
message_handler: MessageHandlerFnT | None = None,
102+
client_info: types.Implementation | None = None,
100103
) -> None:
101104
super().__init__(
102105
read_stream,
@@ -105,6 +108,7 @@ def __init__(
105108
types.ServerNotification,
106109
read_timeout_seconds=read_timeout_seconds,
107110
)
111+
self._client_info = client_info or DEFAULT_CLIENT_INFO
108112
self._sampling_callback = sampling_callback or _default_sampling_callback
109113
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
110114
self._logging_callback = logging_callback or _default_logging_callback
@@ -130,7 +134,7 @@ async def initialize(self) -> types.InitializeResult:
130134
experimental=None,
131135
roots=roots,
132136
),
133-
clientInfo=types.Implementation(name="mcp", version="0.1.0"),
137+
clientInfo=self._client_info,
134138
),
135139
)
136140
),

Diff for: src/mcp/shared/memory.py

+3
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import anyio
1111
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1212

13+
import mcp.types as types
1314
from mcp.client.session import (
1415
ClientSession,
1516
ListRootsFnT,
@@ -65,6 +66,7 @@ async def create_connected_server_and_client_session(
6566
list_roots_callback: ListRootsFnT | None = None,
6667
logging_callback: LoggingFnT | None = None,
6768
message_handler: MessageHandlerFnT | None = None,
69+
client_info: types.Implementation | None = None,
6870
raise_exceptions: bool = False,
6971
) -> AsyncGenerator[ClientSession, None]:
7072
"""Creates a ClientSession that is connected to a running MCP server."""
@@ -95,6 +97,7 @@ async def create_connected_server_and_client_session(
9597
list_roots_callback=list_roots_callback,
9698
logging_callback=logging_callback,
9799
message_handler=message_handler,
100+
client_info=client_info,
98101
) as client_session:
99102
await client_session.initialize()
100103
yield client_session

Diff for: tests/client/test_session.py

+129-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytest
33

44
import mcp.types as types
5-
from mcp.client.session import ClientSession
5+
from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
66
from mcp.shared.session import RequestResponder
77
from mcp.types import (
88
LATEST_PROTOCOL_VERSION,
@@ -111,3 +111,131 @@ async def message_handler(
111111
# Check that the client sent the initialized notification
112112
assert initialized_notification
113113
assert isinstance(initialized_notification.root, InitializedNotification)
114+
115+
116+
@pytest.mark.anyio
117+
async def test_client_session_custom_client_info():
118+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
119+
JSONRPCMessage
120+
](1)
121+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
122+
JSONRPCMessage
123+
](1)
124+
125+
custom_client_info = Implementation(name="test-client", version="1.2.3")
126+
received_client_info = None
127+
128+
async def mock_server():
129+
nonlocal received_client_info
130+
131+
jsonrpc_request = await client_to_server_receive.receive()
132+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
133+
request = ClientRequest.model_validate(
134+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
135+
)
136+
assert isinstance(request.root, InitializeRequest)
137+
received_client_info = request.root.params.clientInfo
138+
139+
result = ServerResult(
140+
InitializeResult(
141+
protocolVersion=LATEST_PROTOCOL_VERSION,
142+
capabilities=ServerCapabilities(),
143+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
144+
)
145+
)
146+
147+
async with server_to_client_send:
148+
await server_to_client_send.send(
149+
JSONRPCMessage(
150+
JSONRPCResponse(
151+
jsonrpc="2.0",
152+
id=jsonrpc_request.root.id,
153+
result=result.model_dump(
154+
by_alias=True, mode="json", exclude_none=True
155+
),
156+
)
157+
)
158+
)
159+
# Receive initialized notification
160+
await client_to_server_receive.receive()
161+
162+
async with (
163+
ClientSession(
164+
server_to_client_receive,
165+
client_to_server_send,
166+
client_info=custom_client_info,
167+
) as session,
168+
anyio.create_task_group() as tg,
169+
client_to_server_send,
170+
client_to_server_receive,
171+
server_to_client_send,
172+
server_to_client_receive,
173+
):
174+
tg.start_soon(mock_server)
175+
await session.initialize()
176+
177+
# Assert that the custom client info was sent
178+
assert received_client_info == custom_client_info
179+
180+
181+
@pytest.mark.anyio
182+
async def test_client_session_default_client_info():
183+
client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[
184+
JSONRPCMessage
185+
](1)
186+
server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[
187+
JSONRPCMessage
188+
](1)
189+
190+
received_client_info = None
191+
192+
async def mock_server():
193+
nonlocal received_client_info
194+
195+
jsonrpc_request = await client_to_server_receive.receive()
196+
assert isinstance(jsonrpc_request.root, JSONRPCRequest)
197+
request = ClientRequest.model_validate(
198+
jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True)
199+
)
200+
assert isinstance(request.root, InitializeRequest)
201+
received_client_info = request.root.params.clientInfo
202+
203+
result = ServerResult(
204+
InitializeResult(
205+
protocolVersion=LATEST_PROTOCOL_VERSION,
206+
capabilities=ServerCapabilities(),
207+
serverInfo=Implementation(name="mock-server", version="0.1.0"),
208+
)
209+
)
210+
211+
async with server_to_client_send:
212+
await server_to_client_send.send(
213+
JSONRPCMessage(
214+
JSONRPCResponse(
215+
jsonrpc="2.0",
216+
id=jsonrpc_request.root.id,
217+
result=result.model_dump(
218+
by_alias=True, mode="json", exclude_none=True
219+
),
220+
)
221+
)
222+
)
223+
# Receive initialized notification
224+
await client_to_server_receive.receive()
225+
226+
async with (
227+
ClientSession(
228+
server_to_client_receive,
229+
client_to_server_send,
230+
) as session,
231+
anyio.create_task_group() as tg,
232+
client_to_server_send,
233+
client_to_server_receive,
234+
server_to_client_send,
235+
server_to_client_receive,
236+
):
237+
tg.start_soon(mock_server)
238+
await session.initialize()
239+
240+
# Assert that the default client info was sent
241+
assert received_client_info == DEFAULT_CLIENT_INFO

0 commit comments

Comments
 (0)