Skip to content

Commit b4c01ed

Browse files
authored
Fix inference search (#3022)
* Fix inference search * forward compatible inference_provider_mapping * styling
1 parent 2855402 commit b4c01ed

File tree

8 files changed

+160
-63
lines changed

8 files changed

+160
-63
lines changed

docs/source/en/package_reference/hf_api.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ models = hf_api.list_models()
5757

5858
[[autodoc]] huggingface_hub.hf_api.GitRefs
5959

60+
### InferenceProviderMapping
61+
62+
[[autodoc]] huggingface_hub.hf_api.InferenceProviderMapping
63+
6064
### LFSFileInfo
6165

6266
[[autodoc]] huggingface_hub.hf_api.LFSFileInfo

src/huggingface_hub/hf_api.py

Lines changed: 54 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from itertools import islice
2929
from pathlib import Path
3030
from typing import (
31+
TYPE_CHECKING,
3132
Any,
3233
BinaryIO,
3334
Callable,
@@ -135,6 +136,9 @@
135136
from .utils.endpoint_helpers import _is_emission_within_threshold
136137

137138

139+
if TYPE_CHECKING:
140+
from .inference._providers import PROVIDER_T
141+
138142
R = TypeVar("R") # Return type
139143
CollectionItemType_T = Literal["model", "dataset", "space", "paper", "collection"]
140144

@@ -709,21 +713,26 @@ def __init__(self, **kwargs):
709713

710714
@dataclass
711715
class InferenceProviderMapping:
712-
hf_model_id: str
716+
provider: "PROVIDER_T" # Provider name
717+
hf_model_id: str # ID of the model on the Hugging Face Hub
718+
provider_id: str # ID of the model on the provider's side
713719
status: Literal["live", "staging"]
714-
provider_id: str
715720
task: str
716721

717722
adapter: Optional[str] = None
718723
adapter_weights_path: Optional[str] = None
724+
type: Optional[Literal["single-model", "tag-filter"]] = None
719725

720726
def __init__(self, **kwargs):
727+
self.provider = kwargs.pop("provider")
721728
self.hf_model_id = kwargs.pop("hf_model_id")
722-
self.status = kwargs.pop("status")
723729
self.provider_id = kwargs.pop("providerId")
730+
self.status = kwargs.pop("status")
724731
self.task = kwargs.pop("task")
732+
725733
self.adapter = kwargs.pop("adapter", None)
726734
self.adapter_weights_path = kwargs.pop("adapterWeightsPath", None)
735+
self.type = kwargs.pop("type", None)
727736
self.__dict__.update(**kwargs)
728737

729738

@@ -765,12 +774,10 @@ class ModelInfo:
765774
If so, whether there is manual or automatic approval.
766775
gguf (`Dict`, *optional*):
767776
GGUF information of the model.
768-
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
769-
Status of the model on the inference API.
770-
Warm models are available for immediate use. Cold models will be loaded on first inference call.
771-
Frozen models are not available in Inference API.
772-
inference_provider_mapping (`Dict`, *optional*):
773-
Model's inference provider mapping.
777+
inference (`Literal["warm"]`, *optional*):
778+
Status of the model on Inference Providers. Warm if the model is served by at least one provider.
779+
inference_provider_mapping (`List[InferenceProviderMapping]`, *optional*):
780+
A list of [`InferenceProviderMapping`] ordered after the user's provider order.
774781
likes (`int`):
775782
Number of likes of the model.
776783
library_name (`str`, *optional*):
@@ -815,8 +822,8 @@ class ModelInfo:
815822
downloads_all_time: Optional[int]
816823
gated: Optional[Literal["auto", "manual", False]]
817824
gguf: Optional[Dict]
818-
inference: Optional[Literal["warm", "cold", "frozen"]]
819-
inference_provider_mapping: Optional[Dict[str, InferenceProviderMapping]]
825+
inference: Optional[Literal["warm"]]
826+
inference_provider_mapping: Optional[List[InferenceProviderMapping]]
820827
likes: Optional[int]
821828
library_name: Optional[str]
822829
tags: Optional[List[str]]
@@ -852,14 +859,25 @@ def __init__(self, **kwargs):
852859
self.gguf = kwargs.pop("gguf", None)
853860

854861
self.inference = kwargs.pop("inference", None)
855-
self.inference_provider_mapping = kwargs.pop("inferenceProviderMapping", None)
856-
if self.inference_provider_mapping:
857-
self.inference_provider_mapping = {
858-
provider: InferenceProviderMapping(
859-
**{**value, "hf_model_id": self.id}
860-
) # little hack to simplify Inference Providers logic
861-
for provider, value in self.inference_provider_mapping.items()
862-
}
862+
863+
# little hack to simplify Inference Providers logic and make it backward and forward compatible
864+
# right now, API returns a dict on model_info and a list on list_models. Let's harmonize to list.
865+
mapping = kwargs.pop("inferenceProviderMapping", None)
866+
if isinstance(mapping, list):
867+
self.inference_provider_mapping = [
868+
InferenceProviderMapping(**{**value, "hf_model_id": self.id}) for value in mapping
869+
]
870+
elif isinstance(mapping, dict):
871+
self.inference_provider_mapping = [
872+
InferenceProviderMapping(**{**value, "hf_model_id": self.id, "provider": provider})
873+
for provider, value in mapping.items()
874+
]
875+
elif mapping is None:
876+
self.inference_provider_mapping = None
877+
else:
878+
raise ValueError(
879+
f"Unexpected type for `inferenceProviderMapping`. Expecting `dict` or `list`. Got {mapping}."
880+
)
863881

864882
self.tags = kwargs.pop("tags", None)
865883
self.pipeline_tag = kwargs.pop("pipeline_tag", None)
@@ -1836,7 +1854,8 @@ def list_models(
18361854
filter: Union[str, Iterable[str], None] = None,
18371855
author: Optional[str] = None,
18381856
gated: Optional[bool] = None,
1839-
inference: Optional[Literal["cold", "frozen", "warm"]] = None,
1857+
inference: Optional[Literal["warm"]] = None,
1858+
inference_provider: Optional[Union[Literal["all"], "PROVIDER_T", List["PROVIDER_T"]]] = None,
18401859
library: Optional[Union[str, List[str]]] = None,
18411860
language: Optional[Union[str, List[str]]] = None,
18421861
model_name: Optional[str] = None,
@@ -1870,10 +1889,11 @@ def list_models(
18701889
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
18711890
If `gated=True` is passed, only gated models are returned.
18721891
If `gated=False` is passed, only non-gated models are returned.
1873-
inference (`Literal["cold", "frozen", "warm"]`, *optional*):
1874-
A string to filter models on the Hub by their state on the Inference API.
1875-
Warm models are available for immediate use. Cold models will be loaded on first inference call.
1876-
Frozen models are not available in Inference API.
1892+
inference (`Literal["warm"]`, *optional*):
1893+
If "warm", filter models on the Hub currently served by at least one provider.
1894+
inference_provider (`Literal["all"]` or `str`, *optional*):
1895+
A string to filter models on the Hub that are served by a specific provider.
1896+
Pass `"all"` to get all models served by at least one provider.
18771897
library (`str` or `List`, *optional*):
18781898
A string or list of strings of foundational libraries models were
18791899
originally trained from, such as pytorch, tensorflow, or allennlp.
@@ -1933,7 +1953,7 @@ def list_models(
19331953
Returns:
19341954
`Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects.
19351955
1936-
Example usage with the `filter` argument:
1956+
Example:
19371957
19381958
```python
19391959
>>> from huggingface_hub import HfApi
@@ -1943,24 +1963,19 @@ def list_models(
19431963
# List all models
19441964
>>> api.list_models()
19451965
1946-
# List only the text classification models
1966+
# List text classification models
19471967
>>> api.list_models(filter="text-classification")
19481968
1949-
# List only models from the AllenNLP library
1950-
>>> api.list_models(filter="allennlp")
1951-
```
1952-
1953-
Example usage with the `search` argument:
1969+
# List models from the KerasHub library
1970+
>>> api.list_models(filter="keras-hub")
19541971
1955-
```python
1956-
>>> from huggingface_hub import HfApi
1957-
1958-
>>> api = HfApi()
1972+
# List models served by Cohere
1973+
>>> api.list_models(inference_provider="cohere")
19591974
1960-
# List all models with "bert" in their name
1975+
# List models with "bert" in their name
19611976
>>> api.list_models(search="bert")
19621977
1963-
# List all models with "bert" in their name made by google
1978+
# List models with "bert" in their name and pushed by google
19641979
>>> api.list_models(search="bert", author="google")
19651980
```
19661981
"""
@@ -2003,6 +2018,8 @@ def list_models(
20032018
params["gated"] = gated
20042019
if inference is not None:
20052020
params["inference"] = inference
2021+
if inference_provider is not None:
2022+
params["inference_provider"] = inference_provider
20062023
if pipeline_tag:
20072024
params["pipeline_tag"] = pipeline_tag
20082025
search_list = []

src/huggingface_hub/inference/_providers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def get_provider_helper(
183183
if model is None:
184184
raise ValueError("Specifying a model is required when provider is 'auto'")
185185
provider_mapping = _fetch_inference_provider_mapping(model)
186-
provider = next(iter(provider_mapping))
186+
provider = next(iter(provider_mapping)).provider
187187

188188
provider_tasks = PROVIDERS.get(provider) # type: ignore
189189
if provider_tasks is None:

src/huggingface_hub/inference/_providers/_common.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from functools import lru_cache
2-
from typing import Any, Dict, Optional, Union
2+
from typing import Any, Dict, List, Optional, Union
33

44
from huggingface_hub import constants
55
from huggingface_hub.hf_api import InferenceProviderMapping
@@ -9,6 +9,7 @@
99

1010
logger = logging.get_logger(__name__)
1111

12+
1213
# Dev purposes only.
1314
# If you want to try to run inference for a new model locally before it's registered on huggingface.co
1415
# for a given Inference Provider, you can add it to the following dictionary.
@@ -124,7 +125,12 @@ def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMappin
124125
if HARDCODED_MODEL_INFERENCE_MAPPING.get(self.provider, {}).get(model):
125126
return HARDCODED_MODEL_INFERENCE_MAPPING[self.provider][model]
126127

127-
provider_mapping = _fetch_inference_provider_mapping(model).get(self.provider)
128+
provider_mapping = None
129+
for mapping in _fetch_inference_provider_mapping(model):
130+
if mapping.provider == self.provider:
131+
provider_mapping = mapping
132+
break
133+
128134
if provider_mapping is None:
129135
raise ValueError(f"Model {model} is not supported by provider {self.provider}.")
130136

@@ -236,7 +242,7 @@ def _prepare_payload_as_dict(
236242

237243

238244
@lru_cache(maxsize=None)
239-
def _fetch_inference_provider_mapping(model: str) -> Dict:
245+
def _fetch_inference_provider_mapping(model: str) -> List["InferenceProviderMapping"]:
240246
"""
241247
Fetch provider mappings for a model from the Hub.
242248
"""

src/huggingface_hub/inference/_providers/hf_inference.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,19 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str:
2626

2727
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
2828
if model is not None and model.startswith(("http://", "https://")):
29-
return InferenceProviderMapping(providerId=model, hf_model_id=model, task=self.task, status="live")
29+
return InferenceProviderMapping(
30+
provider="hf-inference", providerId=model, hf_model_id=model, task=self.task, status="live"
31+
)
3032
model_id = model if model is not None else _fetch_recommended_models().get(self.task)
3133
if model_id is None:
3234
raise ValueError(
3335
f"Task {self.task} has no recommended model for HF Inference. Please specify a model"
3436
" explicitly. Visit https://huggingface.co/tasks for more info."
3537
)
3638
_check_supported_task(model_id, self.task)
37-
return InferenceProviderMapping(providerId=model_id, hf_model_id=model_id, task=self.task, status="live")
39+
return InferenceProviderMapping(
40+
provider="hf-inference", providerId=model_id, hf_model_id=model_id, task=self.task, status="live"
41+
)
3842

3943
def _prepare_url(self, api_key: str, mapped_model: str) -> str:
4044
# hf-inference provider can handle URLs (e.g. Inference Endpoints or TGI deployment)

src/huggingface_hub/inference/_providers/openai.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,6 @@ def _prepare_api_key(self, api_key: Optional[str]) -> str:
2020
def _prepare_mapping_info(self, model: Optional[str]) -> InferenceProviderMapping:
2121
if model is None:
2222
raise ValueError("Please provide an OpenAI model ID, e.g. `gpt-4o` or `o1`.")
23-
return InferenceProviderMapping(providerId=model, task="conversational", status="live", hf_model_id=model)
23+
return InferenceProviderMapping(
24+
provider="openai", providerId=model, task="conversational", status="live", hf_model_id=model
25+
)

tests/test_hf_api.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
ExpandModelProperty_T,
6161
ExpandSpaceProperty_T,
6262
InferenceEndpoint,
63+
InferenceProviderMapping,
6364
ModelInfo,
6465
RepoSibling,
6566
RepoUrl,
@@ -2511,6 +2512,38 @@ def test_not_a_safetensors_file(self) -> None:
25112512
"HuggingFaceH4/zephyr-7b-beta", "pytorch_model-00001-of-00008.bin"
25122513
)
25132514

2515+
def test_inference_provider_mapping_model_info(self):
2516+
model = self._api.model_info("deepseek-ai/DeepSeek-R1-0528", expand="inferenceProviderMapping")
2517+
mapping = model.inference_provider_mapping
2518+
assert isinstance(mapping, list)
2519+
assert len(mapping) > 0
2520+
for item in mapping:
2521+
assert isinstance(item, InferenceProviderMapping)
2522+
assert item.provider is not None
2523+
assert item.hf_model_id == "deepseek-ai/DeepSeek-R1-0528"
2524+
assert item.provider_id is not None
2525+
2526+
def test_inference_provider_mapping_list_models(self):
2527+
models = list(self._api.list_models(author="deepseek-ai", expand="inferenceProviderMapping", limit=1))
2528+
assert len(models) > 0
2529+
mapping = models[0].inference_provider_mapping
2530+
assert isinstance(mapping, list)
2531+
assert len(mapping) > 0
2532+
for item in mapping:
2533+
assert isinstance(item, InferenceProviderMapping)
2534+
assert item.provider is not None
2535+
assert item.hf_model_id is not None
2536+
assert item.provider_id is not None
2537+
2538+
def test_filter_models_by_inference_provider(self):
2539+
models = list(
2540+
self._api.list_models(inference_provider="hf-inference", expand=["inferenceProviderMapping"], limit=10)
2541+
)
2542+
assert len(models) > 0
2543+
for model in models:
2544+
assert model.inference_provider_mapping is not None
2545+
assert any(mapping.provider == "hf-inference" for mapping in model.inference_provider_mapping)
2546+
25142547

25152548
class HfApiPrivateTest(HfApiCommonTest):
25162549
def setUp(self) -> None:

0 commit comments

Comments
 (0)