1
1
import logging
2
- from collections .abc import AsyncGenerator
3
- from contextlib import asynccontextmanager
4
2
from typing import (
5
3
Annotated ,
6
4
Any ,
7
5
)
8
6
9
- from mcp . client . session import ClientSession
10
- from mcp . client . sse import sse_client
7
+ from httpx import Client
8
+ from mcp import CallToolRequest , JSONRPCResponse , ListToolsResult
11
9
from mcp .server .fastmcp import FastMCP
12
10
from mcp .server .fastmcp .tools .base import Tool
13
11
from mcp .server .fastmcp .utilities .func_metadata import (
14
12
ArgModelBase ,
15
13
FuncMetadata ,
16
14
_get_typed_annotation ,
17
15
)
18
- from mcp .types import EmbeddedResource , ImageContent , TextContent
16
+ from mcp .types import (
17
+ CallToolRequestParams ,
18
+ CallToolResult ,
19
+ EmbeddedResource ,
20
+ ImageContent ,
21
+ TextContent ,
22
+ )
19
23
from mcp .types import Tool as RemoteTool
20
- from pydantic import Field , WithJsonSchema , create_model
24
+ from pydantic import Field , ValidationError , WithJsonSchema , create_model
21
25
from pydantic .fields import FieldInfo
22
26
from pydantic_core import PydanticUndefined
23
27
26
30
logger = logging .getLogger (__name__ )
27
31
28
32
29
- @asynccontextmanager
30
- async def sse_mcp_connection_context (
31
- url : str ,
32
- headers : dict [str , Any ] | None = None ,
33
- timeout : float = 5 ,
34
- sse_read_timeout : float = 60 * 5 ,
35
- ) -> AsyncGenerator [ClientSession , None ]:
36
- async with (
37
- sse_client (url , headers , timeout , sse_read_timeout ) as (read , write ),
38
- ClientSession (read , write ) as session ,
39
- ):
40
- await session .initialize ()
41
- yield session
42
-
43
-
44
33
# Based on this: https://github.com/modelcontextprotocol/python-sdk/blob/9ae4df85fbab97bf476ddd160b766ca4c208cd13/src/mcp/server/fastmcp/utilities/func_metadata.py#L105
45
34
def get_remote_tool_fn_metadata (tool : RemoteTool ) -> FuncMetadata :
46
35
dynamic_pydantic_model_params : dict [str , Any ] = {}
@@ -68,20 +57,15 @@ def get_remote_tool_fn_metadata(tool: RemoteTool) -> FuncMetadata:
68
57
)
69
58
70
59
71
- async def list_remote_tools (
72
- url : str ,
73
- headers : dict [str , Any ],
74
- ) -> list [RemoteTool ]:
75
- result : list [RemoteTool ] = []
60
+ def _get_remote_tools (config : Config , headers : dict [str , str ]) -> list [RemoteTool ]:
76
61
try :
77
- async with sse_mcp_connection_context (url , headers ) as session :
78
- result = (await session .list_tools ()).tools
62
+ with Client (base_url = config .remote_mcp_base_url , headers = headers ) as client :
63
+ list_tools_response = JSONRPCResponse .model_validate_json (
64
+ client .get ("/tools/list" ).text
65
+ )
66
+ return ListToolsResult .model_validate (list_tools_response .result ).tools
79
67
except Exception :
80
- # TODO: uncomment this when remote tools are available
81
- # and this is actually an error.
82
- # logger.error(f"Connection error while listing remote tools: {e}")
83
- pass
84
- return result
68
+ return []
85
69
86
70
87
71
async def register_remote_tools (dbt_mcp : FastMCP , config : Config ) -> None :
@@ -93,19 +77,60 @@ async def register_remote_tools(dbt_mcp: FastMCP, config: Config) -> None:
93
77
"x-dbt-dev-environment-id" : str (config .dev_environment_id ),
94
78
"x-dbt-user-id" : str (config .user_id ),
95
79
}
96
- remote_tools = await list_remote_tools (config .remote_mcp_url , headers )
97
- for tool in remote_tools :
80
+ for tool in _get_remote_tools (config = config , headers = headers ):
98
81
# Create a new function using a factory to avoid closure issues
99
82
def create_tool_function (tool_name : str ):
100
83
async def tool_function (
101
84
* args , ** kwargs
102
85
) -> list [TextContent | ImageContent | EmbeddedResource ]:
103
- async with sse_mcp_connection_context (
104
- config .remote_mcp_url , headers
105
- ) as session :
106
- return (
107
- await session .call_tool (name = tool_name , arguments = kwargs )
108
- ).content
86
+ with Client (
87
+ base_url = config .remote_mcp_base_url , headers = headers
88
+ ) as client :
89
+ tool_call_http_response = client .post (
90
+ "/tools/call" ,
91
+ json = CallToolRequest (
92
+ method = "tools/call" ,
93
+ params = CallToolRequestParams (
94
+ name = tool_name ,
95
+ arguments = kwargs ,
96
+ ),
97
+ ).model_dump (),
98
+ )
99
+ if tool_call_http_response .status_code != 200 :
100
+ return [
101
+ TextContent (
102
+ type = "text" ,
103
+ text = f"Failed to call tool { tool_name } with "
104
+ + f"status code: { tool_call_http_response .status_code } "
105
+ + f"error message: { tool_call_http_response .text } " ,
106
+ )
107
+ ]
108
+ try :
109
+ tool_call_jsonrpc_response = (
110
+ JSONRPCResponse .model_validate_json (
111
+ tool_call_http_response .text
112
+ )
113
+ )
114
+ tool_call_result = CallToolResult .model_validate (
115
+ tool_call_jsonrpc_response .result
116
+ )
117
+ except ValidationError as e :
118
+ return [
119
+ TextContent (
120
+ type = "text" ,
121
+ text = f"Failed to parse tool response for { tool_name } : "
122
+ + f"{ e } " ,
123
+ )
124
+ ]
125
+ if tool_call_result .isError :
126
+ return [
127
+ TextContent (
128
+ type = "text" ,
129
+ text = f"Tool { tool_name } reported an error: "
130
+ + f"{ tool_call_result .content } " ,
131
+ )
132
+ ]
133
+ return tool_call_result .content
109
134
110
135
return tool_function
111
136
0 commit comments