|
| 1 | +from typing import Any |
| 2 | + |
1 | 3 | import anyio
|
2 | 4 | import pytest
|
3 | 5 |
|
4 | 6 | import mcp.types as types
|
5 | 7 | from mcp.client.session import DEFAULT_CLIENT_INFO, ClientSession
|
| 8 | +from mcp.shared.context import RequestContext |
6 | 9 | from mcp.shared.message import SessionMessage
|
7 | 10 | from mcp.shared.session import RequestResponder
|
8 | 11 | from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS
|
@@ -380,3 +383,167 @@ async def mock_server():
|
380 | 383 | # Should raise RuntimeError for unsupported version
|
381 | 384 | with pytest.raises(RuntimeError, match="Unsupported protocol version"):
|
382 | 385 | await session.initialize()
|
| 386 | + |
| 387 | + |
| 388 | +@pytest.mark.anyio |
| 389 | +async def test_client_capabilities_default(): |
| 390 | + """Test that client capabilities are properly set with default callbacks""" |
| 391 | + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ |
| 392 | + SessionMessage |
| 393 | + ](1) |
| 394 | + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ |
| 395 | + SessionMessage |
| 396 | + ](1) |
| 397 | + |
| 398 | + received_capabilities = None |
| 399 | + |
| 400 | + async def mock_server(): |
| 401 | + nonlocal received_capabilities |
| 402 | + |
| 403 | + session_message = await client_to_server_receive.receive() |
| 404 | + jsonrpc_request = session_message.message |
| 405 | + assert isinstance(jsonrpc_request.root, JSONRPCRequest) |
| 406 | + request = ClientRequest.model_validate( |
| 407 | + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 408 | + ) |
| 409 | + assert isinstance(request.root, InitializeRequest) |
| 410 | + received_capabilities = request.root.params.capabilities |
| 411 | + |
| 412 | + result = ServerResult( |
| 413 | + InitializeResult( |
| 414 | + protocolVersion=LATEST_PROTOCOL_VERSION, |
| 415 | + capabilities=ServerCapabilities(), |
| 416 | + serverInfo=Implementation(name="mock-server", version="0.1.0"), |
| 417 | + ) |
| 418 | + ) |
| 419 | + |
| 420 | + async with server_to_client_send: |
| 421 | + await server_to_client_send.send( |
| 422 | + SessionMessage( |
| 423 | + JSONRPCMessage( |
| 424 | + JSONRPCResponse( |
| 425 | + jsonrpc="2.0", |
| 426 | + id=jsonrpc_request.root.id, |
| 427 | + result=result.model_dump( |
| 428 | + by_alias=True, mode="json", exclude_none=True |
| 429 | + ), |
| 430 | + ) |
| 431 | + ) |
| 432 | + ) |
| 433 | + ) |
| 434 | + # Receive initialized notification |
| 435 | + await client_to_server_receive.receive() |
| 436 | + |
| 437 | + async with ( |
| 438 | + ClientSession( |
| 439 | + server_to_client_receive, |
| 440 | + client_to_server_send, |
| 441 | + ) as session, |
| 442 | + anyio.create_task_group() as tg, |
| 443 | + client_to_server_send, |
| 444 | + client_to_server_receive, |
| 445 | + server_to_client_send, |
| 446 | + server_to_client_receive, |
| 447 | + ): |
| 448 | + tg.start_soon(mock_server) |
| 449 | + await session.initialize() |
| 450 | + |
| 451 | + # Assert that capabilities are properly set with defaults |
| 452 | + assert received_capabilities is not None |
| 453 | + assert received_capabilities.sampling is None # No custom sampling callback |
| 454 | + assert received_capabilities.roots is None # No custom list_roots callback |
| 455 | + |
| 456 | + |
| 457 | +@pytest.mark.anyio |
| 458 | +async def test_client_capabilities_with_custom_callbacks(): |
| 459 | + """Test that client capabilities are properly set with custom callbacks""" |
| 460 | + client_to_server_send, client_to_server_receive = anyio.create_memory_object_stream[ |
| 461 | + SessionMessage |
| 462 | + ](1) |
| 463 | + server_to_client_send, server_to_client_receive = anyio.create_memory_object_stream[ |
| 464 | + SessionMessage |
| 465 | + ](1) |
| 466 | + |
| 467 | + received_capabilities = None |
| 468 | + |
| 469 | + async def custom_sampling_callback( |
| 470 | + context: RequestContext["ClientSession", Any], |
| 471 | + params: types.CreateMessageRequestParams, |
| 472 | + ) -> types.CreateMessageResult | types.ErrorData: |
| 473 | + return types.CreateMessageResult( |
| 474 | + role="assistant", |
| 475 | + content=types.TextContent(type="text", text="test"), |
| 476 | + model="test-model", |
| 477 | + ) |
| 478 | + |
| 479 | + async def custom_list_roots_callback( |
| 480 | + context: RequestContext["ClientSession", Any], |
| 481 | + ) -> types.ListRootsResult | types.ErrorData: |
| 482 | + return types.ListRootsResult(roots=[]) |
| 483 | + |
| 484 | + async def mock_server(): |
| 485 | + nonlocal received_capabilities |
| 486 | + |
| 487 | + session_message = await client_to_server_receive.receive() |
| 488 | + jsonrpc_request = session_message.message |
| 489 | + assert isinstance(jsonrpc_request.root, JSONRPCRequest) |
| 490 | + request = ClientRequest.model_validate( |
| 491 | + jsonrpc_request.model_dump(by_alias=True, mode="json", exclude_none=True) |
| 492 | + ) |
| 493 | + assert isinstance(request.root, InitializeRequest) |
| 494 | + received_capabilities = request.root.params.capabilities |
| 495 | + |
| 496 | + result = ServerResult( |
| 497 | + InitializeResult( |
| 498 | + protocolVersion=LATEST_PROTOCOL_VERSION, |
| 499 | + capabilities=ServerCapabilities(), |
| 500 | + serverInfo=Implementation(name="mock-server", version="0.1.0"), |
| 501 | + ) |
| 502 | + ) |
| 503 | + |
| 504 | + async with server_to_client_send: |
| 505 | + await server_to_client_send.send( |
| 506 | + SessionMessage( |
| 507 | + JSONRPCMessage( |
| 508 | + JSONRPCResponse( |
| 509 | + jsonrpc="2.0", |
| 510 | + id=jsonrpc_request.root.id, |
| 511 | + result=result.model_dump( |
| 512 | + by_alias=True, mode="json", exclude_none=True |
| 513 | + ), |
| 514 | + ) |
| 515 | + ) |
| 516 | + ) |
| 517 | + ) |
| 518 | + # Receive initialized notification |
| 519 | + await client_to_server_receive.receive() |
| 520 | + |
| 521 | + async with ( |
| 522 | + ClientSession( |
| 523 | + server_to_client_receive, |
| 524 | + client_to_server_send, |
| 525 | + sampling_callback=custom_sampling_callback, |
| 526 | + list_roots_callback=custom_list_roots_callback, |
| 527 | + ) as session, |
| 528 | + anyio.create_task_group() as tg, |
| 529 | + client_to_server_send, |
| 530 | + client_to_server_receive, |
| 531 | + server_to_client_send, |
| 532 | + server_to_client_receive, |
| 533 | + ): |
| 534 | + tg.start_soon(mock_server) |
| 535 | + await session.initialize() |
| 536 | + |
| 537 | + # Assert that capabilities are properly set with custom callbacks |
| 538 | + assert received_capabilities is not None |
| 539 | + assert ( |
| 540 | + received_capabilities.sampling is not None |
| 541 | + ) # Custom sampling callback provided |
| 542 | + assert isinstance(received_capabilities.sampling, types.SamplingCapability) |
| 543 | + assert ( |
| 544 | + received_capabilities.roots is not None |
| 545 | + ) # Custom list_roots callback provided |
| 546 | + assert isinstance(received_capabilities.roots, types.RootsCapability) |
| 547 | + assert ( |
| 548 | + received_capabilities.roots.listChanged is True |
| 549 | + ) # Should be True for custom callback |
0 commit comments