diff --git a/src/mcp/client/stdio/__init__.py b/src/mcp/client/stdio/__init__.py index 83de57a2b..52db01618 100644 --- a/src/mcp/client/stdio/__init__.py +++ b/src/mcp/client/stdio/__init__.py @@ -1,3 +1,4 @@ +import contextvars import os import sys from contextlib import asynccontextmanager @@ -6,6 +7,7 @@ import anyio import anyio.lowlevel +from anyio.abc import Process from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from anyio.streams.text import TextReceiveStream from pydantic import BaseModel, Field @@ -92,6 +94,9 @@ class StdioServerParameters(BaseModel): """ +PROCESS_VAR: contextvars.ContextVar[Process] = contextvars.ContextVar("process") + + @asynccontextmanager async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stderr): """ @@ -169,9 +174,13 @@ async def stdin_writer(): ): tg.start_soon(stdout_reader) tg.start_soon(stdin_writer) + token = None try: + token = PROCESS_VAR.set(process) yield read_stream, write_stream finally: + if token is not None: + PROCESS_VAR.reset(token) # Clean up process to prevent any dangling orphaned processes if sys.platform == "win32": await terminate_windows_process(process) diff --git a/tests/client/test_stdio.py b/tests/client/test_stdio.py index 95747ffd1..9c8ad317a 100644 --- a/tests/client/test_stdio.py +++ b/tests/client/test_stdio.py @@ -2,7 +2,7 @@ import pytest -from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.client.stdio import PROCESS_VAR, StdioServerParameters, stdio_client from mcp.types import JSONRPCMessage, JSONRPCRequest, JSONRPCResponse tee: str = shutil.which("tee") # type: ignore @@ -34,6 +34,10 @@ async def test_stdio_client(): if len(read_messages) == 2: break + process = PROCESS_VAR.get() + assert process is not None + assert process.returncode is None + assert len(read_messages) == 2 assert read_messages[0] == JSONRPCMessage( root=JSONRPCRequest(jsonrpc="2.0", id=1, method="ping")