|
5 | 5 | from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
6 | 6 | from mcp.shared.message import SessionMessage
|
7 | 7 | from mcp.shared.session import RequestResponder
|
| 8 | +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS |
8 | 9 | from mcp.types import (
|
9 | 10 | LATEST_PROTOCOL_VERSION,
|
10 | 11 | ClientNotification,
|
@@ -250,3 +251,132 @@ async def mock_server():
|
250 | 251 |
|
251 | 252 | # Assert that the default client info was sent
|
252 | 253 | assert received_client_info == DEFAULT_CLIENT_INFO
|
| 254 | + |
| 255 | + |
| 256 | +@pytest.mark.anyio |
| 257 | +async def test_client_session_version_negotiation_success(): |
| 258 | + """Test successful version negotiation with supported version""" |
| 259 | + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ |
| 260 | + SessionMessage |
| 261 | + ](1) |
| 262 | + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ |
| 263 | + SessionMessage |
| 264 | + ](1) |
| 265 | + |
| 266 | + async def mock_server(): |
| 267 | + session_message = await client_to_server_receive.receive() |
| 268 | + jsonrpc_request = session_message.message |
| 269 | + assert isinstance(jsonrpc_request.root, JSONRPCRequest) |
| 270 | + request = ClientRequest.model_validate( |
| 271 | + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 272 | + ) |
| 273 | + assert isinstance(request.root, InitializeRequest) |
| 274 | + |
| 275 | + # Verify client sent the latest protocol version |
| 276 | + assert request.root.params.protocolVersion == LATEST_PROTOCOL_VERSION |
| 277 | + |
| 278 | + # Server responds with a supported older version |
| 279 | + result = ServerResult( |
| 280 | + InitializeResult( |
| 281 | + protocolVersion="2024-11-05", |
| 282 | + capabilities=ServerCapabilities(), |
| 283 | + serverInfo=Implementation(name="mock-server", version="0.1.0"), |
| 284 | + ) |
| 285 | + ) |
| 286 | + |
| 287 | + async with server_to_client_send: |
| 288 | + await server_to_client_send.send( |
| 289 | + SessionMessage( |
| 290 | + JSONRPCMessage( |
| 291 | + JSONRPCResponse( |
| 292 | + jsonrpc="2.0", |
| 293 | + id=jsonrpc_request.root.id, |
| 294 | + result=result.model_dump( |
| 295 | + by_alias=True, mode="json", exclude_none=True |
| 296 | + ), |
| 297 | + ) |
| 298 | + ) |
| 299 | + ) |
| 300 | + ) |
| 301 | + # Receive initialized notification |
| 302 | + await client_to_server_receive.receive() |
| 303 | + |
| 304 | + async with ( |
| 305 | + ClientSession( |
| 306 | + server_to_client_receive, |
| 307 | + client_to_server_send, |
| 308 | + ) as session, |
| 309 | + anyio.create_task_group() as tg, |
| 310 | + client_to_server_send, |
| 311 | + client_to_server_receive, |
| 312 | + server_to_client_send, |
| 313 | + server_to_client_receive, |
| 314 | + ): |
| 315 | + tg.start_soon(mock_server) |
| 316 | + result = await session.initialize() |
| 317 | + |
| 318 | + # Assert the result with negotiated version |
| 319 | + assert isinstance(result, InitializeResult) |
| 320 | + assert result.protocolVersion == "2024-11-05" |
| 321 | + assert result.protocolVersion in SUPPORTED_PROTOCOL_VERSIONS |
| 322 | + |
| 323 | + |
| 324 | +@pytest.mark.anyio |
| 325 | +async def test_client_session_version_negotiation_failure(): |
| 326 | + """Test version negotiation failure with unsupported version""" |
| 327 | + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ |
| 328 | + SessionMessage |
| 329 | + ](1) |
| 330 | + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ |
| 331 | + SessionMessage |
| 332 | + ](1) |
| 333 | + |
| 334 | + async def mock_server(): |
| 335 | + session_message = await client_to_server_receive.receive() |
| 336 | + jsonrpc_request = session_message.message |
| 337 | + assert isinstance(jsonrpc_request.root, JSONRPCRequest) |
| 338 | + request = ClientRequest.model_validate( |
| 339 | + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 340 | + ) |
| 341 | + assert isinstance(request.root, InitializeRequest) |
| 342 | + |
| 343 | + # Server responds with an unsupported version |
| 344 | + result = ServerResult( |
| 345 | + InitializeResult( |
| 346 | + protocolVersion="2020-01-01", # Unsupported old version |
| 347 | + capabilities=ServerCapabilities(), |
| 348 | + serverInfo=Implementation(name="mock-server", version="0.1.0"), |
| 349 | + ) |
| 350 | + ) |
| 351 | + |
| 352 | + async with server_to_client_send: |
| 353 | + await server_to_client_send.send( |
| 354 | + SessionMessage( |
| 355 | + JSONRPCMessage( |
| 356 | + JSONRPCResponse( |
| 357 | + jsonrpc="2.0", |
| 358 | + id=jsonrpc_request.root.id, |
| 359 | + result=result.model_dump( |
| 360 | + by_alias=True, mode="json", exclude_none=True |
| 361 | + ), |
| 362 | + ) |
| 363 | + ) |
| 364 | + ) |
| 365 | + ) |
| 366 | + |
| 367 | + async with ( |
| 368 | + ClientSession( |
| 369 | + server_to_client_receive, |
| 370 | + client_to_server_send, |
| 371 | + ) as session, |
| 372 | + anyio.create_task_group() as tg, |
| 373 | + client_to_server_send, |
| 374 | + client_to_server_receive, |
| 375 | + server_to_client_send, |
| 376 | + server_to_client_receive, |
| 377 | + ): |
| 378 | + tg.start_soon(mock_server) |
| 379 | + |
| 380 | + # Should raise RuntimeError for unsupported version |
| 381 | + with pytest.raises(RuntimeError, match="Unsupported protocol version"): |
| 382 | + await session.initialize() |
0 commit comments