forked from modelcontextprotocol/python-sdk
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmcp_stdio_client.py
136 lines (107 loc) · 4.34 KB
/
mcp_stdio_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from mcp import ClientSession, ListToolsResult, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.types import CallToolResult
from mcp import Tool as MCPTool
from contextlib import AsyncExitStack
from typing import Any
import asyncio
import logging
logger = logging.getLogger(__name__)
class NotificationLoggingClientSession(ClientSession):
def __init__(self, read_stream, write_stream):
print(f"NOTIFICATION LOGGING CLIENT SESSION")
super().__init__(read_stream, write_stream)
# override base session to log incoming notifications
async def _received_notification(self, notification):
print(f"NOTIFICATION:{notification}")
print(f"NOTIFICATION-END")
async def send_progress_notification(self, progress_token, progress, total):
print(f"PROGRESS:{progress_token}")
print(f"PROGRESS-END")
# adapted from mcp-python-sdk/examples/clients/simple-chatbot/mcp_simple_chatbot/main.py
class MCPClient:
"""Manages MCP server connections and tool execution."""
def __init__(self, name, server_params: StdioServerParameters, errlog=None):
self.name = name
self.server_params = server_params
self.errlog = errlog
self.stdio_context: Any | None = None
self.session: ClientSession | None = None
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
self.exit_stack: AsyncExitStack = AsyncExitStack()
async def initialize(self) -> None:
"""Initialize the server connection."""
try:
stdio_transport = await self.exit_stack.enter_async_context(
stdio_client(self.server_params)
)
read, write = stdio_transport
session = await self.exit_stack.enter_async_context(
# ClientSession(read, write)
NotificationLoggingClientSession(read, write)
)
await session.initialize()
self.session = session
except Exception as e:
logging.error(f"Error initializing server: {e}")
await self.cleanup()
raise
async def get_available_tools(self) -> list[MCPTool]:
"""List available tools from the server.
Returns:
A list of available tools.
Raises:
RuntimeError: If the server is not initialized.
"""
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
tools_response = await self.session.list_tools()
# Let's just ignore pagination for now
return tools_response.tools
async def call_tool(
self,
tool_name: str,
arguments: dict[str, Any],
retries: int = 2,
delay: float = 1.0,
) -> Any:
"""Execute a tool with retry mechanism.
Args:
tool_name: Name of the tool to execute.
arguments: Tool arguments.
retries: Number of retry attempts.
delay: Delay between retries in seconds.
Returns:
Tool execution result.
Raises:
RuntimeError: If server is not initialized.
Exception: If tool execution fails after all retries.
"""
if not self.session:
raise RuntimeError(f"Server {self.name} not initialized")
attempt = 0
while attempt < retries:
try:
logging.info(f"Executing {tool_name}...")
result = await self.session.call_tool(tool_name, arguments)
return result
except Exception as e:
attempt += 1
logging.warning(
f"Error executing tool: {e}. Attempt {attempt} of {retries}."
)
if attempt < retries:
logging.info(f"Retrying in {delay} seconds...")
await asyncio.sleep(delay)
else:
logging.error("Max retries reached. Failing.")
raise
async def cleanup(self) -> None:
"""Clean up server resources."""
async with self._cleanup_lock:
try:
await self.exit_stack.aclose()
self.session = None
self.stdio_context = None
except Exception as e:
logging.error(f"Error during cleanup of server {self.name}: {e}")