Skip to content

Commit b70c474

Browse files
[MCP] Add local/remote endpoint inference support (#3121)
* allow endpoint url in tiny-agents * nit * explicitly fail if no model or base_url is provided Co-authored-by: Lucain <[email protected]> --------- Co-authored-by: Lucain <[email protected]>
1 parent 5add979 commit b70c474

File tree

3 files changed

+32
-10
lines changed

3 files changed

+32
-10
lines changed

src/huggingface_hub/inference/_mcp/agent.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,16 @@ class Agent(MCPClient):
2020
</Tip>
2121
2222
Args:
23-
model (`str`):
23+
model (`str`, *optional*):
2424
The model to run inference with. Can be a model id hosted on the Hugging Face Hub, e.g. `meta-llama/Meta-Llama-3-8B-Instruct`
2525
or a URL to a deployed Inference Endpoint or other local or remote endpoint.
2626
servers (`Iterable[Dict]`):
2727
MCP servers to connect to. Each server is a dictionary containing a `type` key and a `config` key. The `type` key can be `"stdio"` or `"sse"`, and the `config` key is a dictionary of arguments for the server.
2828
provider (`str`, *optional*):
2929
Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
3030
If model is a URL or `base_url` is passed, then `provider` is not used.
31+
base_url (`str`, *optional*):
32+
The base URL to run inference. Defaults to None.
3133
api_key (`str`, *optional*):
3234
Token to use for authentication. Will default to the locally Hugging Face saved token if not provided. You can also use your own provider API key to interact directly with the provider's service.
3335
prompt (`str`, *optional*):
@@ -37,13 +39,14 @@ class Agent(MCPClient):
3739
def __init__(
3840
self,
3941
*,
40-
model: str,
42+
model: Optional[str] = None,
4143
servers: Iterable[Dict],
4244
provider: Optional[PROVIDER_OR_POLICY_T] = None,
45+
base_url: Optional[str] = None,
4346
api_key: Optional[str] = None,
4447
prompt: Optional[str] = None,
4548
):
46-
super().__init__(model=model, provider=provider, api_key=api_key)
49+
super().__init__(model=model, provider=provider, base_url=base_url, api_key=api_key)
4750
self._servers_cfg = list(servers)
4851
self.messages: List[Union[Dict, ChatCompletionInputMessage]] = [
4952
{"role": "system", "content": prompt or DEFAULT_SYSTEM_PROMPT}

src/huggingface_hub/inference/_mcp/cli.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import os
33
import signal
4+
import traceback
45
from functools import partial
56
from typing import Any, Dict, List, Optional
67

@@ -71,8 +72,9 @@ def _sigint_handler() -> None:
7172
# Windows (or any loop that doesn't support it) : fall back to sync
7273
signal.signal(signal.SIGINT, lambda *_: _sigint_handler())
7374
async with Agent(
74-
provider=config["provider"],
75-
model=config["model"],
75+
provider=config.get("provider"),
76+
model=config.get("model"),
77+
base_url=config.get("endpointUrl"),
7678
servers=servers,
7779
prompt=prompt,
7880
) as agent:
@@ -123,9 +125,15 @@ def _sigint_handler() -> None:
123125
print()
124126

125127
except Exception as e:
126-
print(f"\n[bold red]Error during agent run: {e}[/bold red]", flush=True)
128+
tb_str = traceback.format_exc()
129+
print(f"\n[bold red]Error during agent run: {e}\n{tb_str}[/bold red]", flush=True)
127130
first_sigint = True # Allow graceful interrupt for the next command
128131

132+
except Exception as e:
133+
tb_str = traceback.format_exc()
134+
print(f"\n[bold red]An unexpected error occurred: {e}\n{tb_str}[/bold red]", flush=True)
135+
raise e
136+
129137
finally:
130138
if sigint_registered_in_loop:
131139
try:

src/huggingface_hub/inference/_mcp/mcp_client.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,34 @@ class MCPClient:
6969
provider (`str`, *optional*):
7070
Name of the provider to use for inference. Defaults to "auto" i.e. the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers.
7171
If model is a URL or `base_url` is passed, then `provider` is not used.
72+
base_url (`str`, *optional*):
73+
The base URL to run inference. Defaults to None.
7274
api_key (`str`, `optional`):
7375
Token to use for authentication. Will default to the locally Hugging Face saved token if not provided. You can also use your own provider API key to interact directly with the provider's service.
7476
"""
7577

7678
def __init__(
7779
self,
7880
*,
79-
model: str,
81+
model: Optional[str] = None,
8082
provider: Optional[PROVIDER_OR_POLICY_T] = None,
83+
base_url: Optional[str] = None,
8184
api_key: Optional[str] = None,
8285
):
8386
# Initialize MCP sessions as a dictionary of ClientSession objects
8487
self.sessions: Dict[ToolName, "ClientSession"] = {}
8588
self.exit_stack = AsyncExitStack()
8689
self.available_tools: List[ChatCompletionInputTool] = []
87-
88-
# Initialize the AsyncInferenceClient
89-
self.client = AsyncInferenceClient(model=model, provider=provider, api_key=api_key)
90+
# To be able to send the model in the payload if `base_url` is provided
91+
if model is None and base_url is None:
92+
raise ValueError("At least one of `model` or `base_url` should be set in `MCPClient`.")
93+
self.payload_model = model
94+
self.client = AsyncInferenceClient(
95+
model=None if base_url is not None else model,
96+
provider=provider,
97+
api_key=api_key,
98+
base_url=base_url,
99+
)
90100

91101
async def __aenter__(self):
92102
"""Enter the context manager"""
@@ -244,6 +254,7 @@ async def process_single_turn_with_tools(
244254

245255
# Create the streaming request
246256
response = await self.client.chat.completions.create(
257+
model=self.payload_model,
247258
messages=messages,
248259
tools=tools,
249260
tool_choice="auto",

0 commit comments

Comments
 (0)