Skip to content

Commit 3b1b213

Browse files
authored
Add message queue for SSE messages POST endpoint (#459)
1 parent 58c5e72 commit 3b1b213

File tree

26 files changed

+1247
-50
lines changed

26 files changed

+1247
-50
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,30 @@ app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app()))
412412

413413
For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes).
414414

415+
#### Message Dispatch Options
416+
417+
By default, the SSE server uses an in-memory message dispatch system for incoming POST messages. For production deployments or distributed scenarios, you can use Redis or implement your own message dispatch system that conforms to the `MessageDispatch` protocol:
418+
419+
```python
420+
# Using the built-in Redis message dispatch
421+
from mcp.server.fastmcp import FastMCP
422+
from mcp.server.message_queue import RedisMessageDispatch
423+
424+
# Create a Redis message dispatch
425+
redis_dispatch = RedisMessageDispatch(
426+
redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:"
427+
)
428+
429+
# Pass the message dispatch instance to the server
430+
mcp = FastMCP("My App", message_queue=redis_dispatch)
431+
```
432+
433+
To use Redis, add the Redis dependency:
434+
435+
```bash
436+
uv add "mcp[redis]"
437+
```
438+
415439
## Examples
416440

417441
### Echo Server

examples/servers/simple-prompt/mcp_simple_prompt/server.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,15 @@ async def get_prompt(
8888
)
8989

9090
if transport == "sse":
91+
from mcp.server.message_queue.redis import RedisMessageDispatch
9192
from mcp.server.sse import SseServerTransport
9293
from starlette.applications import Starlette
9394
from starlette.responses import Response
9495
from starlette.routing import Mount, Route
9596

96-
sse = SseServerTransport("/messages/")
97+
message_dispatch = RedisMessageDispatch("redis://localhost:6379/0")
98+
99+
sse = SseServerTransport("/messages/", message_dispatch=message_dispatch)
97100

