Skip to content

Add message queue for SSE messages POST endpoint #459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
10c5af8
initial
akash329d Apr 8, 2025
c3d5efc
readme update
akash329d Apr 8, 2025
b2fce7d
ruff
akash329d Apr 8, 2025
7c82f36
fix typing issues
akash329d Apr 9, 2025
b92f22f
update lock
akash329d Apr 9, 2025
5dbca6e
retrigger tests?
akash329d Apr 9, 2025
badc1e2
revert
akash329d Apr 9, 2025
23665db
clean up test stuff
akash329d Apr 9, 2025
ccd5a13
lock pydantic version
akash329d Apr 9, 2025
fb44020
fix lock
akash329d Apr 9, 2025
efe6da9
wip
akash329d Apr 14, 2025
d625782
fixes
akash329d Apr 14, 2025
78c6aef
Add optional redis dep
akash329d Apr 14, 2025
fad836c
changes
akash329d Apr 14, 2025
fd97501
format / lint
akash329d Apr 14, 2025
4bce7d8
cleanup
akash329d Apr 14, 2025
d6075bb
update lock
akash329d Apr 14, 2025
8ee3a7e
remove redundant comment
akash329d Apr 14, 2025
7cabcea
add a checkpoint
akash329d Apr 14, 2025
5111c92
naming changes
akash329d Apr 15, 2025
09e0cab
logging improvements
akash329d Apr 15, 2025
8d280d8
better channel validation
akash329d Apr 15, 2025
c2bb049
merge
akash329d Apr 15, 2025
87e07b8
formatting and linting
akash329d Apr 15, 2025
b484284
fix naming in server.py
akash329d Apr 15, 2025
0bfd800
Rework to fix POST blocking issue
akash329d Apr 21, 2025
1e81f36
comments fix
akash329d Apr 21, 2025
215cc42
wip
akash329d Apr 22, 2025
8fce8e6
back to b48428486aa90f7529c36e5a78074ac2a2d813bc
akash329d Apr 22, 2025
b2893e6
push message handling onto corresponding SSE session task group
akash329d Apr 22, 2025
e5938d4
format
akash329d Apr 22, 2025
a151f1c
clean up comment and session state
akash329d Apr 22, 2025
d22f46b
shorten comment
akash329d Apr 22, 2025
8d6a20d
remove extra change
akash329d Apr 23, 2025
bb24881
testing
akash329d Apr 24, 2025
564561f
add a cancelscope on the finally
akash329d May 1, 2025
9419ad0
Move to session heartbeat w/ TTL
akash329d May 1, 2025
046ed94
add test for TTL
akash329d May 1, 2025
70547c0
merge conflict
akash329d May 5, 2025
5638653
merge fixes
akash329d May 5, 2025
2437e46
fakeredis dev dep
akash329d May 5, 2025
9664c8a
fmt
akash329d May 5, 2025
30b475b
convert to Pydantic models
akash329d May 5, 2025
0114189
fmt
akash329d May 5, 2025
7081a40
more type fixes
akash329d May 5, 2025
5ae3cc6
test cleanup
akash329d May 5, 2025
46b78f2
rename to message dispatch
akash329d May 5, 2025
e21d514
make int tests better
akash329d May 6, 2025
ee9f4de
lint
akash329d May 6, 2025
206a98a
tests hanging
akash329d May 6, 2025
bb59e5d
do cleanup after test
akash329d May 6, 2025
ca9a54a
fmt
akash329d May 6, 2025
9832c34
clean up int test
akash329d May 6, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,30 @@ app.router.routes.append(Host('mcp.acme.corp', app=mcp.sse_app()))

