Skip to content

Commit 9344142

Browse files
author
cshinaver
committed
downgraded redis to 3.2.1
1 parent 046ed94 commit 9344142

File tree

6 files changed

+798
-832
lines changed

6 files changed

+798
-832
lines changed

.tool-versions

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
python 3.10.16

pyproject.toml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,15 @@ dependencies = [
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1; sys_platform != 'emscripten'",
33-
"fakeredis==2.28.1",
33+
"fakeredis==1.9.0",
34+
"redis==3.2.1",
3435
]
3536

3637
[project.optional-dependencies]
3738
rich = ["rich>=13.9.4"]
3839
cli = ["typer>=0.12.4", "python-dotenv>=1.0.0"]
3940
ws = ["websockets>=15.0.1"]
40-
redis = ["redis>=5.2.1", "types-redis>=4.6.0.20241004"]
41+
redis = ["redis==3.2.1"]
4142

4243
[project.scripts]
4344
mcp = "mcp.cli:app [cli]"
@@ -115,5 +116,7 @@ filterwarnings = [
115116
# This should be fixed on Uvicorn's side.
116117
"ignore::DeprecationWarning:websockets",
117118
"ignore:websockets.server.WebSocketServerProtocol is deprecated:DeprecationWarning",
118-
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel"
119+
"ignore:Returning str or bytes.*:DeprecationWarning:mcp.server.lowlevel",
120+
# Ignore distutils deprecation from Redis 3.2.1
121+
"ignore:The distutils package is deprecated.*:DeprecationWarning"
119122
]

src/mcp/server/message_queue/redis.py

Lines changed: 94 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
import threading
23
from contextlib import asynccontextmanager
34
from typing import Any, cast
45
from uuid import UUID
@@ -12,7 +13,7 @@
1213
from mcp.server.message_queue.base import MessageCallback
1314

