Skip to content

Commit 203eec8

Browse files
Fix payload model name when model id is a URL (#2911)
* fix default model name when model id is a URL * better * Update test Co-authored-by: Lucain <[email protected]> --------- Co-authored-by: Lucain <[email protected]>
1 parent 1c5f2f9 commit 203eec8

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

src/huggingface_hub/inference/_providers/hf_inference.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def __init__(self):
8484
super().__init__("text-generation")
8585

8686
def _prepare_payload_as_dict(self, inputs: Any, parameters: Dict, mapped_model: str) -> Optional[Dict]:
87-
payload_model = "tgi" if mapped_model.startswith(("http://", "https://")) else mapped_model
87+
payload_model = parameters.get("model") or mapped_model
88+
89+
if payload_model is None or payload_model.startswith(("http://", "https://")):
90+
payload_model = "dummy"
91+
8892
return {**filter_none(parameters), "model": payload_model, "messages": inputs}
8993

9094
def _prepare_url(self, api_key: str, mapped_model: str) -> str:

tests/test_inference_providers.py

+47
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,53 @@ def test_prepare_request_conversational(self):
305305
"messages": [{"role": "user", "content": "dummy text input"}],
306306
}
307307

308+
@pytest.mark.parametrize(
309+
"mapped_model,parameters,expected_model",
310+
[
311+
(
312+
"username/repo_name",
313+
{},
314+
"username/repo_name",
315+
),
316+
# URL endpoint with model in parameters - use model from parameters
317+
(
318+
"http://localhost:8000/v1/chat/completions",
319+
{"model": "username/repo_name"},
320+
"username/repo_name",
321+
),
322+
# URL endpoint without model - fallback to dummy
323+
(
324+
"http://localhost:8000/v1/chat/completions",
325+
{},
326+
"dummy",
327+
),
328+
# HTTPS endpoint with model in parameters
329+
(
330+
"https://api.example.com/v1/chat/completions",
331+
{"model": "username/repo_name"},
332+
"username/repo_name",
333+
),
334+
# URL endpoint with other parameters - should still use dummy
335+
(
336+
"http://localhost:8000/v1/chat/completions",
337+
{"temperature": 0.7, "max_tokens": 100},
338+
"dummy",
339+
),
340+
],
341+
)
342+
def test_prepare_payload_as_dict_conversational(self, mapped_model, parameters, expected_model):
343+
helper = HFInferenceConversational()
344+
messages = [{"role": "user", "content": "Hello!"}]
345+
346+
payload = helper._prepare_payload_as_dict(
347+
inputs=messages,
348+
parameters=parameters,
349+
mapped_model=mapped_model,
350+
)
351+
352+
assert payload["model"] == expected_model
353+
assert payload["messages"] == messages
354+
308355

309356
class TestHyperbolicProvider:
310357
def test_prepare_route(self):

0 commit comments

Comments
 (0)