Skip to content

Commit 96726fa

Browse files
committed
feat: add request cancellation and in-flight request tracking
1 parent 0d7be0a commit 96726fa

File tree

1 file changed

+55
-18
lines changed

1 file changed

+55
-18
lines changed

src/mcp/shared/session.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from mcp.shared.exceptions import McpError
1212
from mcp.types import (
13+
CancelledNotification,
1314
ClientNotification,
1415
ClientRequest,
1516
ClientResult,
@@ -44,20 +45,36 @@ def __init__(
4445
request_meta: RequestParams.Meta | None,
4546
request: ReceiveRequestT,
4647
session: "BaseSession",
48+
cancel_scope: anyio.CancelScope | None,
4749
) -> None:
4850
self.request_id = request_id
4951
self.request_meta = request_meta
5052
self.request = request
5153
self._session = session
5254
self._responded = False
55+
self._cancel_scope = cancel_scope
5356

5457
async def respond(self, response: SendResultT | ErrorData) -> None:
5558
assert not self._responded, "Request already responded to"
56-
self._responded = True
5759

58-
await self._session._send_response(
59-
request_id=self.request_id, response=response
60-
)
60+
if not self.cancelled:
61+
self._responded = True
62+
63+
await self._session._send_response(
64+
request_id=self.request_id, response=response
65+
)
66+
67+
async def cancel(self) -> None:
68+
if self._cancel_scope is not None:
69+
self._cancel_scope.cancel()
70+
71+
@property
72+
def in_flight(self) -> bool:
73+
return not self._responded and not self.cancelled
74+
75+
@property
76+
def cancelled(self) -> bool:
77+
return self._cancel_scope is not None and self._cancel_scope.cancel_called
6178

6279

6380
class BaseSession(
@@ -205,12 +222,21 @@ async def _send_response(
205222
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
206223

207224
async def _receive_loop(self) -> None:
225+
in_flight: dict[RequestId, RequestResponder] = {}
226+
208227
async with (
209228
self._read_stream,
210229
self._write_stream,
211230
self._incoming_message_stream_writer,
212231
):
213232
async for message in self._read_stream:
233+
# Clean up completed requests
234+
in_flight = {
235+
req_id: responder
236+
for req_id, responder in in_flight.items()
237+
if responder.in_flight
238+
}
239+
214240
if isinstance(message, Exception):
215241
await self._incoming_message_stream_writer.send(message)
216242
elif isinstance(message.root, JSONRPCRequest):
@@ -219,27 +245,38 @@ async def _receive_loop(self) -> None:
219245
by_alias=True, mode="json", exclude_none=True
220246
)
221247
)
222-
responder = RequestResponder(
223-
request_id=message.root.id,
224-
request_meta=validated_request.root.params.meta
225-
if validated_request.root.params
226-
else None,
227-
request=validated_request,
228-
session=self,
229-
)
230248

231-
await self._received_request(responder)
232-
if not responder._responded:
233-
await self._incoming_message_stream_writer.send(responder)
249+
with anyio.CancelScope() as scope:
250+
responder = RequestResponder(
251+
request_id=message.root.id,
252+
request_meta=validated_request.root.params.meta
253+
if validated_request.root.params
254+
else None,
255+
request=validated_request,
256+
session=self,
257+
cancel_scope=scope,
258+
)
259+
260+
in_flight[message.root.id] = responder
261+
262+
await self._received_request(responder)
263+
if not responder._responded:
264+
await self._incoming_message_stream_writer.send(responder)
265+
234266
elif isinstance(message.root, JSONRPCNotification):
235267
notification = self._receive_notification_type.model_validate(
236268
message.root.model_dump(
237269
by_alias=True, mode="json", exclude_none=True
238270
)
239271
)
240-
241-
await self._received_notification(notification)
242-
await self._incoming_message_stream_writer.send(notification)
272+
# Handle cancellation notifications
273+
if isinstance(notification.root, CancelledNotification):
274+
cancelled_id = notification.root.params.requestId
275+
if cancelled_id in in_flight:
276+
await in_flight[cancelled_id].cancel()
277+
else:
278+
await self._received_notification(notification)
279+
await self._incoming_message_stream_writer.send(notification)
243280
else: # Response or error
244281
stream = self._response_streams.pop(message.root.id, None)
245282
if stream:

0 commit comments

Comments
 (0)