1415
try:
15-
import redis.asyncio as redis
16+
import redis
1617
except ImportError:
1718
raise ImportError(
1819
"Redis support requires the 'redis' package. "
@@ -40,19 +41,45 @@ def __init__(
4041
prefix: Key prefix for Redis channels to avoid collisions
4142
session_ttl: TTL in seconds for session keys (default: 1 hour)
4243
"""
43-
self._redis = redis.from_url(redis_url, decode_responses=True) # type: ignore
44-
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True) # type: ignore
44+
# Parse Redis URL and create connection
45+
if redis_url.startswith("redis://"):
46+
host_port = redis_url.replace("redis://", "").split("/")[0]
47+
if ":" in host_port:
48+
host, port = host_port.split(":")
49+
port = int(port)
50+
else:
51+
host, port = host_port, 6379
52+
db = int(redis_url.split("/")[-1]) if "/" in redis_url else 0
53+
self._redis = redis.StrictRedis(host=host, port=port, db=db, decode_responses=True)
54+
else:
55+
self._redis = redis.StrictRedis.from_url(redis_url, decode_responses=True)
56+
57+
self._pubsub = self._redis.pubsub(ignore_subscribe_messages=True)
4558
self._prefix = prefix
4659
self._session_ttl = session_ttl
4760
# Maps session IDs to the callback and task group for that SSE session.
4861
self._session_state: dict[UUID, tuple[MessageCallback, TaskGroup]] = {}
62+
# Thread for pubsub listening
63+
self._pubsub_thread = None
64+
# Lock for thread safety
65+
self._lock = threading.RLock()
66+
# Tracks current subscriptions
67+
self._subscriptions: set[str] = set()
4968
# Ensures only one polling task runs at a time for message handling
5069
self._limiter = CapacityLimiter(1)
70+
# Active sessions set key
71+
self._active_sessions_key = f"{self._prefix}active_sessions"
5172
logger.debug(f"Redis message dispatch initialized: {redis_url}")
5273

5374
async def close(self):
54-
await self._pubsub.aclose() # type: ignore
55-
await self._redis.aclose() # type: ignore
75+
"""Close Redis connections."""
76+
# Stop pubsub thread if running
77+
if self._pubsub_thread:
78+
self._pubsub_thread.stop()
79+
80+
# Clean up pubsub and connection
81+
self._pubsub.close()
82+
# Redis connection in 3.2.1 doesn't need explicit closing
5683

5784
def _session_channel(self, session_id: UUID) -> str:
5885
"""Get the Redis channel for a session."""
@@ -66,24 +93,55 @@ def _session_key(self, session_id: UUID) -> str:
6693
async def subscribe(self, session_id: UUID, callback: MessageCallback):
6794
"""Request-scoped context manager that subscribes to messages for a session."""
6895
session_key = self._session_key(session_id)
69-
await self._redis.setex(session_key, self._session_ttl, "1") # type: ignore
96+
97+
# Run Redis operations in anyio's run_sync to make blocking calls non-blocking
98+
await anyio.to_thread.run_sync(
99+
lambda: self._redis.setex(session_key, self._session_ttl, "1")
100+
)
101+
102+
# Add to active sessions set
103+
await anyio.to_thread.run_sync(
104+
lambda: self._redis.sadd(self._active_sessions_key, session_id.hex)
105+
)
70106

71107
channel = self._session_channel(session_id)
72-
await self._pubsub.subscribe(channel) # type: ignore
73-
108+
109+
# Use lock for thread safety
110+
with self._lock:
111+
# Subscribe to channel
112+
await anyio.to_thread.run_sync(lambda: self._pubsub.subscribe(channel))
113+
self._subscriptions.add(channel)
114+
74115
logger.debug(f"Subscribing to Redis channel for session {session_id}")
116+
75117
async with anyio.create_task_group() as tg:
76118
self._session_state[session_id] = (callback, tg)
77-
tg.start_soon(self._listen_for_messages)
119+
# Start message listener if not running
120+
if not self._pubsub_thread:
121+
tg.start_soon(self._listen_for_messages)
122+
78123
# Start heartbeat for this session
79124
tg.start_soon(self._session_heartbeat, session_id)
80125
try:
81126
yield
82127
finally:
83128
with anyio.CancelScope(shield=True):
84129
tg.cancel_scope.cancel()
85-
await self._pubsub.unsubscribe(channel) # type: ignore
86-
await self._redis.delete(session_key) # type: ignore
130+
131+
# Unsubscribe
132+
with self._lock:
133+
await anyio.to_thread.run_sync(
134+
lambda: self._pubsub.unsubscribe(channel)
135+
)
136+
self._subscriptions.discard(channel)
137+
138+
# Delete session key and remove from active sessions
139+
await anyio.to_thread.run_sync(lambda: self._redis.delete(session_key))
140+
await anyio.to_thread.run_sync(
141+
lambda: self._redis.srem(self._active_sessions_key, session_id.hex)
142+
)
143+
144+
# Clean up session state
87145
del self._session_state[session_id]
88146
logger.debug(f"Unsubscribed from Redis channel: {session_id}")
89147

@@ -96,7 +154,10 @@ async def _session_heartbeat(self, session_id: UUID) -> None:
96154
# Refresh TTL at half the TTL interval to avoid expiration
97155
await anyio.sleep(self._session_ttl / 2)
98156
with anyio.CancelScope(shield=True):
99-
await self._redis.expire(session_key, self._session_ttl) # type: ignore
157+
# Run in thread to avoid blocking
158+
await anyio.to_thread.run_sync(
159+
lambda: self._redis.expire(session_key, self._session_ttl)
160+
)
100161
except anyio.get_cancelled_exc_class():
101162
break
102163
except Exception as e:
@@ -123,33 +184,39 @@ async def _listen_for_messages(self) -> None:
123184
"""Background task that listens for messages on subscribed channels."""
124185
async with self._limiter:
125186
while True:
187+
# Check for cancellation
126188
await lowlevel.checkpoint()
189+
190+
# Use a shield to prevent cancellation during message processing
127191
with CancelScope(shield=True):
128-
message: None | dict[str, Any] = await self._pubsub.get_message( # type: ignore
129-
ignore_subscribe_messages=True,
130-
timeout=0.1, # type: ignore
192+
# Get message with non-blocking call using thread
193+
message = await anyio.to_thread.run_sync(
194+
lambda: self._pubsub.get_message(
195+
ignore_subscribe_messages=True,
196+
timeout=0.1,
197+
)
131198
)
199+
132200
if message is None:
201+
# No message available, sleep briefly and try again
202+
await anyio.sleep(0.01)
133203
continue
134204

135205
channel: str = cast(str, message["channel"])
136206
session_id = self._extract_session_id(channel)
137207
if session_id is None:
138-
logger.debug(
139-
f"Ignoring message from non-MCP channel: {channel}"
140-
)
208+
logger.debug(f"Ignoring message from non-MCP channel: {channel}")
141209
continue
142210

143211
data: str = cast(str, message["data"])
144212
try:
145213
if session_state := self._session_state.get(session_id):
214+
# Process message in task group
146215
session_state[1].start_soon(
147216
self._handle_message, session_id, data
148217
)
149218
else:
150-
logger.warning(
151-
f"Message dropped: unknown session {session_id}"
152-
)
219+
logger.warning(f"Message dropped: unknown session {session_id}")
153220
except Exception as e:
154221
logger.error(f"Error processing message for {session_id}: {e}")
155222

@@ -186,11 +253,15 @@ async def publish_message(
186253
data = message.model_dump_json()
187254

188255
channel = self._session_channel(session_id)
189-
await self._redis.publish(channel, data) # type: ignore[attr-defined]
256+
257+
# Run publish in thread to avoid blocking
258+
await anyio.to_thread.run_sync(lambda: self._redis.publish(channel, data))
190259
logger.debug(f"Message published to Redis channel for session {session_id}")
191260
return True
192261

193262
async def session_exists(self, session_id: UUID) -> bool:
194263
"""Check if a session exists."""
195264
session_key = self._session_key(session_id)
196-
return bool(await self._redis.exists(session_key)) # type: ignore
265+
# Run exists command in thread to avoid blocking
266+
exists = await anyio.to_thread.run_sync(lambda: self._redis.exists(session_key))
267+
return bool(exists)

tests/server/message_queue/test_redis.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99

1010
# Set up fakeredis for testing
1111
try:
12-
from fakeredis import aioredis as fake_redis
12+
import fakeredis
13+
# Older fakeredis (v1.9.0) doesn't have aioredis module
14+
fake_redis = fakeredis
1315
except ImportError:
1416
pytest.skip(
1517
"fakeredis is required for testing Redis functionality", allow_module_level=True
@@ -20,7 +22,7 @@
2022
async def redis_dispatch():
2123
"""Create a Redis message dispatch with a fake Redis client."""
2224
# Mock the redis module entirely within RedisMessageDispatch
23-
with patch("mcp.server.message_queue.redis.redis", fake_redis.FakeRedis):
25+
with patch("mcp.server.message_queue.redis.redis.StrictRedis", fake_redis.FakeStrictRedis):
2426
from mcp.server.message_queue.redis import RedisMessageDispatch
2527

2628
dispatch = RedisMessageDispatch(session_ttl=5) # Shorter TTL for testing
@@ -53,7 +55,8 @@ async def test_session_ttl(redis_dispatch):
5355

5456
async with redis_dispatch.subscribe(session_id, AsyncMock()):
5557
session_key = redis_dispatch._session_key(session_id)
56-
ttl = await redis_dispatch._redis.ttl(session_key) # type: ignore
58+
# Updated for synchronous Redis client
59+
ttl = await anyio.to_thread.run_sync(lambda: redis_dispatch._redis.ttl(session_key))
5760
assert ttl > 0
5861
assert ttl <= redis_dispatch._session_ttl
5962

@@ -67,14 +70,14 @@ async def test_session_heartbeat(redis_dispatch):
6770
session_key = redis_dispatch._session_key(session_id)
6871

6972
# Initial TTL
70-
initial_ttl = await redis_dispatch._redis.ttl(session_key) # type: ignore
73+
initial_ttl = await anyio.to_thread.run_sync(lambda: redis_dispatch._redis.ttl(session_key))
7174
assert initial_ttl > 0
7275

7376
# Wait for heartbeat to run
7477
await anyio.sleep(redis_dispatch._session_ttl / 2 + 0.5)
7578

7679
# TTL should be refreshed
77-
refreshed_ttl = await redis_dispatch._redis.ttl(session_key) # type: ignore
80+
refreshed_ttl = await anyio.to_thread.run_sync(lambda: redis_dispatch._redis.ttl(session_key))
7881
assert refreshed_ttl > 0
7982
assert refreshed_ttl <= redis_dispatch._session_ttl
8083

tests/server/message_queue/test_redis_integration.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222

2323
# Set up fakeredis for testing
2424
try:
25-
from fakeredis import aioredis as fake_redis
25+
import fakeredis
26+
# Older fakeredis (v1.9.0) doesn't have aioredis module
27+
fake_redis = fakeredis
2628
except ImportError:
2729
pytest.skip(
2830
"fakeredis is required for testing Redis functionality", allow_module_level=True
@@ -64,10 +66,10 @@ async def handle_call_tool(name: str, args: dict) -> list[TextContent]:
6466
def make_redis_server_app() -> Starlette:
6567
"""Create test Starlette app with SSE transport and Redis message dispatch"""
6668
# Create a mock Redis instance
67-
mock_redis = fake_redis.FakeRedis()
69+
mock_redis = fake_redis.FakeStrictRedis(decode_responses=True)
6870

6971
# Patch the redis module within RedisMessageDispatch
70-
with patch("mcp.server.message_queue.redis.redis", mock_redis):
72+
with patch("mcp.server.message_queue.redis.redis.StrictRedis", lambda *args, **kwargs: mock_redis):
7173
from mcp.server.message_queue.redis import RedisMessageDispatch
7274

7375
# Create Redis message dispatch with mock redis
@@ -173,11 +175,11 @@ async def test_redis_integration_tool_call(server: None, server_url: str) -> Non
173175
async def test_redis_integration_session_lifecycle() -> None:
174176
"""Test that sessions are properly added to and removed from Redis using direct Redis access"""
175177
# Create a fresh Redis instance with decode_responses=True to get str instead of bytes
176-
mock_redis = fake_redis.FakeRedis(decode_responses=True)
178+
mock_redis = fake_redis.FakeStrictRedis(decode_responses=True)
177179
active_sessions_key = "mcp:pubsub:active_sessions"
178180

179181
# Mock Redis in RedisMessageDispatch
180-
with patch("mcp.server.message_queue.redis.redis.from_url", return_value=mock_redis):
182+
with patch("mcp.server.message_queue.redis.redis.StrictRedis", lambda *args, **kwargs: mock_redis):
181183
from mcp.server.message_queue.redis import RedisMessageDispatch
182184

183185
# Create Redis message dispatch with our specific mock redis instance
@@ -197,7 +199,7 @@ async def mock_callback(message):
197199
await anyio.sleep(0.05)
198200

199201
# Check that session was added to Redis
200-
active_sessions = await mock_redis.smembers(active_sessions_key)
202+
active_sessions = mock_redis.smembers(active_sessions_key)
201203
assert len(active_sessions) == 1
202204
assert list(active_sessions)[0] == session_id.hex
203205

@@ -208,7 +210,7 @@ async def mock_callback(message):
208210
await anyio.sleep(0.05)
209211

210212
# After context exit, verify the session was removed
211-
final_sessions = await mock_redis.smembers(active_sessions_key)
213+
final_sessions = mock_redis.smembers(active_sessions_key)
212214
assert len(final_sessions) == 0
213215
assert not await message_dispatch.session_exists(session_id)
214216

@@ -217,10 +219,10 @@ async def mock_callback(message):
217219
async def test_redis_integration_message_publishing_direct() -> None:
218220
"""Test message publishing through Redis channels using direct Redis access"""
219221
# Create a fresh Redis instance with decode_responses=True to get str instead of bytes
220-
mock_redis = fake_redis.FakeRedis(decode_responses=True)
222+
mock_redis = fake_redis.FakeStrictRedis(decode_responses=True)
221223

222224
# Mock Redis in RedisMessageDispatch
223-
with patch("mcp.server.message_queue.redis.redis.from_url", return_value=mock_redis):
225+
with patch("mcp.server.message_queue.redis.redis.StrictRedis", lambda *args, **kwargs: mock_redis):
224226
from mcp.server.message_queue.redis import RedisMessageDispatch
225227
from mcp.types import JSONRPCMessage, JSONRPCRequest
226228

@@ -252,8 +254,8 @@ async def message_callback(message):
252254
assert success
253255

254256
# Give some time for the message to be processed
255-
# Use a shorter sleep since we're in controlled test environment
256-
await anyio.sleep(0.1)
257+
# Use a longer sleep since we're with older Redis version
258+
await anyio.sleep(0.5)
257259

258260
# Verify that the message was received
259261
assert len(messages_received) > 0, "No messages were received through the callback"

0 commit comments

Comments
 (0)