Skip to content

Commit 64b3fef

Browse files
authored
Stateless Remote MCP support (#76)
The current implementation of the Python SDK is stateful. There is work underway to fix this: modelcontextprotocol/python-sdk#443. However, currently we need to eject from the MCP framework to use stateless requests. This is to avoid issues described in this issue: modelcontextprotocol/python-sdk#520, for example. I tested this in Claude by listing and making a request to a remote tool.
1 parent 077ed20 commit 64b3fef

File tree

2 files changed

+69
-44
lines changed

2 files changed

+69
-44
lines changed

src/dbt_mcp/config/config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Config:
1818
remote_enabled: bool
1919
dbt_command: str
2020
multicell_account_prefix: str | None
21-
remote_mcp_url: str
21+
remote_mcp_base_url: str
2222

2323

2424
def load_config() -> Config:
@@ -100,8 +100,8 @@ def load_config() -> Config:
100100
remote_enabled=not disable_remote,
101101
dbt_command=dbt_path,
102102
multicell_account_prefix=multicell_account_prefix,
103-
remote_mcp_url=(
103+
remote_mcp_base_url=(
104104
"http://" if host and host.startswith("localhost") else "https://"
105105
)
106-
+ f"{host}/mcp/sse",
106+
+ f"{host}/mcp",
107107
)

src/dbt_mcp/remote/tools.py

Lines changed: 66 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,27 @@
11
import logging
2-
from collections.abc import AsyncGenerator
3-
from contextlib import asynccontextmanager
42
from typing import (
53
Annotated,
64
Any,
75
)
86

9-
from mcp.client.session import ClientSession
10-
from mcp.client.sse import sse_client
7+
from httpx import Client
8+
from mcp import CallToolRequest, JSONRPCResponse, ListToolsResult
119
from mcp.server.fastmcp import FastMCP
1210
from mcp.server.fastmcp.tools.base import Tool
1311
from mcp.server.fastmcp.utilities.func_metadata import (
1412
ArgModelBase,
1513
FuncMetadata,
1614
_get_typed_annotation,
1715
)
18-
from mcp.types import EmbeddedResource, ImageContent, TextContent
16+
from mcp.types import (
17+
CallToolRequestParams,
18+
CallToolResult,
19+
EmbeddedResource,
20+
ImageContent,
21+
TextContent,
22+
)
1923
from mcp.types import Tool as RemoteTool
20-
from pydantic import Field, WithJsonSchema, create_model
24+
from pydantic import Field, ValidationError, WithJsonSchema, create_model
2125
from pydantic.fields import FieldInfo
2226
from pydantic_core import PydanticUndefined
2327

@@ -26,21 +30,6 @@
2630
logger = logging.getLogger(__name__)
2731

2832

29-
@asynccontextmanager
30-
async def sse_mcp_connection_context(
31-
url: str,
32-
headers: dict[str, Any] | None = None,
33-
timeout: float = 5,
34-
sse_read_timeout: float = 60 * 5,
35-
) -> AsyncGenerator[ClientSession, None]:
36-
async with (
37-
sse_client(url, headers, timeout, sse_read_timeout) as (read, write),
38-
ClientSession(read, write) as session,
39-
):
40-
await session.initialize()
41-
yield session
42-
43-
4433
# Based on this: https://github.com/modelcontextprotocol/python-sdk/blob/9ae4df85fbab97bf476ddd160b766ca4c208cd13/src/mcp/server/fastmcp/utilities/func_metadata.py#L105
4534
def get_remote_tool_fn_metadata(tool: RemoteTool) -> FuncMetadata:
4635
dynamic_pydantic_model_params: dict[str, Any] = {}
@@ -68,20 +57,15 @@ def get_remote_tool_fn_metadata(tool: RemoteTool) -> FuncMetadata:
6857
)
6958

7059

71-
async def list_remote_tools(
72-
url: str,
73-
headers: dict[str, Any],
74-
) -> list[RemoteTool]:
75-
result: list[RemoteTool] = []
60+
def _get_remote_tools(config: Config, headers: dict[str, str]) -> list[RemoteTool]:
7661
try:
77-
async with sse_mcp_connection_context(url, headers) as session:
78-
result = (await session.list_tools()).tools
62+
with Client(base_url=config.remote_mcp_base_url, headers=headers) as client:
63+
list_tools_response = JSONRPCResponse.model_validate_json(
64+
client.get("/tools/list").text
65+
)
66+
return ListToolsResult.model_validate(list_tools_response.result).tools
7967
except Exception:
80-
# TODO: uncomment this when remote tools are available
81-
# and this is actually an error.
82-
# logger.error(f"Connection error while listing remote tools: {e}")
83-
pass
84-
return result
68+
return []
8569

8670

8771
async def register_remote_tools(dbt_mcp: FastMCP, config: Config) -> None:
@@ -93,19 +77,60 @@ async def register_remote_tools(dbt_mcp: FastMCP, config: Config) -> None:
9377
"x-dbt-dev-environment-id": str(config.dev_environment_id),
9478
"x-dbt-user-id": str(config.user_id),
9579
}
96-
remote_tools = await list_remote_tools(config.remote_mcp_url, headers)
97-
for tool in remote_tools:
80+
for tool in _get_remote_tools(config=config, headers=headers):
9881
# Create a new function using a factory to avoid closure issues
9982
def create_tool_function(tool_name: str):
10083
async def tool_function(
10184
*args, **kwargs
10285
) -> list[TextContent | ImageContent | EmbeddedResource]:
103-
async with sse_mcp_connection_context(
104-
config.remote_mcp_url, headers
105-
) as session:
106-
return (
107-
await session.call_tool(name=tool_name, arguments=kwargs)
108-
).content
86+
with Client(
87+
base_url=config.remote_mcp_base_url, headers=headers
88+
) as client:
89+
tool_call_http_response = client.post(
90+
"/tools/call",
91+
json=CallToolRequest(
92+
method="tools/call",
93+
params=CallToolRequestParams(
94+
name=tool_name,
95+
arguments=kwargs,
96+
),
97+
).model_dump(),
98+
)
99+
if tool_call_http_response.status_code != 200:
100+
return [
101+
TextContent(
102+
type="text",
103+
text=f"Failed to call tool {tool_name} with "
104+
+ f"status code: {tool_call_http_response.status_code} "
105+
+ f"error message: {tool_call_http_response.text}",
106+
)
107+
]
108+
try:
109+
tool_call_jsonrpc_response = (
110+
JSONRPCResponse.model_validate_json(
111+
tool_call_http_response.text
112+
)
113+
)
114+
tool_call_result = CallToolResult.model_validate(
115+
tool_call_jsonrpc_response.result
116+
)
117+
except ValidationError as e:
118+
return [
119+
TextContent(
120+
type="text",
121+
text=f"Failed to parse tool response for {tool_name}: "
122+
+ f"{e}",
123+
)
124+
]
125+
if tool_call_result.isError:
126+
return [
127+
TextContent(
128+
type="text",
129+
text=f"Tool {tool_name} reported an error: "
130+
+ f"{tool_call_result.content}",
131+
)
132+
]
133+
return tool_call_result.content
109134

110135
return tool_function
111136

0 commit comments

Comments
 (0)