Skip to content

Commit 8090c58

Browse files
committed
Close all resources
1 parent 8d1c0c5 commit 8090c58

File tree

8 files changed

+61
-14
lines changed

8 files changed

+61
-14
lines changed

pyproject.toml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ packages = ["src/mcp"]
6767
include = ["src/mcp", "tests"]
6868
venvPath = "."
6969
venv = ".venv"
70+
strict = [
71+
"src/mcp/server/fastmcp/tools/base.py",
72+
]
7073

7174
[tool.ruff.lint]
7275
select = ["E", "F", "I"]
@@ -85,3 +88,10 @@ members = ["examples/servers/*"]
8588

8689
[tool.uv.sources]
8790
mcp = { workspace = true }
91+
92+
# TODO(Marcelo): This should be enabled!!! There are a lot of resource warnings.
93+
[tool.pytest.ini_options]
94+
xfail_strict = true
95+
filterwarnings = [
96+
"error",
97+
]

src/mcp/client/session.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ async def _default_list_roots_callback(
4343
)
4444

4545

46-
ClientResponse = TypeAdapter(types.ClientResult | types.ErrorData)
46+
ClientResponse: TypeAdapter[types.ClientResult | types.ErrorData] = TypeAdapter(
47+
types.ClientResult | types.ErrorData
48+
)
4749

4850

4951
class ClientSession(
@@ -219,7 +221,7 @@ async def unsubscribe_resource(self, uri: AnyUrl) -> types.EmptyResult:
219221
)
220222

