diff --git a/pyproject.toml b/pyproject.toml index 6ff2601e9..721642612 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dependencies = [ "sse-starlette>=1.6.1", "pydantic-settings>=2.5.2", "uvicorn>=0.23.1; sys_platform != 'emscripten'", + "exceptiongroup>=1.2.0", ] [project.optional-dependencies] diff --git a/src/mcp/client/sse.py b/src/mcp/client/sse.py index 7df251f79..7030ce9e3 100644 --- a/src/mcp/client/sse.py +++ b/src/mcp/client/sse.py @@ -7,6 +7,7 @@ import httpx from anyio.abc import TaskStatus from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream +from exceptiongroup import BaseExceptionGroup, catch from httpx_sse import aconnect_sse import mcp.types as types @@ -19,6 +20,12 @@ def remove_request_params(url: str) -> str: return urljoin(url, urlparse(url).path) +def handle_exception(exc: BaseExceptionGroup[Exception]) -> str: + """Handle ExceptionGroup and Exceptions for Client transport for SSE""" + messages = "; ".join(str(e) for e in exc.exceptions) + raise Exception(f"TaskGroup failed with: {messages}") from None + + @asynccontextmanager async def sse_client( url: str, @@ -41,114 +48,118 @@ async def sse_client( read_stream_writer, read_stream = anyio.create_memory_object_stream(0) write_stream, write_stream_reader = anyio.create_memory_object_stream(0) - async with anyio.create_task_group() as tg: - try: - logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") - async with httpx.AsyncClient(headers=headers) as client: - async with aconnect_sse( - client, - "GET", - url, - timeout=httpx.Timeout(timeout, read=sse_read_timeout), - ) as event_source: - event_source.response.raise_for_status() - logger.debug("SSE connection established") - - async def sse_reader( - task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, - ): - try: - async for sse in event_source.aiter_sse(): - logger.debug(f"Received SSE event: {sse.event}") - match sse.event: - case "endpoint": - endpoint_url = urljoin(url, sse.data) - logger.info( - f"Received endpoint URL: {endpoint_url}" + with catch({Exception: handle_exception}): + logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}") + async with httpx.AsyncClient(headers=headers) as client: + async with aconnect_sse( + client, + "GET", + url, + timeout=httpx.Timeout(timeout, read=sse_read_timeout), + ) as event_source: + event_source.response.raise_for_status() + logger.debug("SSE connection established") + + async def sse_reader( + task_status: TaskStatus[str] = anyio.TASK_STATUS_IGNORED, + ): + try: + async for sse in event_source.aiter_sse(): + logger.debug(f"Received SSE event: {sse.event}") + match sse.event: + case "endpoint": + endpoint_url = urljoin(url, sse.data) + logger.info( + f"Received endpoint URL: {endpoint_url}" + ) + + url_parsed = urlparse(url) + endpoint_parsed = urlparse(endpoint_url) + if ( + url_parsed.netloc + != endpoint_parsed.netloc + or url_parsed.scheme + != endpoint_parsed.scheme + ): + error_msg = ( + "Endpoint origin does not match " + f"connection origin: {endpoint_url}" ) + logger.error(error_msg) + raise ValueError(error_msg) + + task_status.started(endpoint_url) - url_parsed = urlparse(url) - endpoint_parsed = urlparse(endpoint_url) - if ( - url_parsed.netloc != endpoint_parsed.netloc - or url_parsed.scheme - != endpoint_parsed.scheme - ): - error_msg = ( - "Endpoint origin does not match " - f"connection origin: {endpoint_url}" - ) - logger.error(error_msg) - raise ValueError(error_msg) - - task_status.started(endpoint_url) - - case "message": - try: - message = types.JSONRPCMessage.model_validate_json( # noqa: E501 - sse.data - ) - logger.debug( - f"Received server message: {message}" - ) - except Exception as exc: - logger.error( - f"Error parsing server message: {exc}" - ) - await read_stream_writer.send(exc) - continue - - session_message = SessionMessage( - message=message + case "message": + try: + message = types.JSONRPCMessage.model_validate_json( # noqa: E501 + sse.data ) - await read_stream_writer.send(session_message) - case _: - logger.warning( - f"Unknown SSE event: {sse.event}" + logger.debug( + "Received server message: " + f"{message}" ) - except Exception as exc: - logger.error(f"Error in sse_reader: {exc}") - await read_stream_writer.send(exc) - finally: - await read_stream_writer.aclose() - - async def post_writer(endpoint_url: str): - try: - async with write_stream_reader: - async for session_message in write_stream_reader: - logger.debug( - f"Sending client message: {session_message}" + except Exception as exc: + logger.error( + "Error parsing server message: " + f"{exc}" + ) + await read_stream_writer.send(exc) + continue + + session_message = SessionMessage( + message=message ) - response = await client.post( - endpoint_url, - json=session_message.message.model_dump( - by_alias=True, - mode="json", - exclude_none=True, - ), + await read_stream_writer.send( + session_message ) - response.raise_for_status() - logger.debug( - "Client message sent successfully: " - f"{response.status_code}" + case _: + logger.warning( + f"Unknown SSE event: {sse.event}" ) - except Exception as exc: - logger.error(f"Error in post_writer: {exc}") - finally: - await write_stream.aclose() - - endpoint_url = await tg.start(sse_reader) - logger.info( - f"Starting post writer with endpoint URL: {endpoint_url}" - ) - tg.start_soon(post_writer, endpoint_url) + except Exception as exc: + logger.error(f"Error in sse_reader: {exc}") + await read_stream_writer.send(exc) + finally: + await read_stream_writer.aclose() + async def post_writer(endpoint_url: str): try: - yield read_stream, write_stream + async with write_stream_reader: + async for session_message in write_stream_reader: + logger.debug( + f"Sending client message: {session_message}" + ) + response = await client.post( + endpoint_url, + json=session_message.message.model_dump( + by_alias=True, + mode="json", + exclude_none=True, + ), + ) + response.raise_for_status() + logger.debug( + "Client message sent successfully: " + f"{response.status_code}" + ) + except Exception as exc: + logger.error(f"Error in post_writer: {exc}") finally: - tg.cancel_scope.cancel() - finally: - await read_stream_writer.aclose() - await write_stream.aclose() - await read_stream.aclose() - await write_stream_reader.aclose() + await write_stream.aclose() + + try: + async with anyio.create_task_group() as tg: + endpoint_url = await tg.start(sse_reader) + logger.info( + f"Starting post writer with endpoint URL: {endpoint_url}" + ) + tg.start_soon(post_writer, endpoint_url) + + # Move streams outside + yield read_stream, write_stream + finally: + await read_stream_writer.aclose() + await write_stream.aclose() + await read_stream.aclose() + await write_stream_reader.aclose() diff --git a/uv.lock b/uv.lock index e819dbfe8..ee261891e 100644 --- a/uv.lock +++ b/uv.lock @@ -559,6 +559,7 @@ name = "mcp" source = { editable = "." } dependencies = [ { name = "anyio" }, + { name = "exceptiongroup" }, { name = "httpx" }, { name = "httpx-sse" }, { name = "pydantic" }, @@ -607,6 +608,7 @@ docs = [ [package.metadata] requires-dist = [ { name = "anyio", specifier = ">=4.5" }, + { name = "exceptiongroup", specifier = ">=1.2.0" }, { name = "httpx", specifier = ">=0.27" }, { name = "httpx-sse", specifier = ">=0.4" }, { name = "pydantic", specifier = ">=2.7.2,<3.0.0" },