Skip to content

Commit 46603d1

Browse files
committed
Properly clean up response streams in BaseSession
Wraps the request handling in a try/finally block to ensure that response streams are properly closed and removed from the tracking dictionary, even if an exception occurs during request processing. This change also prevents response_stream and response_stream_reader instances from piling up on _exit_stack over the course of the session. Github-Issue:#169
1 parent 70115b9 commit 46603d1

File tree

4 files changed

+113
-42
lines changed

4 files changed

+113
-42
lines changed

Diff for: pyproject.toml

+2-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ dependencies = [
2626
"httpx>=0.27",
2727
"httpx-sse>=0.4",
2828
"pydantic>=2.7.2,<3.0.0",
29-
"starlette>=0.27",
29+
"starlette>=0.46.2",
3030
"sse-starlette>=1.6.1",
3131
"pydantic-settings>=2.5.2",
3232
"uvicorn>=0.23.1",
@@ -53,6 +53,7 @@ dev = [
5353
"pytest-flakefinder>=1.1.0",
5454
"pytest-xdist>=3.6.1",
5555
"pytest-examples>=0.0.14",
56+
"starlette>=0.46.2",
5657
]
5758
docs = [
5859
"mkdocs>=1.6.1",

Diff for: src/mcp/shared/session.py

+36-34
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ def __init__(
187187
self._receive_notification_type = receive_notification_type
188188
self._read_timeout_seconds = read_timeout_seconds
189189
self._in_flight = {}
190-
191190
self._exit_stack = AsyncExitStack()
192191

193192
async def __aenter__(self) -> Self:
@@ -230,42 +229,45 @@ async def send_request(
230229
](1)
231230
self._response_streams[request_id] = response_stream
232231

233-
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
234-
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
235-
236-
jsonrpc_request = JSONRPCRequest(
237-
jsonrpc="2.0",
238-
id=request_id,
239-
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
240-
)
241-
242-
# TODO: Support progress callbacks
243-
244-
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
245-
246232
try:
247-
with anyio.fail_after(
248-
None
249-
if self._read_timeout_seconds is None
250-
else self._read_timeout_seconds.total_seconds()
251-
):
252-
response_or_error = await response_stream_reader.receive()
253-
except TimeoutError:
254-
raise McpError(
255-
ErrorData(
256-
code=httpx.codes.REQUEST_TIMEOUT,
257-
message=(
258-
f"Timed out while waiting for response to "
259-
f"{request.__class__.__name__}. Waited "
260-
f"{self._read_timeout_seconds} seconds."
261-
),
262-
)
233+
jsonrpc_request = JSONRPCRequest(
234+
jsonrpc="2.0",
235+
id=request_id,
236+
**request.model_dump(by_alias=True, mode="json", exclude_none=True),
263237
)
264238

265-
if isinstance(response_or_error, JSONRPCError):
266-
raise McpError(response_or_error.error)
267-
else:
268-
return result_type.model_validate(response_or_error.result)
239+
# TODO: Support progress callbacks
240+
241+
await self._write_stream.send(JSONRPCMessage(jsonrpc_request))
242+
243+
try:
244+
with anyio.fail_after(
245+
None
246+
if self._read_timeout_seconds is None
247+
else self._read_timeout_seconds.total_seconds()
248+
):
249+
response_or_error = await response_stream_reader.receive()
250+
except TimeoutError:
251+
raise McpError(
252+
ErrorData(
253+
code=httpx.codes.REQUEST_TIMEOUT,
254+
message=(
255+
f"Timed out while waiting for response to "
256+
f"{request.__class__.__name__}. Waited "
257+
f"{self._read_timeout_seconds} seconds."
258+
),
259+
)
260+
)
261+
262+
if isinstance(response_or_error, JSONRPCError):
263+
raise McpError(response_or_error.error)
264+
else:
265+
return result_type.model_validate(response_or_error.result)
266+
267+
finally:
268+
self._response_streams.pop(request_id, None)
269+
await response_stream.aclose()
270+
await response_stream_reader.aclose()
269271

270272
async def send_notification(self, notification: SendNotificationT) -> None:
271273
"""

Diff for: tests/client/test_resource_cleanup.py

+68
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
from unittest.mock import patch
2+
3+
import anyio
4+
import pytest
5+
6+
from mcp.shared.session import BaseSession
7+
from mcp.types import (
8+
ClientRequest,
9+
EmptyResult,
10+
PingRequest,
11+
)
12+
13+
14+
@pytest.mark.anyio
15+
async def test_send_request_stream_cleanup():
16+
"""
17+
Test that send_request properly cleans up streams when an exception occurs.
18+
19+
This test mocks out most of the session functionality to focus on stream cleanup.
20+
"""
21+
22+
# Create a mock session with the minimal required functionality
23+
class TestSession(BaseSession):
24+
async def _send_response(self, request_id, response):
25+
pass
26+
27+
# Create streams
28+
write_stream_send, write_stream_receive = anyio.create_memory_object_stream(1)
29+
read_stream_send, read_stream_receive = anyio.create_memory_object_stream(1)
30+
31+
# Create the session
32+
session = TestSession(
33+
read_stream_receive,
34+
write_stream_send,
35+
object, # Request type doesn't matter for this test
36+
object, # Notification type doesn't matter for this test
37+
)
38+
39+
# Create a test request
40+
request = ClientRequest(
41+
PingRequest(
42+
method="ping",
43+
)
44+
)
45+
46+
# Patch the _write_stream.send method to raise an exception
47+
async def mock_send(*args, **kwargs):
48+
raise RuntimeError("Simulated network error")
49+
50+
# Record the response streams before the test
51+
initial_stream_count = len(session._response_streams)
52+
53+
# Run the test with the patched method
54+
with patch.object(session._write_stream, "send", mock_send):
55+
with pytest.raises(RuntimeError):
56+
await session.send_request(request, EmptyResult)
57+
58+
# Verify that no response streams were leaked
59+
assert len(session._response_streams) == initial_stream_count, (
60+
f"Expected {initial_stream_count} response streams after request, "
61+
f"but found {len(session._response_streams)}"
62+
)
63+
64+
# Clean up
65+
await write_stream_send.aclose()
66+
await write_stream_receive.aclose()
67+
await read_stream_send.aclose()
68+
await read_stream_receive.aclose()

Diff for: uv.lock

+7-7
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)