-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Add message queue for SSE messages POST endpoint #459
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
10c5af8
c3d5efc
b2fce7d
7c82f36
b92f22f
5dbca6e
badc1e2
23665db
ccd5a13
fb44020
efe6da9
d625782
78c6aef
fad836c
fd97501
4bce7d8
d6075bb
8ee3a7e
7cabcea
5111c92
09e0cab
8d280d8
c2bb049
87e07b8
b484284
0bfd800
1e81f36
215cc42
8fce8e6
b2893e6
e5938d4
a151f1c
d22f46b
8d6a20d
bb24881
564561f
9419ad0
046ed94
70547c0
5638653
2437e46
9664c8a
30b475b
0114189
7081a40
5ae3cc6
46b78f2
e21d514
ee9f4de
206a98a
bb59e5d
ca9a54a
9832c34
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Message Queue Module for MCP Server | ||
|
||
This module implements queue interfaces for handling | ||
messages between clients and servers. | ||
""" | ||
|
||
from mcp.server.message_queue.base import InMemoryMessageQueue, MessageQueue | ||
|
||
# Try to import Redis implementation if available | ||
try: | ||
from mcp.server.message_queue.redis import RedisMessageQueue | ||
except ImportError: | ||
RedisMessageQueue = None | ||
|
||
__all__ = ["MessageQueue", "InMemoryMessageQueue", "RedisMessageQueue"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,103 @@ | ||
import logging | ||
from collections.abc import Awaitable, Callable | ||
from contextlib import asynccontextmanager | ||
from typing import Protocol, runtime_checkable | ||
from uuid import UUID | ||
|
||
import mcp.types as types | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
MessageCallback = Callable[[types.JSONRPCMessage | Exception], Awaitable[None]] | ||
|
||
|
||
@runtime_checkable | ||
class MessageQueue(Protocol): | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Abstract interface for SSE messaging. | ||
|
||
This interface allows messages to be published to sessions and callbacks to be | ||
registered for message handling, enabling multiple servers to handle requests. | ||
""" | ||
|
||
async def publish_message( | ||
self, session_id: UUID, message: types.JSONRPCMessage | Exception | ||
) -> bool: | ||
"""Publish a message for the specified session. | ||
|
||
Args: | ||
session_id: The UUID of the session this message is for | ||
message: The message to publish | ||
|
||
Returns: | ||
bool: True if message was published, False if session not found | ||
""" | ||
... | ||
|
||
@asynccontextmanager | ||
async def active_for_request(self, session_id: UUID, callback: MessageCallback): | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Request-scoped context manager that ensures the listener is active. | ||
|
||
Args: | ||
session_id: The UUID of the session to activate | ||
callback: Async callback function to handle messages for this session | ||
""" | ||
yield | ||
|
||
async def session_exists(self, session_id: UUID) -> bool: | ||
"""Check if a session exists. | ||
|
||
Args: | ||
session_id: The UUID of the session to check | ||
|
||
Returns: | ||
bool: True if the session is active, False otherwise | ||
""" | ||
... | ||
|
||
|
||
class InMemoryMessageQueue: | ||
"""Default in-memory implementation of the MessageQueue interface. | ||
|
||
This implementation immediately calls registered callbacks when messages | ||
are received. | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._callbacks: dict[UUID, MessageCallback] = {} | ||
self._active_sessions: set[UUID] = set() | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
async def publish_message( | ||
self, session_id: UUID, message: types.JSONRPCMessage | Exception | ||
) -> bool: | ||
"""Publish a message for the specified session.""" | ||
if not await self.session_exists(session_id): | ||
logger.warning(f"Message received for unknown session {session_id}") | ||
return False | ||
|
||
# Call the callback directly if registered | ||
if session_id in self._callbacks: | ||
await self._callbacks[session_id](message) | ||
logger.debug(f"Called callback for session {session_id}") | ||
else: | ||
logger.warning(f"No callback registered for session {session_id}") | ||
|
||
return True | ||
|
||
@asynccontextmanager | ||
async def active_for_request(self, session_id: UUID, callback: MessageCallback): | ||
"""Request-scoped context manager that ensures the listener is active.""" | ||
self._active_sessions.add(session_id) | ||
self._callbacks[session_id] = callback | ||
logger.debug(f"Registered session {session_id} with callback") | ||
|
||
try: | ||
yield | ||
finally: | ||
self._active_sessions.discard(session_id) | ||
if session_id in self._callbacks: | ||
del self._callbacks[session_id] | ||
logger.debug(f"Unregistered session {session_id}") | ||
|
||
async def session_exists(self, session_id: UUID) -> bool: | ||
"""Check if a session exists.""" | ||
return session_id in self._active_sessions |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
import json | ||
import logging | ||
from contextlib import asynccontextmanager | ||
from typing import Any, cast | ||
from uuid import UUID | ||
|
||
import anyio | ||
from anyio import CapacityLimiter, lowlevel | ||
|
||
import mcp.types as types | ||
from mcp.server.message_queue.base import MessageCallback | ||
|
||
try: | ||
import redis.asyncio as redis | ||
except ImportError: | ||
raise ImportError( | ||
"Redis support requires the 'redis' package. " | ||
"Install it with: 'uv add redis' or 'uv add \"mcp[redis]\"'" | ||
) | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class RedisMessageQueue: | ||
"""Redis implementation of the MessageQueue interface using pubsub. | ||
|
||
This implementation uses Redis pubsub for real-time message distribution across | ||
multiple servers handling the same sessions. | ||
""" | ||
|
||
def __init__( | ||
self, redis_url: str = "redis://localhost:6379/0", prefix: str = "mcp:pubsub:" | ||
) -> None: | ||
"""Initialize Redis message queue. | ||
|
||
Args: | ||
redis_url: Redis connection string | ||
prefix: Key prefix for Redis channels to avoid collisions | ||
""" | ||
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore | ||
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore | ||
self._prefix = prefix | ||
self._active_sessions_key = f"{prefix}active_sessions" | ||
self._callbacks: dict[UUID, MessageCallback] = {} | ||
# Ensures only one polling task runs at a time for message handling | ||
self._limiter = CapacityLimiter(1) | ||
logger.debug(f"Initialized Redis message queue with URL: {redis_url}") | ||
|
||
def _session_channel(self, session_id: UUID) -> str: | ||
"""Get the Redis channel for a session.""" | ||
return f"{self._prefix}session:{session_id.hex}" | ||
|
||
@asynccontextmanager | ||
async def active_for_request(self, session_id: UUID, callback: MessageCallback): | ||
"""Request-scoped context manager that ensures the listener task is running.""" | ||
await self._redis.sadd(self._active_sessions_key, session_id.hex) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's somewhat preferable to use separate keys for tracking the active sessions, rather than one big set value, so the set doesn't grow unboundedly. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems like the max # of redis keys & the max number of members in a set are equal, so I think there should be no difference in practice? https://redis.io/docs/latest/develop/data-types/sets/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Huh, I guess that seems fine to me, then. |
||
self._callbacks[session_id] = callback | ||
channel = self._session_channel(session_id) | ||
await self._pubsub.subscribe(channel) # type: ignore | ||
|
||
logger.debug(f"Registered session {session_id} in Redis with callback") | ||
async with anyio.create_task_group() as tg: | ||
tg.start_soon(self._listen_for_messages) | ||
try: | ||
yield | ||
finally: | ||
tg.cancel_scope.cancel() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One kinda squirrelly part of this is that it's extremely important that the cancel doesn't happen after we've read a message off the Redis connection, otherwise the data will just get entirely lost. Right now, I think there's a failure mode here where one SSE stream being closed at the wrong moment will cause messages from another SSE stream to be dropped. Looking through the Redis implementation, there are definitely multiple I think the way we have to structure this is either with multiple connections, or by wrapping the |
||
await self._pubsub.unsubscribe(channel) # type: ignore | ||
await self._redis.srem(self._active_sessions_key, session_id.hex) | ||
del self._callbacks[session_id] | ||
logger.debug(f"Unregistered session {session_id} from Redis") | ||
|
||
async def _listen_for_messages(self) -> None: | ||
"""Background task that listens for messages on subscribed channels.""" | ||
async with self._limiter: | ||
while True: | ||
await lowlevel.checkpoint() | ||
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore | ||
ignore_subscribe_messages=True, | ||
timeout=None, # type: ignore | ||
) | ||
if message is None: | ||
continue | ||
|
||
# Extract session ID from channel name | ||
channel: str = cast(str, message["channel"]) | ||
if not channel.startswith(self._prefix): | ||
continue | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
session_hex = channel.split(":")[-1] | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
session_id = UUID(hex=session_hex) | ||
except ValueError: | ||
logger.error(f"Invalid session channel: {channel}") | ||
continue | ||
|
||
data: str = cast(str, message["data"]) | ||
msg: None | types.JSONRPCMessage | Exception = None | ||
try: | ||
json_data = json.loads(data) | ||
if isinstance(json_data, dict): | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
json_dict: dict[str, Any] = json_data | ||
if json_dict.get("_exception", False): | ||
msg = Exception( | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
f"{json_dict['type']}: {json_dict['message']}" | ||
) | ||
else: | ||
msg = types.JSONRPCMessage.model_validate_json(data) | ||
|
||
if msg and session_id in self._callbacks: | ||
await self._callbacks[session_id](msg) | ||
akash329d marked this conversation as resolved.
Show resolved
Hide resolved
|
||
except Exception as e: | ||
logger.error(f"Failed to process message: {e}") | ||
|
||
async def publish_message( | ||
self, session_id: UUID, message: types.JSONRPCMessage | Exception | ||
) -> bool: | ||
"""Publish a message for the specified session.""" | ||
if not await self.session_exists(session_id): | ||
logger.warning(f"Message received for unknown session {session_id}") | ||
return False | ||
|
||
if isinstance(message, Exception): | ||
data = json.dumps( | ||
{ | ||
"_exception": True, | ||
"type": type(message).__name__, | ||
"message": str(message), | ||
} | ||
) | ||
else: | ||
data = message.model_dump_json() | ||
|
||
channel = self._session_channel(session_id) | ||
await self._redis.publish(channel, data) # type: ignore[attr-defined] | ||
logger.debug(f"Published message to Redis channel for session {session_id}") | ||
return True | ||
|
||
async def session_exists(self, session_id: UUID) -> bool: | ||
"""Check if a session exists.""" | ||
return bool( | ||
await self._redis.sismember(self._active_sessions_key, session_id.hex) # type: ignore[attr-defined] | ||
) |
Uh oh!
There was an error while loading. Please reload this page.