Skip to content

Commit 4816bed

Browse files
committed
fix: improve session cleanup and progress handling
1. Add proper session cleanup handling - Track session state with _closed flag - Handle cancellation and cleanup errors gracefully - Skip notification validation during cleanup 2. Improve progress context - Add final_progress method - Send completion notification in finally block - Handle progress cleanup properly This fix addresses issues with session cleanup causing validation errors and improves progress notification reliability.
1 parent f10665d commit 4816bed

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed

src/mcp/shared/progress.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from contextlib import contextmanager
1+
from contextlib import asynccontextmanager
22
from dataclasses import dataclass, field
33

44
from pydantic import BaseModel
@@ -21,20 +21,34 @@ class ProgressContext:
2121
current: float = field(default=0.0, init=False)
2222

2323
async def progress(self, amount: float) -> None:
24+
"""Update progress by the given amount and send notification."""
2425
self.current += amount
25-
2626
await self.session.send_progress_notification(
2727
self.progress_token, self.current, total=self.total
2828
)
2929

30+
async def final_progress(self) -> None:
31+
"""Send the final progress notification."""
32+
if self.total is not None and self.current < self.total:
33+
self.current = self.total
34+
await self.session.send_progress_notification(
35+
self.progress_token, self.current, total=self.total
36+
)
37+
38+
39+
@asynccontextmanager
40+
async def progress(ctx: RequestContext, total: float | None = None):
41+
"""Context manager for progress tracking and notification.
3042
31-
@contextmanager
32-
def progress(ctx: RequestContext, total: float | None = None):
43+
Args:
44+
ctx: Request context containing the session and progress token
45+
total: Optional total progress amount
46+
"""
3347
if ctx.meta is None or ctx.meta.progressToken is None:
3448
raise ValueError("No progress token provided")
3549

3650
progress_ctx = ProgressContext(ctx.session, ctx.meta.progressToken, total)
3751
try:
3852
yield progress_ctx
3953
finally:
40-
pass
54+
await progress_ctx.final_progress()

src/mcp/shared/session.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
193193
# Using BaseSession as a context manager should not block on exit (this
194194
# would be very surprising behavior), so make sure to cancel the tasks
195195
# in the task group.
196+
self._closed = True
196197
self._task_group.cancel_scope.cancel()
197198
return await self._task_group.__aexit__(exc_type, exc_val, exc_tb)
198199

@@ -256,12 +257,21 @@ async def send_notification(self, notification: SendNotificationT) -> None:
256257
Emits a notification, which is a one-way message that does not expect
257258
a response.
258259
"""
259-
jsonrpc_notification = JSONRPCNotification(
260-
jsonrpc="2.0",
261-
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
262-
)
260+
# Skip sending notifications if the session is closed
261+
if self._closed:
262+
return
263+
264+
try:
265+
jsonrpc_notification = JSONRPCNotification(
266+
jsonrpc="2.0",
267+
**notification.model_dump(by_alias=True, mode="json", exclude_none=True),
268+
)
263269

264-
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
270+
await self._write_stream.send(JSONRPCMessage(jsonrpc_notification))
271+
except Exception:
272+
# Ignore notification send errors during session cleanup
273+
if not self._closed:
274+
raise
265275

266276
async def _send_response(
267277
self, request_id: RequestId, response: SendResultT | ErrorData
@@ -279,6 +289,19 @@ async def _send_response(
279289
)
280290
await self._write_stream.send(JSONRPCMessage(jsonrpc_response))
281291

292+
def _should_validate_notification(self, message_root: JSONRPCNotification) -> bool:
293+
"""
294+
Determines if a notification should be validated.
295+
Internal notifications (like cancelled) should be ignored.
296+
"""
297+
try:
298+
return (
299+
getattr(message_root, "method", None) != "cancelled" and
300+
not self._closed
301+
)
302+
except:
303+
return False
304+
282305
async def _receive_loop(self) -> None:
283306
async with (
284307
self._read_stream,
@@ -378,4 +401,4 @@ def incoming_messages(
378401
| ReceiveNotificationT
379402
| Exception
380403
]:
381-
return self._incoming_message_stream_reader
404+
return self._incoming_message_stream_reader

0 commit comments

Comments
 (0)