diff --git a/src/mcp/shared/progress.py b/src/mcp/shared/progress.py index 19ea5ede2..af064642d 100644 --- a/src/mcp/shared/progress.py +++ b/src/mcp/shared/progress.py @@ -1,4 +1,4 @@ -from contextlib import contextmanager +from contextlib import asynccontextmanager from dataclasses import dataclass, field from pydantic import BaseModel @@ -21,15 +21,29 @@ class ProgressContext: current: float = field(default=0.0, init=False) async def progress(self, amount: float) -> None: + """Update progress by the given amount and send notification.""" self.current += amount - await self.session.send_progress_notification( self.progress_token, self.current, total=self.total ) + async def final_progress(self) -> None: + """Send the final progress notification.""" + if self.total is not None and self.current < self.total: + self.current = self.total + await self.session.send_progress_notification( + self.progress_token, self.current, total=self.total + ) + + +@asynccontextmanager +async def progress(ctx: RequestContext, total: float | None = None): + """Context manager for progress tracking and notification. -@contextmanager -def progress(ctx: RequestContext, total: float | None = None): + Args: + ctx: Request context containing the session and progress token + total: Optional total progress amount + """ if ctx.meta is None or ctx.meta.progressToken is None: raise ValueError("No progress token provided") @@ -37,4 +51,4 @@ def progress(ctx: RequestContext, total: float | None = None): try: yield progress_ctx finally: - pass + await progress_ctx.final_progress() diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 3d3988ce1..075a8e5de 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -193,6 +193,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): # Using BaseSession as a context manager should not block on exit (this # would be very surprising behavior), so make sure to cancel the tasks # in the task group. + self._closed = True self._task_group.cancel_scope.cancel() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) @@ -256,12 +257,21 @@ async def send_notification(self, notification: SendNotificationT) -> None: Emits a notification, which is a one-way message that does not expect a response. """ - jsonrpc_notification = JSONRPCNotification( - jsonrpc="2.0", - **notification.model_dump(by_alias=True, mode="json", exclude_none=True), - ) + # Skip sending notifications if the session is closed + if self._closed: + return + + try: + jsonrpc_notification = JSONRPCNotification( + jsonrpc="2.0", + **notification.model_dump(by_alias=True, mode="json", exclude_none=True), + ) - await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + await self._write_stream.send(JSONRPCMessage(jsonrpc_notification)) + except Exception: + # Ignore notification send errors during session cleanup + if not self._closed: + raise async def _send_response( self, request_id: RequestId, response: SendResultT | ErrorData