@@ -1087,3 +1087,79 @@ def test_warning_if_bill_to_with_direct_calls(self):
1087
1087
match = "You've provided an external provider's API key, so requests will be billed directly by the provider." ,
1088
1088
):
1089
1089
InferenceClient (bill_to = "openai" , token = "replicate_key" , provider = "replicate" )
1090
+
1091
+
1092
+ @pytest .mark .parametrize (
1093
+ "client_init_arg, init_kwarg_name, expected_request_url, expected_payload_model" ,
1094
+ [
1095
+ # passing a custom endpoint in the model argument
1096
+ pytest .param (
1097
+ "https://my-custom-endpoint.com/custom_path" ,
1098
+ "model" ,
1099
+ "https://my-custom-endpoint.com/custom_path/v1/chat/completions" ,
1100
+ "dummy" ,
1101
+ id = "client_model_is_url" ,
1102
+ ),
1103
+ # passing a custom endpoint in the base_url argument
1104
+ pytest .param (
1105
+ "https://another-endpoint.com/v1/" ,
1106
+ "base_url" ,
1107
+ "https://another-endpoint.com/v1/chat/completions" ,
1108
+ "dummy" ,
1109
+ id = "client_base_url_is_url" ,
1110
+ ),
1111
+ # passing a model ID
1112
+ pytest .param (
1113
+ "username/repo_name" ,
1114
+ "model" ,
1115
+ "https://router.huggingface.co/hf-inference/models/username/repo_name/v1/chat/completions" ,
1116
+ "username/repo_name" ,
1117
+ id = "client_model_is_id" ,
1118
+ ),
1119
+ # passing a custom endpoint in the model argument
1120
+ pytest .param (
1121
+ "https://specific-chat-endpoint.com/v1/chat/completions" ,
1122
+ "model" ,
1123
+ "https://specific-chat-endpoint.com/v1/chat/completions" ,
1124
+ "dummy" ,
1125
+ id = "client_model_is_full_chat_url" ,
1126
+ ),
1127
+ # passing a localhost URL in the model argument
1128
+ pytest .param (
1129
+ "http://localhost:8080" ,
1130
+ "model" ,
1131
+ "http://localhost:8080/v1/chat/completions" ,
1132
+ "dummy" ,
1133
+ id = "client_model_is_localhost_url" ,
1134
+ ),
1135
+ # passing a localhost URL in the base_url argument
1136
+ pytest .param (
1137
+ "http://127.0.0.1:8000/custom/path/v1" ,
1138
+ "base_url" ,
1139
+ "http://127.0.0.1:8000/custom/path/v1/chat/completions" ,
1140
+ "dummy" ,
1141
+ id = "client_base_url_is_localhost_ip_with_path" ,
1142
+ ),
1143
+ ],
1144
+ )
1145
+ def test_chat_completion_url_resolution (
1146
+ mocker , client_init_arg , init_kwarg_name , expected_request_url , expected_payload_model
1147
+ ):
1148
+ init_kwargs = {init_kwarg_name : client_init_arg , "provider" : "hf-inference" }
1149
+ client = InferenceClient (** init_kwargs )
1150
+
1151
+ mock_response_content = b'{"choices": [{"message": {"content": "Mock response"}}]}'
1152
+ mocker .patch (
1153
+ "huggingface_hub.inference._providers.hf_inference._check_supported_task" ,
1154
+ return_value = None ,
1155
+ )
1156
+
1157
+ with patch .object (InferenceClient , "_inner_post" , return_value = mock_response_content ) as mock_inner_post :
1158
+ client .chat_completion (messages = [{"role" : "user" , "content" : "Hello?" }], stream = False )
1159
+
1160
+ mock_inner_post .assert_called_once ()
1161
+
1162
+ request_params = mock_inner_post .call_args [0 ][0 ]
1163
+ assert request_params .url == expected_request_url
1164
+ assert request_params .json is not None
1165
+ assert request_params .json .get ("model" ) == expected_payload_model
0 commit comments