Skip to content

Commit 827e494

Browse files
committed
feat: add request cancellation and in-flight request tracking
This commit adds support for request cancellation and tracking of in-flight requests in the MCP protocol implementation. The key architectural changes are: 1. Request Lifecycle Management: - Added _in_flight dictionary to BaseSession to track active requests - Requests are tracked from receipt until completion/cancellation - Added proper cleanup via on_complete callback 2. Cancellation Support: - Added CancelledNotification handling in _receive_loop - Implemented cancel() method in RequestResponder - Uses anyio.CancelScope for robust cancellation - Sends error response on cancellation 3. Request Context: - Added request_ctx ContextVar for request context - Ensures proper cleanup after request handling - Maintains request state throughout lifecycle 4. Error Handling: - Improved error propagation for cancelled requests - Added proper cleanup of cancelled requests - Maintains consistency of in-flight tracking This change enables clients to cancel long-running requests and servers to properly clean up resources when requests are cancelled. Github-Issue:#88
1 parent 888bdd3 commit 827e494

File tree

3 files changed

+190
-13
lines changed

3 files changed

+190
-13
lines changed

src/mcp/server/lowlevel/server.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,10 +453,15 @@ async def run(
453453
logger.debug(f"Received message: {message}")
454454

455455
match message:
456-
case RequestResponder(request=types.ClientRequest(root=req)):
457-
await self._handle_request(
458-
message, req, session, raise_exceptions
459-
)
456+
case (
457+
RequestResponder(
458+
request=types.ClientRequest(root=req)
459+
) as responder
460+
):
461+
with responder:
462+
await self._handle_request(
463+
message, req, session, raise_exceptions
464+
)
460465
case types.ClientNotification(root=notify):
461466
await self._handle_notification(notify)
462467

src/mcp/shared/session.py

Lines changed: 55 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from contextlib import AbstractAsyncContextManager
22
from datetime import timedelta
3-
from typing import Generic, TypeVar
3+
from typing import Any, Callable, Generic, TypeVar
44

55
import anyio
66
import anyio.lowlevel
@@ -10,6 +10,7 @@
1010

1111
from mcp.shared.exceptions import McpError
1212
from mcp.types import (
13+
CancelledNotification,
1314
ClientNotification,
1415
ClientRequest,
1516
ClientResult,
@@ -44,21 +45,55 @@ def __init__(
4445
request_meta: RequestParams.Meta | None,
4546
request: ReceiveRequestT,
4647
session: "BaseSession",
48+
on_complete: Callable[["RequestResponder[ReceiveRequestT, SendResultT]"], Any],
4749
) -> None:
4850
self.request_id = request_id
4951
self.request_meta = request_meta
5052
self.request = request
5153
self._session = session
52-
self._responded = False
54+
self._completed = False
55+
self._cancel_scope = anyio.CancelScope()
56+
self._on_complete = on_complete
57+
58+
def __enter__(self) -> "RequestResponder[ReceiveRequestT, SendResultT]":
59+
self._cancel_scope.__enter__()
60+
return self
61+
62+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
63+
try:
64+
if self._completed:
65+
self._on_complete(self)
66+
finally:
67+
self._cancel_scope.__exit__(exc_type, exc_val, exc_tb)
5368

5469
async def respond(self, response: SendResultT | ErrorData) -> None:
55-
assert not self._responded, "Request already responded to"
56-
self._responded = True
70+
assert not self._completed, "Request already responded to"
5771

72+
if not self.cancelled:
73+
self._completed = True
74+
75+
await self._session._send_response(
76+
request_id=self.request_id, response=response
77+
)
78+
79+
async def cancel(self) -> None:
80+
"""Cancel this request and mark it as completed."""
81+
self._cancel_scope.cancel()
82+
self._completed = True # Mark as completed so it's removed from in_flight
83+
# Send an error response to indicate cancellation
5884
await self._session._send_response(
59-
request_id=self.request_id, response=response
85+
request_id=self.request_id,
86+
response=ErrorData(code=0, message="Request cancelled", data=None),
6087
)
6188

89+
@property
90+
def in_flight(self) -> bool:
91+
return not self._completed and not self.cancelled
92+
93+
@property
94+
def cancelled(self) -> bool:
95+
return self._cancel_scope is not None and self._cancel_scope.cancel_called
96+
6297

6398
class BaseSession(
6499
AbstractAsyncContextManager,
@@ -82,6 +117,7 @@ class BaseSession(
82117
RequestId, MemoryObjectSendStream[JSONRPCResponse | JSONRPCError]
83118
]
84119
_request_id: int
120+
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
85121

86122
def __init__(
87123
self,
@@ -99,6 +135,7 @@ def __init__(
99135
self._receive_request_type = receive_request_type
100136
self._receive_notification_type = receive_notification_type
101137
self._read_timeout_seconds = read_timeout_seconds
138+
self._in_flight = {}
102139

103140
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
104141
anyio.create_memory_object_stream[
@@ -219,27 +256,36 @@ async def _receive_loop(self) -> None:
219256
by_alias=True, mode="json", exclude_none=True
220257
)
221258
)
259+
222260
responder = RequestResponder(
223261
request_id=message.root.id,
224262
request_meta=validated_request.root.params.meta
225263
if validated_request.root.params
226264
else None,
227265
request=validated_request,
228266
session=self,
267+
on_complete=lambda r: self._in_flight.pop(r.request_id, None),
229268
)
230269

270+
self._in_flight[responder.request_id] = responder
231271
await self._received_request(responder)
232-
if not responder._responded:
272+
if not responder._completed:
233273
await self._incoming_message_stream_writer.send(responder)
274+
234275
elif isinstance(message.root, JSONRPCNotification):
235276
notification = self._receive_notification_type.model_validate(
236277
message.root.model_dump(
237278
by_alias=True, mode="json", exclude_none=True
238279
)
239280
)
240-
241-
await self._received_notification(notification)
242-
await self._incoming_message_stream_writer.send(notification)
281+
# Handle cancellation notifications
282+
if isinstance(notification.root, CancelledNotification):
283+
cancelled_id = notification.root.params.requestId
284+
if cancelled_id in self._in_flight:
285+
await self._in_flight[cancelled_id].cancel()
286+
else:
287+
await self._received_notification(notification)
288+
await self._incoming_message_stream_writer.send(notification)
243289
else: # Response or error
244290
stream = self._response_streams.pop(message.root.id, None)
245291
if stream:

tests/shared/test_session.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
from typing import AsyncGenerator
2+
3+
import anyio
4+
import pytest
5+
6+
import mcp.types as types
7+
from mcp.client.session import ClientSession
8+
from mcp.server.lowlevel.server import Server
9+
from mcp.shared.exceptions import McpError
10+
from mcp.shared.memory import create_connected_server_and_client_session
11+
from mcp.types import (
12+
CancelledNotification,
13+
CancelledNotificationParams,
14+
ClientNotification,
15+
ClientRequest,
16+
EmptyResult,
17+
)
18+
19+
20+
@pytest.fixture
21+
def mcp_server() -> Server:
22+
return Server(name="test server")
23+
24+
25+
@pytest.fixture
26+
async def client_connected_to_server(
27+
mcp_server: Server,
28+
) -> AsyncGenerator[ClientSession, None]:
29+
async with create_connected_server_and_client_session(mcp_server) as client_session:
30+
yield client_session
31+
32+
33+
@pytest.mark.anyio
34+
async def test_in_flight_requests_cleared_after_completion(
35+
client_connected_to_server: ClientSession,
36+
):
37+
"""Verify that _in_flight is empty after all requests complete."""
38+
# Send a request and wait for response
39+
response = await client_connected_to_server.send_ping()
40+
assert isinstance(response, EmptyResult)
41+
42+
# Verify _in_flight is empty
43+
assert len(client_connected_to_server._in_flight) == 0
44+
45+
46+
@pytest.mark.anyio
47+
async def test_request_cancellation():
48+
"""Test that requests can be cancelled while in-flight."""
49+
# The tool is already registered in the fixture
50+
51+
ev_tool_called = anyio.Event()
52+
ev_cancelled = anyio.Event()
53+
request_id = None
54+
55+
# Start the request in a separate task so we can cancel it
56+
def make_server() -> Server:
57+
server = Server(name="TestSessionServer")
58+
59+
# Register the tool handler
60+
@server.call_tool()
61+
async def handle_call_tool(name: str, arguments: dict | None) -> list:
62+
nonlocal request_id, ev_tool_called
63+
if name == "slow_tool":
64+
request_id = server.request_context.request_id
65+
ev_tool_called.set()
66+
await anyio.sleep(10) # Long enough to ensure we can cancel
67+
return []
68+
raise ValueError(f"Unknown tool: {name}")
69+
70+
# Register the tool so it shows up in list_tools
71+
@server.list_tools()
72+
async def handle_list_tools() -> list[types.Tool]:
73+
return [
74+
types.Tool(
75+
name="slow_tool",
76+
description="A slow tool that takes 10 seconds to complete",
77+
inputSchema={},
78+
)
79+
]
80+
81+
return server
82+
83+
async def make_request(client_session):
84+
nonlocal ev_cancelled
85+
try:
86+
await client_session.send_request(
87+
ClientRequest(
88+
types.CallToolRequest(
89+
method="tools/call",
90+
params=types.CallToolRequestParams(
91+
name="slow_tool", arguments={}
92+
),
93+
)
94+
),
95+
types.CallToolResult,
96+
)
97+
pytest.fail("Request should have been cancelled")
98+
except McpError as e:
99+
# Expected - request was cancelled
100+
assert "Request cancelled" in str(e)
101+
ev_cancelled.set()
102+
103+
async with create_connected_server_and_client_session(
104+
make_server()
105+
) as client_session:
106+
async with anyio.create_task_group() as tg:
107+
tg.start_soon(make_request, client_session)
108+
109+
# Wait for the request to be in-flight
110+
with anyio.fail_after(1): # Timeout after 1 second
111+
await ev_tool_called.wait()
112+
113+
# Send cancellation notification
114+
assert request_id is not None
115+
await client_session.send_notification(
116+
ClientNotification(
117+
CancelledNotification(
118+
method="notifications/cancelled",
119+
params=CancelledNotificationParams(requestId=request_id),
120+
)
121+
)
122+
)
123+
124+
# Give cancellation time to process
125+
with anyio.fail_after(1):
126+
await ev_cancelled.wait()

0 commit comments

Comments
 (0)