Skip to content

Commit 466e1e8

Browse files
committed
refactor: modernize type hints and improve async context handling
- Update type hints to use Python 3.10 syntax (dict, list, X | None) - Replace requests with httpx for HTTP client consistency - Improve async context management using AsyncExitStack - Simplify server cleanup method
1 parent a0216c3 commit 466e1e8

File tree

1 file changed

+30
-48
lines changed
  • examples/clients/simple-chatbot/mcp_simple_chatbot

1 file changed

+30
-48
lines changed

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

Lines changed: 30 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
import logging
44
import os
55
import shutil
6-
from typing import Any, Dict, List, Optional
6+
from contextlib import AsyncExitStack
7+
from typing import Any
78

8-
import requests
9+
import httpx
910
from dotenv import load_dotenv
1011
from mcp import ClientSession, StdioServerParameters
1112
from mcp.client.stdio import stdio_client
@@ -30,7 +31,7 @@ def load_env() -> None:
3031
load_dotenv()
3132

3233
@staticmethod
33-
def load_config(file_path: str) -> Dict[str, Any]:
34+
def load_config(file_path: str) -> dict[str, Any]:
3435
"""Load server configuration from JSON file.
3536
3637
Args:
@@ -64,12 +65,13 @@ def llm_api_key(self) -> str:
6465
class Server:
6566
"""Manages MCP server connections and tool execution."""
6667

67-
def __init__(self, name: str, config: Dict[str, Any]) -> None:
68+
def __init__(self, name: str, config: dict[str, Any]) -> None:
6869
self.name: str = name
69-
self.config: Dict[str, Any] = config
70-
self.stdio_context: Optional[Any] = None
71-
self.session: Optional[ClientSession] = None
70+
self.config: dict[str, Any] = config
71+
self.stdio_context: Any | None = None
72+
self.session: ClientSession | None = None
7273
self._cleanup_lock: asyncio.Lock = asyncio.Lock()
74+
self.exit_stack: AsyncExitStack = AsyncExitStack()
7375

7476
async def initialize(self) -> None:
7577
"""Initialize the server connection."""
@@ -89,17 +91,16 @@ async def initialize(self) -> None:
8991
else None,
9092
)
9193
try:
92-
self.stdio_context = stdio_client(server_params)
93-
read, write = await self.stdio_context.__aenter__()
94-
self.session = ClientSession(read, write)
95-
await self.session.__aenter__()
94+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
95+
read, write = stdio_transport
96+
self.session = await self.exit_stack.enter_async_context(ClientSession(read, write))
9697
await self.session.initialize()
9798
except Exception as e:
9899
logging.error(f"Error initializing server {self.name}: {e}")
99100
await self.cleanup()
100101
raise
101102

102-
async def list_tools(self) -> List[Any]:
103+
async def list_tools(self) -> list[Any]:
103104
"""List available tools from the server.
104105
105106
Returns:
@@ -124,7 +125,7 @@ async def list_tools(self) -> List[Any]:
124125
async def execute_tool(
125126
self,
126127
tool_name: str,
127-
arguments: Dict[str, Any],
128+
arguments: dict[str, Any],
128129
retries: int = 2,
129130
delay: float = 1.0,
130131
) -> Any:
@@ -170,29 +171,9 @@ async def cleanup(self) -> None:
170171
"""Clean up server resources."""
171172
async with self._cleanup_lock:
172173
try:
173-
if self.session:
174-
try:
175-
await self.session.__aexit__(None, None, None)
176-
except Exception as e:
177-
logging.warning(
178-
f"Warning during session cleanup for {self.name}: {e}"
179-
)
180-
finally:
181-
self.session = None
182-
183-
if self.stdio_context:
184-
try:
185-
await self.stdio_context.__aexit__(None, None, None)
186-
except (RuntimeError, asyncio.CancelledError) as e:
187-
logging.info(
188-
f"Note: Normal shutdown message for {self.name}: {e}"
189-
)
190-
except Exception as e:
191-
logging.warning(
192-
f"Warning during stdio cleanup for {self.name}: {e}"
193-
)
194-
finally:
195-
self.stdio_context = None
174+
await self.exit_stack.aclose()
175+
self.session = None
176+
self.stdio_context = None
196177
except Exception as e:
197178
logging.error(f"Error during cleanup of server {self.name}: {e}")
198179

@@ -201,11 +182,11 @@ class Tool:
201182
"""Represents a tool with its properties and formatting."""
202183

203184
def __init__(
204-
self, name: str, description: str, input_schema: Dict[str, Any]
185+
self, name: str, description: str, input_schema: dict[str, Any]
205186
) -> None:
206187
self.name: str = name
207188
self.description: str = description
208-
self.input_schema: Dict[str, Any] = input_schema
189+
self.input_schema: dict[str, Any] = input_schema
209190

210191
def format_for_llm(self) -> str:
211192
"""Format tool information for LLM.
@@ -237,7 +218,7 @@ class LLMClient:
237218
def __init__(self, api_key: str) -> None:
238219
self.api_key: str = api_key
239220

240-
def get_response(self, messages: List[Dict[str, str]]) -> str:
221+
def get_response(self, messages: list[dict[str, str]]) -> str:
241222
"""Get a response from the LLM.
242223
243224
Args:
@@ -247,7 +228,7 @@ def get_response(self, messages: List[Dict[str, str]]) -> str:
247228
The LLM's response as a string.
248229
249230
Raises:
250-
RequestException: If the request to the LLM fails.
231+
httpx.RequestError: If the request to the LLM fails.
251232
"""
252233
url = "https://api.groq.com/openai/v1/chat/completions"
253234

@@ -266,16 +247,17 @@ def get_response(self, messages: List[Dict[str, str]]) -> str:
266247
}
267248

268249
try:
269-
response = requests.post(url, headers=headers, json=payload)
270-
response.raise_for_status()
271-
data = response.json()
272-
return data["choices"][0]["message"]["content"]
250+
with httpx.Client() as client:
251+
response = client.post(url, headers=headers, json=payload)
252+
response.raise_for_status()
253+
data = response.json()
254+
return data["choices"][0]["message"]["content"]
273255

274-
except requests.exceptions.RequestException as e:
256+
except httpx.RequestError as e:
275257
error_message = f"Error getting LLM response: {str(e)}"
276258
logging.error(error_message)
277259

278-
if e.response is not None:
260+
if hasattr(e, 'response'):
279261
status_code = e.response.status_code
280262
logging.error(f"Status code: {status_code}")
281263
logging.error(f"Response details: {e.response.text}")
@@ -289,8 +271,8 @@ def get_response(self, messages: List[Dict[str, str]]) -> str:
289271
class ChatSession:
290272
"""Orchestrates the interaction between user, LLM, and tools."""
291273

292-
def __init__(self, servers: List[Server], llm_client: LLMClient) -> None:
293-
self.servers: List[Server] = servers
274+
def __init__(self, servers: list[Server], llm_client: LLMClient) -> None:
275+
self.servers: list[Server] = servers
294276
self.llm_client: LLMClient = llm_client
295277

296278
async def cleanup_servers(self) -> None:

0 commit comments

Comments
 (0)