Skip to content

Commit 1e40c62

Browse files
[Inference Providers] fix inference with URL endpoints (#3041)
* fix inference with url endpoints * style * parentheses * add test * Update tests/test_inference_client.py Co-authored-by: Lucain <[email protected]> * Update tests/test_inference_client.py Co-authored-by: Lucain <[email protected]> * Update tests/test_inference_client.py --------- Co-authored-by: Lucain <[email protected]>
1 parent caeaeeb commit 1e40c62

File tree

4 files changed

+93
-3
lines changed

4 files changed

+93
-3
lines changed

src/huggingface_hub/inference/_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,13 @@ def chat_completion(
883883
payload_model = model or self.model
884884

885885
# Get the provider helper
886-
provider_helper = get_provider_helper(self.provider, task="conversational", model=payload_model)
886+
provider_helper = get_provider_helper(
887+
self.provider,
888+
task="conversational",
889+
model=model_id_or_url
890+
if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
891+
else payload_model,
892+
)
887893

888894
# Prepare the payload
889895
parameters = {

src/huggingface_hub/inference/_generated/_async_client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,13 @@ async def chat_completion(
923923
payload_model = model or self.model
924924

925925
# Get the provider helper
926-
provider_helper = get_provider_helper(self.provider, task="conversational", model=payload_model)
926+
provider_helper = get_provider_helper(
927+
self.provider,
928+
task="conversational",
929+
model=model_id_or_url
930+
if model_id_or_url is not None and model_id_or_url.startswith(("http://", "https://"))
931+
else payload_model,
932+
)
927933

928934
# Prepare the payload
929935
parameters = {

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def get_provider_helper(
147147
ValueError: If provider or task is not supported
148148
"""
149149

150-
if model is None and provider in (None, "auto"):
150+
if (model is None and provider in (None, "auto")) or (
151+
model is not None and model.startswith(("http://", "https://"))
152+
):
151153
provider = "hf-inference"
152154

153155
if provider is None:

tests/test_inference_client.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,3 +1087,79 @@ def test_warning_if_bill_to_with_direct_calls(self):
10871087
match="You've provided an external provider's API key, so requests will be billed directly by the provider.",
10881088
):
10891089
InferenceClient(bill_to="openai", token="replicate_key", provider="replicate")
1090+
1091+
1092+
@pytest.mark.parametrize(
1093+
"client_init_arg, init_kwarg_name, expected_request_url, expected_payload_model",
1094+
[
1095+
# passing a custom endpoint in the model argument
1096+
pytest.param(
1097+
"https://my-custom-endpoint.com/custom_path",
1098+
"model",
1099+
"https://my-custom-endpoint.com/custom_path/v1/chat/completions",
1100+
"dummy",
1101+
id="client_model_is_url",
1102+
),
1103+
# passing a custom endpoint in the base_url argument
1104+
pytest.param(
1105+
"https://another-endpoint.com/v1/",
1106+
"base_url",
1107+
"https://another-endpoint.com/v1/chat/completions",
1108+
"dummy",
1109+
id="client_base_url_is_url",
1110+
),
1111+
# passing a model ID
1112+
pytest.param(
1113+
"username/repo_name",
1114+
"model",
1115+
"https://router.huggingface.co/hf-inference/models/username/repo_name/v1/chat/completions",
1116+
"username/repo_name",
1117+
id="client_model_is_id",
1118+
),
1119+
# passing a custom endpoint in the model argument
1120+
pytest.param(
1121+
"https://specific-chat-endpoint.com/v1/chat/completions",
1122+
"model",
1123+
"https://specific-chat-endpoint.com/v1/chat/completions",
1124+
"dummy",
1125+
id="client_model_is_full_chat_url",
1126+
),
1127+
# passing a localhost URL in the model argument
1128+
pytest.param(
1129+
"http://localhost:8080",
1130+
"model",
1131+
"http://localhost:8080/v1/chat/completions",
1132+
"dummy",
1133+
id="client_model_is_localhost_url",
1134+
),
1135+
# passing a localhost URL in the base_url argument
1136+
pytest.param(
1137+
"http://127.0.0.1:8000/custom/path/v1",
1138+
"base_url",
1139+
"http://127.0.0.1:8000/custom/path/v1/chat/completions",
1140+
"dummy",
1141+
id="client_base_url_is_localhost_ip_with_path",
1142+
),
1143+
],
1144+
)
1145+
def test_chat_completion_url_resolution(
1146+
mocker, client_init_arg, init_kwarg_name, expected_request_url, expected_payload_model
1147+
):
1148+
init_kwargs = {init_kwarg_name: client_init_arg, "provider": "hf-inference"}
1149+
client = InferenceClient(**init_kwargs)
1150+
1151+
mock_response_content = b'{"choices": [{"message": {"content": "Mock response"}}]}'
1152+
mocker.patch(
1153+
"huggingface_hub.inference._providers.hf_inference._check_supported_task",
1154+
return_value=None,
1155+
)
1156+
1157+
with patch.object(InferenceClient, "_inner_post", return_value=mock_response_content) as mock_inner_post:
1158+
client.chat_completion(messages=[{"role": "user", "content": "Hello?"}], stream=False)
1159+
1160+
mock_inner_post.assert_called_once()
1161+
1162+
request_params = mock_inner_post.call_args[0][0]
1163+
assert request_params.url == expected_request_url
1164+
assert request_params.json is not None
1165+
assert request_params.json.get("model") == expected_payload_model

0 commit comments

Comments
 (0)