For more information on mounting applications in Starlette, see the [Starlette documentation](https://www.starlette.io/routing/#submounting-routes).

#### Message Queue Options

By default, the SSE server uses an in-memory message queue for incoming POST messages. For production deployments or distributed scenarios, you can use Redis:

```python
# Using the built-in Redis message queue
from mcp.server.fastmcp import FastMCP
from mcp.server.message_queue import RedisMessageQueue

# Create a Redis message queue
redis_queue = RedisMessageQueue(
redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:"
)

# Pass the message queue instance to the server
mcp = FastMCP("My App", message_queue=redis_queue)
```

To use Redis, add the Redis dependency:

```bash
uv add "mcp[redis]"
```

## Examples

### Echo Server
Expand Down
10 changes: 8 additions & 2 deletions examples/fastmcp/unicode_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
"""

from mcp.server.fastmcp import FastMCP
from mcp.server.message_queue import RedisMessageQueue

mcp = FastMCP()
# Create a Redis message queue
redis_queue = RedisMessageQueue(
redis_url="redis://localhost:6379/0", prefix="mcp:pubsub:"
)

mcp = FastMCP(message_queue=redis_queue)


@mcp.tool(
Expand Down Expand Up @@ -61,4 +67,4 @@ def multilingual_hello() -> str:


if __name__ == "__main__":
mcp.run()
mcp.run(transport="sse")
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ dependencies = [
rich = ["rich>=13.9.4"]
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
ws = ["websockets>=15.0.1"]
redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"]

[project.scripts]
mcp = "mcp.cli:app [cli]"
Expand Down
17 changes: 16 additions & 1 deletion src/mcp/server/fastmcp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from mcp.server.lowlevel.server import LifespanResultT
from mcp.server.lowlevel.server import Server as MCPServer
from mcp.server.lowlevel.server import lifespan as default_lifespan
from mcp.server.message_queue import MessageQueue
from mcp.server.session import ServerSession, ServerSessionT
from mcp.server.sse import SseServerTransport
from mcp.server.stdio import stdio_server
Expand Down Expand Up @@ -76,6 +77,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
sse_path: str = "/sse"
message_path: str = "/messages/"

# SSE message queue settings
message_queue: MessageQueue | None = Field(
None, description="Custom message queue instance"
)

# resource settings
warn_on_duplicate_resources: bool = True

Expand Down Expand Up @@ -479,7 +485,16 @@ async def run_sse_async(self) -> None:

def sse_app(self) -> Starlette:
"""Return an instance of the SSE server app."""
sse = SseServerTransport(self.settings.message_path)
message_queue = self.settings.message_queue
if message_queue is None:
from mcp.server.message_queue import InMemoryMessageQueue

message_queue = InMemoryMessageQueue()
logger.info("Using default in-memory message queue")

sse = SseServerTransport(
self.settings.message_path, message_queue=message_queue
)

async def handle_sse(request: Request) -> None:
async with sse.connect_sse(
Expand Down
16 changes: 16 additions & 0 deletions src/mcp/server/message_queue/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Message Queue Module for MCP Server

This module implements queue interfaces for handling
messages between clients and servers.
"""

from mcp.server.message_queue.base import InMemoryMessageQueue, MessageQueue

# Try to import Redis implementation if available
try:
from mcp.server.message_queue.redis import RedisMessageQueue
except ImportError:
RedisMessageQueue = None

__all__ = ["MessageQueue", "InMemoryMessageQueue", "RedisMessageQueue"]
103 changes: 103 additions & 0 deletions src/mcp/server/message_queue/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import logging
from collections.abc import Awaitable, Callable
from contextlib import asynccontextmanager
from typing import Protocol, runtime_checkable
from uuid import UUID

import mcp.types as types

logger = logging.getLogger(__name__)

MessageCallback = Callable[[types.JSONRPCMessage | Exception], Awaitable[None]]


@runtime_checkable
class MessageQueue(Protocol):
"""Abstract interface for SSE messaging.

This interface allows messages to be published to sessions and callbacks to be
registered for message handling, enabling multiple servers to handle requests.
"""

async def publish_message(
self, session_id: UUID, message: types.JSONRPCMessage | Exception
) -> bool:
"""Publish a message for the specified session.

Args:
session_id: The UUID of the session this message is for
message: The message to publish

Returns:
bool: True if message was published, False if session not found
"""
...

@asynccontextmanager
async def active_for_request(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that ensures the listener is active.

Args:
session_id: The UUID of the session to activate
callback: Async callback function to handle messages for this session
"""
yield

async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists.

Args:
session_id: The UUID of the session to check

Returns:
bool: True if the session is active, False otherwise
"""
...


class InMemoryMessageQueue:
"""Default in-memory implementation of the MessageQueue interface.

This implementation immediately calls registered callbacks when messages
are received.
"""

def __init__(self) -> None:
self._callbacks: dict[UUID, MessageCallback] = {}
self._active_sessions: set[UUID] = set()

async def publish_message(
self, session_id: UUID, message: types.JSONRPCMessage | Exception
) -> bool:
"""Publish a message for the specified session."""
if not await self.session_exists(session_id):
logger.warning(f"Message received for unknown session {session_id}")
return False

# Call the callback directly if registered
if session_id in self._callbacks:
await self._callbacks[session_id](message)
logger.debug(f"Called callback for session {session_id}")
else:
logger.warning(f"No callback registered for session {session_id}")

return True

@asynccontextmanager
async def active_for_request(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that ensures the listener is active."""
self._active_sessions.add(session_id)
self._callbacks[session_id] = callback
logger.debug(f"Registered session {session_id} with callback")

try:
yield
finally:
self._active_sessions.discard(session_id)
if session_id in self._callbacks:
del self._callbacks[session_id]
logger.debug(f"Unregistered session {session_id}")

async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists."""
return session_id in self._active_sessions
143 changes: 143 additions & 0 deletions src/mcp/server/message_queue/redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
import json
import logging
from contextlib import asynccontextmanager
from typing import Any, cast
from uuid import UUID

import anyio
from anyio import CapacityLimiter, lowlevel

import mcp.types as types
from mcp.server.message_queue.base import MessageCallback

try:
import redis.asyncio as redis
except ImportError:
raise ImportError(
"Redis support requires the 'redis' package. "
"Install it with: 'uv add redis' or 'uv add \"mcp[redis]\"'"
)

logger = logging.getLogger(__name__)


class RedisMessageQueue:
"""Redis implementation of the MessageQueue interface using pubsub.

This implementation uses Redis pubsub for real-time message distribution across
multiple servers handling the same sessions.
"""

def __init__(
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:pubsub:"
) -> None:
"""Initialize Redis message queue.

Args:
redis_url: Redis connection string
prefix: Key prefix for Redis channels to avoid collisions
"""
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
self._prefix = prefix
self._active_sessions_key = f"{prefix}active_sessions"
self._callbacks: dict[UUID, MessageCallback] = {}
# Ensures only one polling task runs at a time for message handling
self._limiter = CapacityLimiter(1)
logger.debug(f"Initialized Redis message queue with URL: {redis_url}")

def _session_channel(self, session_id: UUID) -> str:
"""Get the Redis channel for a session."""
return f"{self._prefix}session:{session_id.hex}"

@asynccontextmanager
async def active_for_request(self, session_id: UUID, callback: MessageCallback):
"""Request-scoped context manager that ensures the listener task is running."""
await self._redis.sadd(self._active_sessions_key, session_id.hex)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's somewhat preferable to use separate keys for tracking the active sessions, rather than one big set value, so the set doesn't grow unboundedly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the max # of redis keys & the max number of members in a set are equal, so I think there should be no difference in practice? https://redis.io/docs/latest/develop/data-types/sets/

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Huh, I guess that seems fine to me, then.

self._callbacks[session_id] = callback
channel = self._session_channel(session_id)
await self._pubsub.subscribe(channel) # type: ignore

logger.debug(f"Registered session {session_id} in Redis with callback")
async with anyio.create_task_group() as tg:
tg.start_soon(self._listen_for_messages)
try:
yield
finally:
tg.cancel_scope.cancel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One kinda squirrelly part of this is that it's extremely important that the cancel doesn't happen after we've read a message off the Redis connection, otherwise the data will just get entirely lost. Right now, I think there's a failure mode here where one SSE stream being closed at the wrong moment will cause messages from another SSE stream to be dropped.

Looking through the Redis implementation, there are definitely multiple async calls within get_message, so it's possible for the bytes to be read off the connection, then the context to be cancelled, and the message dropped.

I think the way we have to structure this is either with multiple connections, or by wrapping the get_message() + dispatch calls in a shield=True scope: https://anyio.readthedocs.io/en/stable/cancellation.html#shielding. We'd then need to set a timeout on the get_message call, and just silently retry if the read times out. (This essentially gives us known safe points that we're OK being cancelled at.)

await self._pubsub.unsubscribe(channel) # type: ignore
await self._redis.srem(self._active_sessions_key, session_id.hex)
del self._callbacks[session_id]
logger.debug(f"Unregistered session {session_id} from Redis")

async def _listen_for_messages(self) -> None:
"""Background task that listens for messages on subscribed channels."""
async with self._limiter:
while True:
await lowlevel.checkpoint()
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
ignore_subscribe_messages=True,
timeout=None, # type: ignore
)
if message is None:
continue

# Extract session ID from channel name
channel: str = cast(str, message["channel"])
if not channel.startswith(self._prefix):
continue

session_hex = channel.split(":")[-1]
try:
session_id = UUID(hex=session_hex)
except ValueError:
logger.error(f"Invalid session channel: {channel}")
continue

data: str = cast(str, message["data"])
msg: None | types.JSONRPCMessage | Exception = None
try:
json_data = json.loads(data)
if isinstance(json_data, dict):
json_dict: dict[str, Any] = json_data
if json_dict.get("_exception", False):
msg = Exception(
f"{json_dict['type']}: {json_dict['message']}"
)
else:
msg = types.JSONRPCMessage.model_validate_json(data)

if msg and session_id in self._callbacks:
await self._callbacks[session_id](msg)
except Exception as e:
logger.error(f"Failed to process message: {e}")

async def publish_message(
self, session_id: UUID, message: types.JSONRPCMessage | Exception
) -> bool:
"""Publish a message for the specified session."""
if not await self.session_exists(session_id):
logger.warning(f"Message received for unknown session {session_id}")
return False

if isinstance(message, Exception):
data = json.dumps(
{
"_exception": True,
"type": type(message).__name__,
"message": str(message),
}
)
else:
data = message.model_dump_json()

channel = self._session_channel(session_id)
await self._redis.publish(channel, data) # type: ignore[attr-defined]
logger.debug(f"Published message to Redis channel for session {session_id}")
return True

async def session_exists(self, session_id: UUID) -> bool:
"""Check if a session exists."""
return bool(
await self._redis.sismember(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined]
)
Loading