Skip to content

Properly infer prefix for SSE messages #659

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/mcp/server/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,37 @@ async def connect_sse(self, scope: Scope, receive: Receive, send: Send):
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)

session_id = uuid4()
session_uri = f"{quote(self._endpoint)}?session_id={session_id.hex}"
self._read_stream_writers[session_id] = read_stream_writer
logger.debug(f"Created new session with ID: {session_id}")

# Determine the full path for the message endpoint to be sent to the client.
# scope['root_path'] is the prefix where the current Starlette app
# instance is mounted.
# e.g., "" if top-level, or "/api_prefix" if mounted under "/api_prefix".
root_path = scope.get("root_path", "")

# self._endpoint is the path *within* this app, e.g., "/messages".
# Concatenating them gives the full absolute path from the server root.
# e.g., "" + "/messages" -> "/messages"
# e.g., "/api_prefix" + "/messages" -> "/api_prefix/messages"
full_message_path_for_client = root_path.rstrip("/") + self._endpoint

# This is the URI (path + query) the client will use to POST messages.
client_post_uri_data = (
f"{quote(full_message_path_for_client)}?session_id={session_id.hex}"
)

sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[
dict[str, Any]
](0)

async def sse_writer():
logger.debug("Starting SSE writer")
async with sse_stream_writer, write_stream_reader:
await sse_stream_writer.send({"event": "endpoint", "data": session_uri})
logger.debug(f"Sent endpoint event: {session_uri}")
await sse_stream_writer.send(
{"event": "endpoint", "data": client_post_uri_data}
)
logger.debug(f"Sent endpoint event: {client_post_uri_data}")

async for session_message in write_stream_reader:
logger.debug(f"Sending message via SSE: {session_message}")
Expand Down
66 changes: 66 additions & 0 deletions tests/shared/test_sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,69 @@ async def test_sse_client_timeout(
return

pytest.fail("the client should have timed out and returned an error already")


def run_mounted_server(server_port: int) -> None:
app = make_server_app()
main_app = Starlette(routes=[Mount("/mounted_app", app=app)])
server = uvicorn.Server(
config=uvicorn.Config(
app=main_app, host="127.0.0.1", port=server_port, log_level="error"
)
)
print(f"starting server on {server_port}")
server.run()

# Give server time to start
while not server.started:
print("waiting for server to start")
time.sleep(0.5)


@pytest.fixture()
def mounted_server(server_port: int) -> Generator[None, None, None]:
proc = multiprocessing.Process(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note for the future, need to extract it into a function...

target=run_mounted_server, kwargs={"server_port": server_port}, daemon=True
)
print("starting process")
proc.start()

# Wait for server to be running
max_attempts = 20
attempt = 0
print("waiting for server to start")
while attempt < max_attempts:
try:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.connect(("127.0.0.1", server_port))
break
except ConnectionRefusedError:
time.sleep(0.1)
attempt += 1
else:
raise RuntimeError(f"Server failed to start after {max_attempts} attempts")

yield

print("killing server")
# Signal the server to stop
proc.kill()
proc.join(timeout=2)
if proc.is_alive():
print("server process failed to terminate")


@pytest.mark.anyio
async def test_sse_client_basic_connection_mounted_app(
mounted_server: None, server_url: str
) -> None:
async with sse_client(server_url + "/mounted_app/sse") as streams:
async with ClientSession(*streams) as session:
# Test initialization
result = await session.initialize()
assert isinstance(result, InitializeResult)
assert result.serverInfo.name == SERVER_NAME

# Test ping
ping_result = await session.send_ping()
assert isinstance(ping_result, EmptyResult)
Loading