diff --git a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py index b5faffedb..71d4e5a37 100644 --- a/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py +++ b/examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py @@ -7,7 +7,7 @@ import click import mcp.types as types from mcp.server.lowlevel import Server -from mcp.server.streamableHttp import ( +from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, StreamableHTTPServerTransport, ) diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py new file mode 100644 index 000000000..1e0042428 --- /dev/null +++ b/src/mcp/client/streamable_http.py @@ -0,0 +1,258 @@ +""" +StreamableHTTP Client Transport Module + +This module implements the StreamableHTTP transport for MCP clients, +providing support for HTTP POST requests with optional SSE streaming responses +and session management. +""" + +import logging +from contextlib import asynccontextmanager +from datetime import timedelta +from typing import Any + +import anyio +import httpx +from httpx_sse import EventSource, aconnect_sse + +from mcp.types import ( + ErrorData, + JSONRPCError, + JSONRPCMessage, + JSONRPCNotification, + JSONRPCRequest, +) + +logger = logging.getLogger(__name__) + +# Header names +MCP_SESSION_ID_HEADER = "mcp-session-id" +LAST_EVENT_ID_HEADER = "last-event-id" + +# Content types +CONTENT_TYPE_JSON = "application/json" +CONTENT_TYPE_SSE = "text/event-stream" + + +@asynccontextmanager +async def streamablehttp_client( + url: str, + headers: dict[str, Any] | None = None, + timeout: timedelta = timedelta(seconds=30), + sse_read_timeout: timedelta = timedelta(seconds=60 * 5), +): + """ + Client transport for StreamableHTTP. + + `sse_read_timeout` determines how long (in seconds) the client will wait for a new + event before disconnecting. All other HTTP operations are controlled by `timeout`. + + Yields: + Tuple of (read_stream, write_stream, terminate_callback) + """ + + read_stream_writer, read_stream = anyio.create_memory_object_stream[ + JSONRPCMessage | Exception + ](0) + write_stream, write_stream_reader = anyio.create_memory_object_stream[ + JSONRPCMessage + ](0) + + async def get_stream(): + """ + Optional GET stream for server-initiated messages + """ + nonlocal session_id + try: + # Only attempt GET if we have a session ID + if not session_id: + return + + get_headers = request_headers.copy() + get_headers[MCP_SESSION_ID_HEADER] = session_id + + async with aconnect_sse( + client, + "GET", + url, + headers=get_headers, + timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("GET SSE connection established") + + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + message = JSONRPCMessage.model_validate_json(sse.data) + logger.debug(f"GET message: {message}") + await read_stream_writer.send(message) + except Exception as exc: + logger.error(f"Error parsing GET message: {exc}") + await read_stream_writer.send(exc) + else: + logger.warning(f"Unknown SSE event from GET: {sse.event}") + except Exception as exc: + # GET stream is optional, so don't propagate errors + logger.debug(f"GET stream error (non-fatal): {exc}") + + async def post_writer(client: httpx.AsyncClient): + nonlocal session_id + try: + async with write_stream_reader: + async for message in write_stream_reader: + # Add session ID to headers if we have one + post_headers = request_headers.copy() + if session_id: + post_headers[MCP_SESSION_ID_HEADER] = session_id + + logger.debug(f"Sending client message: {message}") + + # Handle initial initialization request + is_initialization = ( + isinstance(message.root, JSONRPCRequest) + and message.root.method == "initialize" + ) + if ( + isinstance(message.root, JSONRPCNotification) + and message.root.method == "notifications/initialized" + ): + tg.start_soon(get_stream) + + async with client.stream( + "POST", + url, + json=message.model_dump( + by_alias=True, mode="json", exclude_none=True + ), + headers=post_headers, + ) as response: + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + # Check for 404 (session expired/invalid) + if response.status_code == 404: + if isinstance(message.root, JSONRPCRequest): + jsonrpc_error = JSONRPCError( + jsonrpc="2.0", + id=message.root.id, + error=ErrorData( + code=32600, + message="Session terminated", + ), + ) + await read_stream_writer.send( + JSONRPCMessage(jsonrpc_error) + ) + continue + response.raise_for_status() + + # Extract session ID from response headers + if is_initialization: + new_session_id = response.headers.get(MCP_SESSION_ID_HEADER) + if new_session_id: + session_id = new_session_id + logger.info(f"Received session ID: {session_id}") + + # Handle different response types + content_type = response.headers.get("content-type", "").lower() + + if content_type.startswith(CONTENT_TYPE_JSON): + try: + content = await response.aread() + json_message = JSONRPCMessage.model_validate_json( + content + ) + await read_stream_writer.send(json_message) + except Exception as exc: + logger.error(f"Error parsing JSON response: {exc}") + await read_stream_writer.send(exc) + + elif content_type.startswith(CONTENT_TYPE_SSE): + # Parse SSE events from the response + try: + event_source = EventSource(response) + async for sse in event_source.aiter_sse(): + if sse.event == "message": + try: + await read_stream_writer.send( + JSONRPCMessage.model_validate_json( + sse.data + ) + ) + except Exception as exc: + logger.exception("Error parsing message") + await read_stream_writer.send(exc) + else: + logger.warning(f"Unknown event: {sse.event}") + + except Exception as e: + logger.exception("Error reading SSE stream:") + await read_stream_writer.send(e) + + else: + # For 202 Accepted with no body + if response.status_code == 202: + logger.debug("Received 202 Accepted") + continue + + error_msg = f"Unexpected content type: {content_type}" + logger.error(error_msg) + await read_stream_writer.send(ValueError(error_msg)) + + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + + async def terminate_session(): + """ + Terminate the session by sending a DELETE request. + """ + nonlocal session_id + if not session_id: + return # No session to terminate + + try: + delete_headers = request_headers.copy() + delete_headers[MCP_SESSION_ID_HEADER] = session_id + + response = await client.delete( + url, + headers=delete_headers, + ) + + if response.status_code == 405: + # Server doesn't allow client-initiated termination + logger.debug("Server does not allow session termination") + elif response.status_code != 200: + logger.warning(f"Session termination failed: {response.status_code}") + except Exception as exc: + logger.warning(f"Session termination failed: {exc}") + + async with anyio.create_task_group() as tg: + try: + logger.info(f"Connecting to StreamableHTTP endpoint: {url}") + # Set up headers with required Accept header + request_headers = { + "Accept": f"{CONTENT_TYPE_JSON}, {CONTENT_TYPE_SSE}", + "Content-Type": CONTENT_TYPE_JSON, + **(headers or {}), + } + # Track session ID if provided by server + session_id: str | None = None + + async with httpx.AsyncClient( + headers=request_headers, + timeout=httpx.Timeout(timeout.seconds, read=sse_read_timeout.seconds), + follow_redirects=True, + ) as client: + tg.start_soon(post_writer, client) + try: + yield read_stream, write_stream, terminate_session + finally: + tg.cancel_scope.cancel() + finally: + await read_stream_writer.aclose() + await write_stream.aclose() diff --git a/src/mcp/server/streamableHttp.py b/src/mcp/server/streamable_http.py similarity index 100% rename from src/mcp/server/streamableHttp.py rename to src/mcp/server/streamable_http.py diff --git a/tests/server/test_streamableHttp.py b/tests/shared/test_streamable_http.py similarity index 69% rename from tests/server/test_streamableHttp.py rename to tests/shared/test_streamable_http.py index f612575c3..48af09536 100644 --- a/tests/server/test_streamableHttp.py +++ b/tests/shared/test_streamable_http.py @@ -1,7 +1,7 @@ """ -Tests for the StreamableHTTP server transport validation. +Tests for the StreamableHTTP server and client transport. -This file contains tests for request validation in the StreamableHTTP transport. +Contains tests for both server and client sides of the StreamableHTTP transport. """ import contextlib @@ -13,6 +13,7 @@ from uuid import uuid4 import anyio +import httpx import pytest import requests import uvicorn @@ -22,18 +23,16 @@ from starlette.responses import Response from starlette.routing import Mount +from mcp.client.session import ClientSession +from mcp.client.streamable_http import streamablehttp_client from mcp.server import Server -from mcp.server.streamableHttp import ( +from mcp.server.streamable_http import ( MCP_SESSION_ID_HEADER, SESSION_ID_PATTERN, StreamableHTTPServerTransport, ) from mcp.shared.exceptions import McpError -from mcp.types import ( - ErrorData, - TextContent, - Tool, -) +from mcp.types import InitializeResult, TextContent, TextResourceContents, Tool # Test constants SERVER_NAME = "test_streamable_http_server" @@ -64,11 +63,7 @@ async def handle_read_resource(uri: AnyUrl) -> str | bytes: await anyio.sleep(2.0) return f"Slow response from {uri.host}" - raise McpError( - error=ErrorData( - code=404, message="OOPS! no resource with that URI was found" - ) - ) + raise ValueError(f"Unknown resource: {uri}") @self.list_tools() async def handle_list_tools() -> list[Tool]: @@ -77,11 +72,23 @@ async def handle_list_tools() -> list[Tool]: name="test_tool", description="A test tool", inputSchema={"type": "object", "properties": {}}, - ) + ), + Tool( + name="test_tool_with_standalone_notification", + description="A test tool that sends a notification", + inputSchema={"type": "object", "properties": {}}, + ), ] @self.call_tool() async def handle_call_tool(name: str, args: dict) -> list[TextContent]: + # When the tool is called, send a notification to test GET stream + if name == "test_tool_with_standalone_notification": + ctx = self.request_context + await ctx.session.send_resource_updated( + uri=AnyUrl("http://test_resource") + ) + return [TextContent(type="text", text=f"Called {name}")] @@ -630,3 +637,219 @@ def test_get_validation(basic_server, basic_server_url): ) assert response.status_code == 406 assert "Not Acceptable" in response.text + + +# Client-specific fixtures +@pytest.fixture +async def http_client(basic_server, basic_server_url): + """Create test client matching the SSE test pattern.""" + async with httpx.AsyncClient(base_url=basic_server_url) as client: + yield client + + +@pytest.fixture +async def initialized_client_session(basic_server, basic_server_url): + """Create initialized StreamableHTTP client session.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + await session.initialize() + yield session + + +@pytest.mark.anyio +async def test_streamablehttp_client_basic_connection(basic_server, basic_server_url): + """Test basic client connection with initialization.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + +@pytest.mark.anyio +async def test_streamablehttp_client_resource_read(initialized_client_session): + """Test client resource read functionality.""" + response = await initialized_client_session.read_resource( + uri=AnyUrl("foobar://test-resource") + ) + assert len(response.contents) == 1 + assert response.contents[0].uri == AnyUrl("foobar://test-resource") + assert response.contents[0].text == "Read test-resource" + + +@pytest.mark.anyio +async def test_streamablehttp_client_tool_invocation(initialized_client_session): + """Test client tool invocation.""" + # First list tools + tools = await initialized_client_session.list_tools() + assert len(tools.tools) == 2 + assert tools.tools[0].name == "test_tool" + + # Call the tool + result = await initialized_client_session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_error_handling(initialized_client_session): + """Test error handling in client.""" + with pytest.raises(McpError) as exc_info: + await initialized_client_session.read_resource( + uri=AnyUrl("unknown://test-error") + ) + assert exc_info.value.error.code == 0 + assert "Unknown resource: unknown://test-error" in exc_info.value.error.message + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_persistence( + basic_server, basic_server_url +): + """Test that session ID persists across requests.""" + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make multiple requests to verify session persistence + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # Read a resource + resource = await session.read_resource(uri=AnyUrl("foobar://test-persist")) + assert isinstance(resource.contents[0], TextResourceContents) is True + content = resource.contents[0] + assert isinstance(content, TextResourceContents) + assert content.text == "Read test-persist" + + +@pytest.mark.anyio +async def test_streamablehttp_client_json_response( + json_response_server, json_server_url +): + """Test client with JSON response mode.""" + async with streamablehttp_client(f"{json_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, + write_stream, + ) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Check tool listing + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # Call a tool and verify JSON response handling + result = await session.call_tool("test_tool", {}) + assert len(result.content) == 1 + assert result.content[0].type == "text" + assert result.content[0].text == "Called test_tool" + + +@pytest.mark.anyio +async def test_streamablehttp_client_get_stream(basic_server, basic_server_url): + """Test GET stream functionality for server-initiated messages.""" + import mcp.types as types + from mcp.shared.session import RequestResponder + + notifications_received = [] + + # Define message handler to capture notifications + async def message_handler( + message: RequestResponder[types.ServerRequest, types.ClientResult] + | types.ServerNotification + | Exception, + ) -> None: + if isinstance(message, types.ServerNotification): + notifications_received.append(message) + + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + _, + ): + async with ClientSession( + read_stream, write_stream, message_handler=message_handler + ) as session: + # Initialize the session - this triggers the GET stream setup + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Call the special tool that sends a notification + await session.call_tool("test_tool_with_standalone_notification", {}) + + # Verify we received the notification + assert len(notifications_received) > 0 + + # Verify the notification is a ResourceUpdatedNotification + resource_update_found = False + for notif in notifications_received: + if isinstance(notif.root, types.ResourceUpdatedNotification): + assert str(notif.root.params.uri) == "http://test_resource/" + resource_update_found = True + + assert ( + resource_update_found + ), "ResourceUpdatedNotification not received via GET stream" + + +@pytest.mark.anyio +async def test_streamablehttp_client_session_termination( + basic_server, basic_server_url +): + """Test client session termination functionality.""" + + # Create the streamablehttp_client with a custom httpx client to capture headers + async with streamablehttp_client(f"{basic_server_url}/mcp") as ( + read_stream, + write_stream, + terminate_session, + ): + async with ClientSession(read_stream, write_stream) as session: + # Initialize the session + result = await session.initialize() + assert isinstance(result, InitializeResult) + + # Make a request to confirm session is working + tools = await session.list_tools() + assert len(tools.tools) == 2 + + # After exiting ClientSession context, explicitly terminate the session + await terminate_session() + with pytest.raises( + McpError, + match="Session terminated", + ): + await session.list_tools() diff --git a/uv.lock b/uv.lock index cbdc33471..06dd240b2 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" [options] @@ -547,6 +548,7 @@ requires-dist = [ { name = "uvicorn", marker = "sys_platform != 'emscripten'", specifier = ">=0.23.1" }, { name = "websockets", marker = "extra == 'ws'", specifier = ">=15.0.1" }, ] +provides-extras = ["cli", "rich", "ws"] [package.metadata.requires-dev] dev = [