3
3
import logging
4
4
import os
5
5
import shutil
6
- from typing import Any , Dict , List , Optional
6
+ from contextlib import AsyncExitStack
7
+ from typing import Any
7
8
8
- import requests
9
+ import httpx
9
10
from dotenv import load_dotenv
10
11
from mcp import ClientSession , StdioServerParameters
11
12
from mcp .client .stdio import stdio_client
@@ -30,7 +31,7 @@ def load_env() -> None:
30
31
load_dotenv ()
31
32
32
33
@staticmethod
33
- def load_config (file_path : str ) -> Dict [str , Any ]:
34
+ def load_config (file_path : str ) -> dict [str , Any ]:
34
35
"""Load server configuration from JSON file.
35
36
36
37
Args:
@@ -64,12 +65,13 @@ def llm_api_key(self) -> str:
64
65
class Server :
65
66
"""Manages MCP server connections and tool execution."""
66
67
67
- def __init__ (self , name : str , config : Dict [str , Any ]) -> None :
68
+ def __init__ (self , name : str , config : dict [str , Any ]) -> None :
68
69
self .name : str = name
69
- self .config : Dict [str , Any ] = config
70
- self .stdio_context : Optional [ Any ] = None
71
- self .session : Optional [ ClientSession ] = None
70
+ self .config : dict [str , Any ] = config
71
+ self .stdio_context : Any | None = None
72
+ self .session : ClientSession | None = None
72
73
self ._cleanup_lock : asyncio .Lock = asyncio .Lock ()
74
+ self .exit_stack : AsyncExitStack = AsyncExitStack ()
73
75
74
76
async def initialize (self ) -> None :
75
77
"""Initialize the server connection."""
@@ -89,17 +91,16 @@ async def initialize(self) -> None:
89
91
else None ,
90
92
)
91
93
try :
92
- self .stdio_context = stdio_client (server_params )
93
- read , write = await self .stdio_context .__aenter__ ()
94
- self .session = ClientSession (read , write )
95
- await self .session .__aenter__ ()
94
+ stdio_transport = await self .exit_stack .enter_async_context (stdio_client (server_params ))
95
+ read , write = stdio_transport
96
+ self .session = await self .exit_stack .enter_async_context (ClientSession (read , write ))
96
97
await self .session .initialize ()
97
98
except Exception as e :
98
99
logging .error (f"Error initializing server { self .name } : { e } " )
99
100
await self .cleanup ()
100
101
raise
101
102
102
- async def list_tools (self ) -> List [Any ]:
103
+ async def list_tools (self ) -> list [Any ]:
103
104
"""List available tools from the server.
104
105
105
106
Returns:
@@ -124,7 +125,7 @@ async def list_tools(self) -> List[Any]:
124
125
async def execute_tool (
125
126
self ,
126
127
tool_name : str ,
127
- arguments : Dict [str , Any ],
128
+ arguments : dict [str , Any ],
128
129
retries : int = 2 ,
129
130
delay : float = 1.0 ,
130
131
) -> Any :
@@ -170,29 +171,9 @@ async def cleanup(self) -> None:
170
171
"""Clean up server resources."""
171
172
async with self ._cleanup_lock :
172
173
try :
173
- if self .session :
174
- try :
175
- await self .session .__aexit__ (None , None , None )
176
- except Exception as e :
177
- logging .warning (
178
- f"Warning during session cleanup for { self .name } : { e } "
179
- )
180
- finally :
181
- self .session = None
182
-
183
- if self .stdio_context :
184
- try :
185
- await self .stdio_context .__aexit__ (None , None , None )
186
- except (RuntimeError , asyncio .CancelledError ) as e :
187
- logging .info (
188
- f"Note: Normal shutdown message for { self .name } : { e } "
189
- )
190
- except Exception as e :
191
- logging .warning (
192
- f"Warning during stdio cleanup for { self .name } : { e } "
193
- )
194
- finally :
195
- self .stdio_context = None
174
+ await self .exit_stack .aclose ()
175
+ self .session = None
176
+ self .stdio_context = None
196
177
except Exception as e :
197
178
logging .error (f"Error during cleanup of server { self .name } : { e } " )
198
179
@@ -201,11 +182,11 @@ class Tool:
201
182
"""Represents a tool with its properties and formatting."""
202
183
203
184
def __init__ (
204
- self , name : str , description : str , input_schema : Dict [str , Any ]
185
+ self , name : str , description : str , input_schema : dict [str , Any ]
205
186
) -> None :
206
187
self .name : str = name
207
188
self .description : str = description
208
- self .input_schema : Dict [str , Any ] = input_schema
189
+ self .input_schema : dict [str , Any ] = input_schema
209
190
210
191
def format_for_llm (self ) -> str :
211
192
"""Format tool information for LLM.
@@ -237,7 +218,7 @@ class LLMClient:
237
218
def __init__ (self , api_key : str ) -> None :
238
219
self .api_key : str = api_key
239
220
240
- def get_response (self , messages : List [ Dict [str , str ]]) -> str :
221
+ def get_response (self , messages : list [ dict [str , str ]]) -> str :
241
222
"""Get a response from the LLM.
242
223
243
224
Args:
@@ -247,7 +228,7 @@ def get_response(self, messages: List[Dict[str, str]]) -> str:
247
228
The LLM's response as a string.
248
229
249
230
Raises:
250
- RequestException : If the request to the LLM fails.
231
+ httpx.RequestError : If the request to the LLM fails.
251
232
"""
252
233
url = "https://api.groq.com/openai/v1/chat/completions"
253
234
@@ -266,16 +247,17 @@ def get_response(self, messages: List[Dict[str, str]]) -> str:
266
247
}
267
248
268
249
try :
269
- response = requests .post (url , headers = headers , json = payload )
270
- response .raise_for_status ()
271
- data = response .json ()
272
- return data ["choices" ][0 ]["message" ]["content" ]
250
+ with httpx .Client () as client :
251
+ response = client .post (url , headers = headers , json = payload )
252
+ response .raise_for_status ()
253
+ data = response .json ()
254
+ return data ["choices" ][0 ]["message" ]["content" ]
273
255
274
- except requests . exceptions . RequestException as e :
256
+ except httpx . RequestError as e :
275
257
error_message = f"Error getting LLM response: { str (e )} "
276
258
logging .error (error_message )
277
259
278
- if e . response is not None :
260
+ if hasattr ( e , 'response' ) :
279
261
status_code = e .response .status_code
280
262
logging .error (f"Status code: { status_code } " )
281
263
logging .error (f"Response details: { e .response .text } " )
@@ -289,8 +271,8 @@ def get_response(self, messages: List[Dict[str, str]]) -> str:
289
271
class ChatSession :
290
272
"""Orchestrates the interaction between user, LLM, and tools."""
291
273
292
- def __init__ (self , servers : List [Server ], llm_client : LLMClient ) -> None :
293
- self .servers : List [Server ] = servers
274
+ def __init__ (self , servers : list [Server ], llm_client : LLMClient ) -> None :
275
+ self .servers : list [Server ] = servers
294
276
self .llm_client : LLMClient = llm_client
295
277
296
278
async def cleanup_servers (self ) -> None :
0 commit comments