98101
async def handle_sse(request):
99102
async with sse.connect_sse(

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies = [
3737
rich = ["rich>=13.9.4"]
3838
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
3939
ws = ["websockets>=15.0.1"]
40+
redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"]
4041

4142
[project.scripts]
4243
mcp = "mcp.cli:app [cli]"
@@ -55,6 +56,7 @@ dev = [
5556
"pytest-xdist>=3.6.1",
5657
"pytest-examples>=0.0.14",
5758
"pytest-pretty>=1.2.0",
59+
"fakeredis==2.28.1",
5860
]
5961
docs = [
6062
"mkdocs>=1.6.1",

src/mcp/client/sse.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,9 @@ async def sse_reader(
9898
await read_stream_writer.send(exc)
9999
continue
100100

101-
session_message = SessionMessage(message)
101+
session_message = SessionMessage(
102+
message=message
103+
)
102104
await read_stream_writer.send(session_message)
103105
case _:
104106
logger.warning(
@@ -148,3 +150,5 @@ async def post_writer(endpoint_url: str):
148150
finally:
149151
await read_stream_writer.aclose()
150152
await write_stream.aclose()
153+
await read_stream.aclose()
154+
await write_stream_reader.aclose()

src/mcp/client/stdio/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ async def stdout_reader():
144144
await read_stream_writer.send(exc)
145145
continue
146146

147-
session_message = SessionMessage(message)
147+
session_message = SessionMessage(message=message)
148148
await read_stream_writer.send(session_message)
149149
except anyio.ClosedResourceError:
150150
await anyio.lowlevel.checkpoint()

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ async def _handle_sse_event(
153153
):
154154
message.root.id = original_request_id
155155

156-
session_message = SessionMessage(message)
156+
session_message = SessionMessage(message=message)
157157
await read_stream_writer.send(session_message)
158158

159159
# Call resumption token callback if we have an ID
@@ -286,7 +286,7 @@ async def _handle_json_response(
286286
try:
287287
content = await response.aread()
288288
message = JSONRPCMessage.model_validate_json(content)
289-
session_message = SessionMessage(message)
289+
session_message = SessionMessage(message=message)
290290
await read_stream_writer.send(session_message)
291291
except Exception as exc:
292292
logger.error(f"Error parsing JSON response: {exc}")
@@ -333,7 +333,7 @@ async def _send_session_terminated_error(
333333
id=request_id,
334334
error=ErrorData(code=32600, message="Session terminated"),
335335
)
336-
session_message = SessionMessage(JSONRPCMessage(jsonrpc_error))
336+
session_message = SessionMessage(message=JSONRPCMessage(jsonrpc_error))
337337
await read_stream_writer.send(session_message)
338338

339339
async def post_writer(

src/mcp/client/websocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ async def ws_reader():
6060
async for raw_text in ws:
6161
try:
6262
message = types.JSONRPCMessage.model_validate_json(raw_text)
63-
session_message = SessionMessage(message)
63+
session_message = SessionMessage(message=message)
6464
await read_stream_writer.send(session_message)
6565
except ValidationError as exc:
6666
# If JSON parse or model validation fails, send the exception

src/mcp/server/fastmcp/server.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from mcp.server.lowlevel.server import LifespanResultT
4545
from mcp.server.lowlevel.server import Server as MCPServer
4646
from mcp.server.lowlevel.server import lifespan as default_lifespan
47+
from mcp.server.message_queue import MessageDispatch
4748
from mcp.server.session import ServerSession, ServerSessionT
4849
from mcp.server.sse import SseServerTransport
4950
from mcp.server.stdio import stdio_server
@@ -90,6 +91,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
9091
sse_path: str = "/sse"
9192
message_path: str = "/messages/"
9293

94+
# SSE message queue settings
95+
message_dispatch: MessageDispatch | None = Field(
96+
None, description="Custom message dispatch instance"
97+
)
98+
9399
# resource settings
94100
warn_on_duplicate_resources: bool = True
95101

@@ -569,12 +575,21 @@ async def run_sse_async(self) -> None:
569575

570576
def sse_app(self) -> Starlette:
571577
"""Return an instance of the SSE server app."""
578+
message_dispatch = self.settings.message_dispatch
579+
if message_dispatch is None:
580+
from mcp.server.message_queue import InMemoryMessageDispatch
581+
582+
message_dispatch = InMemoryMessageDispatch()
583+
logger.info("Using default in-memory message dispatch")
584+
572585
from starlette.middleware import Middleware
573586
from starlette.routing import Mount, Route
574587

575588
# Set up auth context and dependencies
576589

577-
sse = SseServerTransport(self.settings.message_path)
590+
sse = SseServerTransport(
591+
self.settings.message_path, message_dispatch=message_dispatch
592+
)
578593

579594
async def handle_sse(scope: Scope, receive: Receive, send: Send):
580595
# Add client ID from auth context into request context if available
@@ -589,7 +604,14 @@ async def handle_sse(scope: Scope, receive: Receive, send: Send):
589604
streams[1],
590605
self._mcp_server.create_initialization_options(),
591606
)
592-
return Response()
607+
return Response()
608+
609+
@asynccontextmanager
610+
async def lifespan(app: Starlette):
611+
try:
612+
yield
613+
finally:
614+
await message_dispatch.close()
593615

594616
# Create routes
595617
routes: list[Route | Mount] = []
@@ -666,7 +688,10 @@ async def sse_endpoint(request: Request) -> None:
666688

667689
# Create Starlette app with routes and middleware
668690
return Starlette(
669-
debug=self.settings.debug, routes=routes, middleware=middleware
691+
debug=self.settings.debug,
692+
routes=routes,
693+
middleware=middleware,
694+
lifespan=lifespan,
670695
)
671696

672697
async def list_prompts(self) -> list[MCPPrompt]:
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""
2+
Message Dispatch Module for MCP Server
3+
4+
This module implements dispatch interfaces for handling
5+
messages between clients and servers.
6+
"""
7+
8+
from mcp.server.message_queue.base import InMemoryMessageDispatch, MessageDispatch
9+
10+
# Try to import Redis implementation if available
11+
try:
12+
from mcp.server.message_queue.redis import RedisMessageDispatch
13+
except ImportError:
14+
RedisMessageDispatch = None
15+
16+
__all__ = ["MessageDispatch", "InMemoryMessageDispatch", "RedisMessageDispatch"]

src/mcp/server/message_queue/base.py

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import logging
2+
from collections.abc import Awaitable, Callable
3+
from contextlib import asynccontextmanager
4+
from typing import Protocol, runtime_checkable
5+
from uuid import UUID
6+
7+
from pydantic import ValidationError
8+
9+
from mcp.shared.message import SessionMessage
10+
11+
logger = logging.getLogger(__name__)
12+
13+
MessageCallback = Callable[[SessionMessage | Exception], Awaitable[None]]
14+
15+
16+
@runtime_checkable
17+
class MessageDispatch(Protocol):
18+
"""Abstract interface for SSE message dispatching.
19+
20+
This interface allows messages to be published to sessions and callbacks to be
21+
registered for message handling, enabling multiple servers to handle requests.
22+
"""
23+
24+
async def publish_message(
25+
self, session_id: UUID, message: SessionMessage | str
26+
) -> bool:
27+
"""Publish a message for the specified session.
28+
29+
Args:
30+
session_id: The UUID of the session this message is for
31+
message: The message to publish (SessionMessage or str for invalid JSON)
32+
33+
Returns:
34+
bool: True if message was published, False if session not found
35+
"""
36+
...
37+
38+
@asynccontextmanager
39+
async def subscribe(self, session_id: UUID, callback: MessageCallback):
40+
"""Request-scoped context manager that subscribes to messages for a session.
41+
42+
Args:
43+
session_id: The UUID of the session to subscribe to
44+
callback: Async callback function to handle messages for this session
45+
"""
46+
yield
47+
48+
async def session_exists(self, session_id: UUID) -> bool:
49+
"""Check if a session exists.
50+
51+
Args:
52+
session_id: The UUID of the session to check
53+
54+
Returns:
55+
bool: True if the session is active, False otherwise
56+
"""
57+
...
58+
59+
async def close(self) -> None:
60+
"""Close the message dispatch."""
61+
...
62+
63+
64+
class InMemoryMessageDispatch:
65+
"""Default in-memory implementation of the MessageDispatch interface.
66+
67+
This implementation immediately dispatches messages to registered callbacks when
68+
messages are received without any queuing behavior.
69+
"""
70+
71+
def __init__(self) -> None:
72+
self._callbacks: dict[UUID, MessageCallback] = {}
73+
74+
async def publish_message(
75+
self, session_id: UUID, message: SessionMessage | str
76+
) -> bool:
77+
"""Publish a message for the specified session."""
78+
if session_id not in self._callbacks:
79+
logger.warning(f"Message dropped: unknown session {session_id}")
80+
return False
81+
82+
# Parse string messages or recreate original ValidationError
83+
if isinstance(message, str):
84+
try:
85+
callback_argument = SessionMessage.model_validate_json(message)
86+
except ValidationError as exc:
87+
callback_argument = exc
88+
else:
89+
callback_argument = message
90+
91+
# Call the callback with either valid message or recreated ValidationError
92+
await self._callbacks[session_id](callback_argument)
93+
94+
logger.debug(f"Message dispatched to session {session_id}")
95+
return True
96+
97+
@asynccontextmanager
98+
async def subscribe(self, session_id: UUID, callback: MessageCallback):
99+
"""Request-scoped context manager that subscribes to messages for a session."""
100+
self._callbacks[session_id] = callback
101+
logger.debug(f"Subscribing to messages for session {session_id}")
102+
103+
try:
104+
yield
105+
finally:
106+
if session_id in self._callbacks:
107+
del self._callbacks[session_id]
108+
logger.debug(f"Unsubscribed from session {session_id}")
109+
110+
async def session_exists(self, session_id: UUID) -> bool:
111+
"""Check if a session exists."""
112+
return session_id in self._callbacks
113+
114+
async def close(self) -> None:
115+
"""Close the message dispatch."""
116+
pass

0 commit comments

Comments
 (0)