1
1
import logging
2
+ import threading
2
3
from contextlib import asynccontextmanager
3
4
from typing import Any , cast
4
5
from uuid import UUID
12
13
from mcp .server .message_queue .base import MessageCallback
13
14
14
15
try :
15
- import redis . asyncio as redis
16
+ import redis
16
17
except ImportError :
17
18
raise ImportError (
18
19
"Redis support requires the 'redis' package. "
@@ -40,19 +41,45 @@ def __init__(
40
41
prefix: Key prefix for Redis channels to avoid collisions
41
42
session_ttl: TTL in seconds for session keys (default: 1 hour)
42
43
"""
43
- self ._redis = redis .from_url (redis_url , decode_responses = True ) # type: ignore
44
- self ._pubsub = self ._redis .pubsub (ignore_subscribe_messages = True ) # type: ignore
44
+ # Parse Redis URL and create connection
45
+ if redis_url .startswith ("redis://" ):
46
+ host_port = redis_url .replace ("redis://" , "" ).split ("/" )[0 ]
47
+ if ":" in host_port :
48
+ host , port = host_port .split (":" )
49
+ port = int (port )
50
+ else :
51
+ host , port = host_port , 6379
52
+ db = int (redis_url .split ("/" )[- 1 ]) if "/" in redis_url else 0
53
+ self ._redis = redis .StrictRedis (host = host , port = port , db = db , decode_responses = True )
54
+ else :
55
+ self ._redis = redis .StrictRedis .from_url (redis_url , decode_responses = True )
56
+
57
+ self ._pubsub = self ._redis .pubsub (ignore_subscribe_messages = True )
45
58
self ._prefix = prefix
46
59
self ._session_ttl = session_ttl
47
60
# Maps session IDs to the callback and task group for that SSE session.
48
61
self ._session_state : dict [UUID , tuple [MessageCallback , TaskGroup ]] = {}
62
+ # Thread for pubsub listening
63
+ self ._pubsub_thread = None
64
+ # Lock for thread safety
65
+ self ._lock = threading .RLock ()
66
+ # Tracks current subscriptions
67
+ self ._subscriptions : set [str ] = set ()
49
68
# Ensures only one polling task runs at a time for message handling
50
69
self ._limiter = CapacityLimiter (1 )
70
+ # Active sessions set key
71
+ self ._active_sessions_key = f"{ self ._prefix } active_sessions"
51
72
logger .debug (f"Redis message dispatch initialized: { redis_url } " )
52
73
53
74
async def close (self ):
54
- await self ._pubsub .aclose () # type: ignore
55
- await self ._redis .aclose () # type: ignore
75
+ """Close Redis connections."""
76
+ # Stop pubsub thread if running
77
+ if self ._pubsub_thread :
78
+ self ._pubsub_thread .stop ()
79
+
80
+ # Clean up pubsub and connection
81
+ self ._pubsub .close ()
82
+ # Redis connection in 3.2.1 doesn't need explicit closing
56
83
57
84
def _session_channel (self , session_id : UUID ) -> str :
58
85
"""Get the Redis channel for a session."""
@@ -66,24 +93,55 @@ def _session_key(self, session_id: UUID) -> str:
66
93
async def subscribe (self , session_id : UUID , callback : MessageCallback ):
67
94
"""Request-scoped context manager that subscribes to messages for a session."""
68
95
session_key = self ._session_key (session_id )
69
- await self ._redis .setex (session_key , self ._session_ttl , "1" ) # type: ignore
96
+
97
+ # Run Redis operations in anyio's run_sync to make blocking calls non-blocking
98
+ await anyio .to_thread .run_sync (
99
+ lambda : self ._redis .setex (session_key , self ._session_ttl , "1" )
100
+ )
101
+
102
+ # Add to active sessions set
103
+ await anyio .to_thread .run_sync (
104
+ lambda : self ._redis .sadd (self ._active_sessions_key , session_id .hex )
105
+ )
70
106
71
107
channel = self ._session_channel (session_id )
72
- await self ._pubsub .subscribe (channel ) # type: ignore
73
-
108
+
109
+ # Use lock for thread safety
110
+ with self ._lock :
111
+ # Subscribe to channel
112
+ await anyio .to_thread .run_sync (lambda : self ._pubsub .subscribe (channel ))
113
+ self ._subscriptions .add (channel )
114
+
74
115
logger .debug (f"Subscribing to Redis channel for session { session_id } " )
116
+
75
117
async with anyio .create_task_group () as tg :
76
118
self ._session_state [session_id ] = (callback , tg )
77
- tg .start_soon (self ._listen_for_messages )
119
+ # Start message listener if not running
120
+ if not self ._pubsub_thread :
121
+ tg .start_soon (self ._listen_for_messages )
122
+
78
123
# Start heartbeat for this session
79
124
tg .start_soon (self ._session_heartbeat , session_id )
80
125
try :
81
126
yield
82
127
finally :
83
128
with anyio .CancelScope (shield = True ):
84
129
tg .cancel_scope .cancel ()
85
- await self ._pubsub .unsubscribe (channel ) # type: ignore
86
- await self ._redis .delete (session_key ) # type: ignore
130
+
131
+ # Unsubscribe
132
+ with self ._lock :
133
+ await anyio .to_thread .run_sync (
134
+ lambda : self ._pubsub .unsubscribe (channel )
135
+ )
136
+ self ._subscriptions .discard (channel )
137
+
138
+ # Delete session key and remove from active sessions
139
+ await anyio .to_thread .run_sync (lambda : self ._redis .delete (session_key ))
140
+ await anyio .to_thread .run_sync (
141
+ lambda : self ._redis .srem (self ._active_sessions_key , session_id .hex )
142
+ )
143
+
144
+ # Clean up session state
87
145
del self ._session_state [session_id ]
88
146
logger .debug (f"Unsubscribed from Redis channel: { session_id } " )
89
147
@@ -96,7 +154,10 @@ async def _session_heartbeat(self, session_id: UUID) -> None:
96
154
# Refresh TTL at half the TTL interval to avoid expiration
97
155
await anyio .sleep (self ._session_ttl / 2 )
98
156
with anyio .CancelScope (shield = True ):
99
- await self ._redis .expire (session_key , self ._session_ttl ) # type: ignore
157
+ # Run in thread to avoid blocking
158
+ await anyio .to_thread .run_sync (
159
+ lambda : self ._redis .expire (session_key , self ._session_ttl )
160
+ )
100
161
except anyio .get_cancelled_exc_class ():
101
162
break
102
163
except Exception as e :
@@ -123,33 +184,39 @@ async def _listen_for_messages(self) -> None:
123
184
"""Background task that listens for messages on subscribed channels."""
124
185
async with self ._limiter :
125
186
while True :
187
+ # Check for cancellation
126
188
await lowlevel .checkpoint ()
189
+
190
+ # Use a shield to prevent cancellation during message processing
127
191
with CancelScope (shield = True ):
128
- message : None | dict [str , Any ] = await self ._pubsub .get_message ( # type: ignore
129
- ignore_subscribe_messages = True ,
130
- timeout = 0.1 , # type: ignore
192
+ # Get message with non-blocking call using thread
193
+ message = await anyio .to_thread .run_sync (
194
+ lambda : self ._pubsub .get_message (
195
+ ignore_subscribe_messages = True ,
196
+ timeout = 0.1 ,
197
+ )
131
198
)
199
+
132
200
if message is None :
201
+ # No message available, sleep briefly and try again
202
+ await anyio .sleep (0.01 )
133
203
continue
134
204
135
205
channel : str = cast (str , message ["channel" ])
136
206
session_id = self ._extract_session_id (channel )
137
207
if session_id is None :
138
- logger .debug (
139
- f"Ignoring message from non-MCP channel: { channel } "
140
- )
208
+ logger .debug (f"Ignoring message from non-MCP channel: { channel } " )
141
209
continue
142
210
143
211
data : str = cast (str , message ["data" ])
144
212
try :
145
213
if session_state := self ._session_state .get (session_id ):
214
+ # Process message in task group
146
215
session_state [1 ].start_soon (
147
216
self ._handle_message , session_id , data
148
217
)
149
218
else :
150
- logger .warning (
151
- f"Message dropped: unknown session { session_id } "
152
- )
219
+ logger .warning (f"Message dropped: unknown session { session_id } " )
153
220
except Exception as e :
154
221
logger .error (f"Error processing message for { session_id } : { e } " )
155
222
@@ -186,11 +253,15 @@ async def publish_message(
186
253
data = message .model_dump_json ()
187
254
188
255
channel = self ._session_channel (session_id )
189
- await self ._redis .publish (channel , data ) # type: ignore[attr-defined]
256
+
257
+ # Run publish in thread to avoid blocking
258
+ await anyio .to_thread .run_sync (lambda : self ._redis .publish (channel , data ))
190
259
logger .debug (f"Message published to Redis channel for session { session_id } " )
191
260
return True
192
261
193
262
async def session_exists (self , session_id : UUID ) -> bool :
194
263
"""Check if a session exists."""
195
264
session_key = self ._session_key (session_id )
196
- return bool (await self ._redis .exists (session_key )) # type: ignore
265
+ # Run exists command in thread to avoid blocking
266
+ exists = await anyio .to_thread .run_sync (lambda : self ._redis .exists (session_key ))
267
+ return bool (exists )
0 commit comments