diff --git a/src/mcp/server/sse.py b/src/mcp/server/sse.py index cc41a80d6..a6350a39b 100644 --- a/src/mcp/server/sse.py +++ b/src/mcp/server/sse.py @@ -100,10 +100,26 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): write_stream, write_stream_reader = anyio.create_memory_object_stream(0) session_id = uuid4() - session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}" self._read_stream_writers[session_id] = read_stream_writer logger.debug(f"Created new session with ID: {session_id}") + # Determine the full path for the message endpoint to be sent to the client. + # scope['root_path'] is the prefix where the current Starlette app + # instance is mounted. + # e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix". + root_path = scope.get("root_path", "") + + # self._endpoint is the path *within* this app, e.g., "/messages". + # Concatenating them gives the full absolute path from the server root. + # e.g., "" + "/messages" -> "/messages" + # e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages" + full_message_path_for_client = root_path.rstrip("/") + self._endpoint + + # This is the URI (path + query) the client will use to POST messages. + client_post_uri_data = ( + f"{quote(full_message_path_for_client)}?session_id={session_id.hex}" + ) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[ dict[str, Any] ](0) @@ -111,8 +127,10 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send): async def sse_writer(): logger.debug("Starting SSE writer") async with sse_stream_writer, write_stream_reader: - await sse_stream_writer.send({"event": "endpoint", "data": session_uri}) - logger.debug(f"Sent endpoint event: {session_uri}") + await sse_stream_writer.send( + {"event": "endpoint", "data": client_post_uri_data} + ) + logger.debug(f"Sent endpoint event: {client_post_uri_data}") async for session_message in write_stream_reader: logger.debug(f"Sending message via SSE: {session_message}") diff --git a/tests/shared/test_sse.py b/tests/shared/test_sse.py index 4558bb88c..e55983e01 100644 --- a/tests/shared/test_sse.py +++ b/tests/shared/test_sse.py @@ -252,3 +252,69 @@ async def test_sse_client_timeout( return pytest.fail("the client should have timed out and returned an error already") + + +def run_mounted_server(server_port: int) -> None: + app = make_server_app() + main_app = Starlette(routes=[Mount("/mounted_app", app=app)]) + server = uvicorn.Server( + config=uvicorn.Config( + app=main_app, host="127.0.0.1", port=server_port, log_level="error" + ) + ) + print(f"starting server on {server_port}") + server.run() + + # Give server time to start + while not server.started: + print("waiting for server to start") + time.sleep(0.5) + + +@pytest.fixture() +def mounted_server(server_port: int) -> Generator[None, None, None]: + proc = multiprocessing.Process( + target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True + ) + print("starting process") + proc.start() + + # Wait for server to be running + max_attempts = 20 + attempt = 0 + print("waiting for server to start") + while attempt < max_attempts: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.connect(("127.0.0.1", server_port)) + break + except ConnectionRefusedError: + time.sleep(0.1) + attempt += 1 + else: + raise RuntimeError(f"Server failed to start after {max_attempts} attempts") + + yield + + print("killing server") + # Signal the server to stop + proc.kill() + proc.join(timeout=2) + if proc.is_alive(): + print("server process failed to terminate") + + +@pytest.mark.anyio +async def test_sse_client_basic_connection_mounted_app( + mounted_server: None, server_url: str +) -> None: + async with sse_client(server_url + "/mounted_app/sse") as streams: + async with ClientSession(*streams) as session: + # Test initialization + result = await session.initialize() + assert isinstance(result, InitializeResult) + assert result.serverInfo.name == SERVER_NAME + + # Test ping + ping_result = await session.send_ping() + assert isinstance(ping_result, EmptyResult)