From 95ae2758f22c02fde1751be02feb3ebdd1c663b7 Mon Sep 17 00:00:00 2001 From: Junpei Kawamoto Date: Mon, 24 Mar 2025 02:52:59 -0600 Subject: [PATCH 1/2] Fix context detection by checking for subclass relationship Previously, the code only checked for exact matches with `Context`. This update ensures that subclasses of `Context` are also correctly identified, improving flexibility and reliability. --- src/mcp/server/fastmcp/tools/base.py | 6 ++++-- tests/server/fastmcp/test_tool_manager.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/src/mcp/server/fastmcp/tools/base.py b/src/mcp/server/fastmcp/tools/base.py index e137e8456..92a216f56 100644 --- a/src/mcp/server/fastmcp/tools/base.py +++ b/src/mcp/server/fastmcp/tools/base.py @@ -2,7 +2,7 @@ import inspect from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, get_origin from pydantic import BaseModel, Field @@ -53,7 +53,9 @@ def from_function( if context_kwarg is None: sig = inspect.signature(fn) for param_name, param in sig.parameters.items(): - if param.annotation is Context: + if get_origin(param.annotation) is not None: + continue + if issubclass(param.annotation, Context): context_kwarg = param_name break diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index d2067583e..560dbbda9 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -242,6 +242,8 @@ def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" from mcp.server.fastmcp import Context + from mcp.server.session import ServerSessionT + from mcp.shared.context import LifespanContextT def tool_with_context(x: int, ctx: Context) -> str: return str(x) @@ -256,6 +258,14 @@ def tool_without_context(x: int) -> str: tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None + def tool_with_specialized_context( + x: int, ctx: Context[ServerSessionT, LifespanContextT] + ) -> str: + return str(x) + + tool = manager.add_tool(tool_with_specialized_context) + assert tool.context_kwarg == "ctx" + @pytest.mark.anyio async def test_context_injection(self): """Test that context is properly injected during tool execution.""" From 1287ac4c5047a628f9a6b7df4abffeb7b9e67f2c Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 10 Apr 2025 08:34:27 +0200 Subject: [PATCH 2/2] Drop imports --- tests/server/fastmcp/test_tool_manager.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/tests/server/fastmcp/test_tool_manager.py b/tests/server/fastmcp/test_tool_manager.py index 560dbbda9..8f52e3d85 100644 --- a/tests/server/fastmcp/test_tool_manager.py +++ b/tests/server/fastmcp/test_tool_manager.py @@ -4,8 +4,11 @@ import pytest from pydantic import BaseModel +from mcp.server.fastmcp import Context, FastMCP from mcp.server.fastmcp.exceptions import ToolError from mcp.server.fastmcp.tools import ToolManager +from mcp.server.session import ServerSessionT +from mcp.shared.context import LifespanContextT class TestAddTools: @@ -194,8 +197,6 @@ def concat_strs(vals: list[str] | str) -> str: @pytest.mark.anyio async def test_call_tool_with_complex_model(self): - from mcp.server.fastmcp import Context - class MyShrimpTank(BaseModel): class Shrimp(BaseModel): name: str @@ -223,8 +224,6 @@ def name_shrimp(tank: MyShrimpTank, ctx: Context) -> list[str]: class TestToolSchema: @pytest.mark.anyio async def test_context_arg_excluded_from_schema(self): - from mcp.server.fastmcp import Context - def something(a: int, ctx: Context) -> int: return a @@ -241,9 +240,6 @@ class TestContextHandling: def test_context_parameter_detection(self): """Test that context parameters are properly detected in Tool.from_function().""" - from mcp.server.fastmcp import Context - from mcp.server.session import ServerSessionT - from mcp.shared.context import LifespanContextT def tool_with_context(x: int, ctx: Context) -> str: return str(x) @@ -258,18 +254,17 @@ def tool_without_context(x: int) -> str: tool = manager.add_tool(tool_without_context) assert tool.context_kwarg is None - def tool_with_specialized_context( + def tool_with_parametrized_context( x: int, ctx: Context[ServerSessionT, LifespanContextT] ) -> str: return str(x) - tool = manager.add_tool(tool_with_specialized_context) + tool = manager.add_tool(tool_with_parametrized_context) assert tool.context_kwarg == "ctx" @pytest.mark.anyio async def test_context_injection(self): """Test that context is properly injected during tool execution.""" - from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -286,7 +281,6 @@ def tool_with_context(x: int, ctx: Context) -> str: @pytest.mark.anyio async def test_context_injection_async(self): """Test that context is properly injected in async tools.""" - from mcp.server.fastmcp import Context, FastMCP async def async_tool(x: int, ctx: Context) -> str: assert isinstance(ctx, Context) @@ -303,7 +297,6 @@ async def async_tool(x: int, ctx: Context) -> str: @pytest.mark.anyio async def test_context_optional(self): """Test that context is optional when calling tools.""" - from mcp.server.fastmcp import Context def tool_with_context(x: int, ctx: Context | None = None) -> str: return str(x) @@ -317,7 +310,6 @@ def tool_with_context(x: int, ctx: Context | None = None) -> str: @pytest.mark.anyio async def test_context_error_handling(self): """Test error handling when context injection fails.""" - from mcp.server.fastmcp import Context, FastMCP def tool_with_context(x: int, ctx: Context) -> str: raise ValueError("Test error")