221223
async def call_tool(
222-
self, name: str, arguments: dict | None = None
224+
self, name: str, arguments: dict[str, Any] | None = None
223225
) -> types.CallToolResult:
224226
"""Send a tools/call request."""
225227
return await self.send_request(
@@ -258,7 +260,9 @@ async def get_prompt(
258260
)
259261

260262
async def complete(
261-
self, ref: types.ResourceReference | types.PromptReference, argument: dict
263+
self,
264+
ref: types.ResourceReference | types.PromptReference,
265+
argument: dict[str, str],
262266
) -> types.CompleteResult:
263267
"""Send a completion/complete request."""
264268
return await self.send_request(

src/mcp/server/fastmcp/tools/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
class Tool(BaseModel):
1919
"""Internal tool registration info."""
2020

21-
fn: Callable = Field(exclude=True)
21+
fn: Callable[..., Any] = Field(exclude=True)
2222
name: str = Field(description="Name of the tool")
2323
description: str = Field(description="Description of what the tool does")
24-
parameters: dict = Field(description="JSON schema for tool parameters")
24+
parameters: dict[str, Any] = Field(description="JSON schema for tool parameters")
2525
fn_metadata: FuncMetadata = Field(
2626
description="Metadata about the function including a pydantic model for tool"
2727
" arguments"
@@ -34,7 +34,7 @@ class Tool(BaseModel):
3434
@classmethod
3535
def from_function(
3636
cls,
37-
fn: Callable,
37+
fn: Callable[..., Any],
3838
name: str | None = None,
3939
description: str | None = None,
4040
context_kwarg: str | None = None,

src/mcp/server/fastmcp/utilities/func_metadata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ def pre_parse_json(self, data: dict[str, Any]) -> dict[str, Any]:
102102
)
103103

104104

105-
def func_metadata(func: Callable, skip_names: Sequence[str] = ()) -> FuncMetadata:
105+
def func_metadata(
106+
func: Callable[..., Any], skip_names: Sequence[str] = ()
107+
) -> FuncMetadata:
106108
"""Given a function, return metadata including a pydantic model representing its
107109
signature.
108110

src/mcp/shared/session.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from contextlib import AsyncExitStack
23
from datetime import timedelta
34
from typing import Any, Callable, Generic, TypeVar
45

@@ -180,13 +181,20 @@ def __init__(
180181
self._read_timeout_seconds = read_timeout_seconds
181182
self._in_flight = {}
182183

184+
self._exit_stack = AsyncExitStack()
183185
self._incoming_message_stream_writer, self._incoming_message_stream_reader = (
184186
anyio.create_memory_object_stream[
185187
RequestResponder[ReceiveRequestT, SendResultT]
186188
| ReceiveNotificationT
187189
| Exception
188190
]()
189191
)
192+
self._exit_stack.push_async_callback(
193+
lambda: self._incoming_message_stream_reader.aclose()
194+
)
195+
self._exit_stack.push_async_callback(
196+
lambda: self._incoming_message_stream_writer.aclose()
197+
)
190198

191199
async def __aenter__(self) -> Self:
192200
self._task_group = anyio.create_task_group()
@@ -195,6 +203,7 @@ async def __aenter__(self) -> Self:
195203
return self
196204

197205
async def __aexit__(self, exc_type, exc_val, exc_tb):
206+
await self._exit_stack.aclose()
198207
# Using BaseSession as a context manager should not block on exit (this
199208
# would be very surprising behavior), so make sure to cancel the tasks
200209
# in the task group.
@@ -222,6 +231,9 @@ async def send_request(
222231
](1)
223232
self._response_streams[request_id] = response_stream
224233

234+
self._exit_stack.push_async_callback(lambda: response_stream.aclose())
235+
self._exit_stack.push_async_callback(lambda: response_stream_reader.aclose())
236+
225237
jsonrpc_request = JSONRPCRequest(
226238
jsonrpc="2.0",
227239
id=request_id,
@@ -255,9 +267,6 @@ async def send_request(
255267
raise McpError(response_or_error.error)
256268
else:
257269
return result_type.model_validate(response_or_error.result)
258-
finally:
259-
await response_stream.aclose()
260-
await response_stream_reader.aclose()
261270

262271
async def send_notification(self, notification: SendNotificationT) -> None:
263272
"""

tests/client/test_session.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ async def listen_session():
8383
async with (
8484
ClientSession(server_to_client_receive, client_to_server_send) as session,
8585
anyio.create_task_group() as tg,
86+
client_to_server_send,
87+
client_to_server_receive,
88+
server_to_client_send,
89+
server_to_client_receive,
8690
):
8791
tg.start_soon(mock_server)
8892
tg.start_soon(listen_session)

tests/issues/test_192_request_id.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,13 @@ async def run_server():
4343
)
4444

4545
# Start server task
46-
async with anyio.create_task_group() as tg:
46+
async with (
47+
anyio.create_task_group() as tg,
48+
client_writer,
49+
client_reader,
50+
server_writer,
51+
server_reader,
52+
):
4753
tg.start_soon(run_server)
4854

4955
# Send initialize request

tests/server/test_lifespan.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async def test_lowlevel_server_lifespan():
2525
"""Test that lifespan works in low-level server."""
2626

2727
@asynccontextmanager
28-
async def test_lifespan(server: Server) -> AsyncIterator[dict]:
28+
async def test_lifespan(server: Server) -> AsyncIterator[dict[str, bool]]:
2929
"""Test lifespan context that tracks startup/shutdown."""
3030
context = {"started": False, "shutdown": False}
3131
try:
@@ -50,7 +50,13 @@ async def check_lifespan(name: str, arguments: dict) -> list:
5050
return [{"type": "text", "text": "true"}]
5151

5252
# Run server in background task
53-
async with anyio.create_task_group() as tg:
53+
async with (
54+
anyio.create_task_group() as tg,
55+
send_stream1,
56+
receive_stream1,
57+
send_stream2,
58+
receive_stream2,
59+
):
5460

5561
async def run_server():
5662
await server.run(
@@ -147,7 +153,13 @@ def check_lifespan(ctx: Context) -> bool:
147153
return True
148154

149155
# Run server in background task
150-
async with anyio.create_task_group() as tg:
156+
async with (
157+
anyio.create_task_group() as tg,
158+
send_stream1,
159+
receive_stream1,
160+
send_stream2,
161+
receive_stream2,
162+
):
151163

152164
async def run_server():
153165
await server._mcp_server.run(

0 commit comments

Comments
 (0)