28
28
from itertools import islice
29
29
from pathlib import Path
30
30
from typing import (
31
+ TYPE_CHECKING ,
31
32
Any ,
32
33
BinaryIO ,
33
34
Callable ,
135
136
from .utils .endpoint_helpers import _is_emission_within_threshold
136
137
137
138
139
+ if TYPE_CHECKING :
140
+ from .inference ._providers import PROVIDER_T
141
+
138
142
R = TypeVar ("R" ) # Return type
139
143
CollectionItemType_T = Literal ["model" , "dataset" , "space" , "paper" , "collection" ]
140
144
@@ -709,21 +713,26 @@ def __init__(self, **kwargs):
709
713
710
714
@dataclass
711
715
class InferenceProviderMapping :
712
- hf_model_id : str
716
+ provider : "PROVIDER_T" # Provider name
717
+ hf_model_id : str # ID of the model on the Hugging Face Hub
718
+ provider_id : str # ID of the model on the provider's side
713
719
status : Literal ["live" , "staging" ]
714
- provider_id : str
715
720
task : str
716
721
717
722
adapter : Optional [str ] = None
718
723
adapter_weights_path : Optional [str ] = None
724
+ type : Optional [Literal ["single-model" , "tag-filter" ]] = None
719
725
720
726
def __init__ (self , ** kwargs ):
727
+ self .provider = kwargs .pop ("provider" )
721
728
self .hf_model_id = kwargs .pop ("hf_model_id" )
722
- self .status = kwargs .pop ("status" )
723
729
self .provider_id = kwargs .pop ("providerId" )
730
+ self .status = kwargs .pop ("status" )
724
731
self .task = kwargs .pop ("task" )
732
+
725
733
self .adapter = kwargs .pop ("adapter" , None )
726
734
self .adapter_weights_path = kwargs .pop ("adapterWeightsPath" , None )
735
+ self .type = kwargs .pop ("type" , None )
727
736
self .__dict__ .update (** kwargs )
728
737
729
738
@@ -765,12 +774,10 @@ class ModelInfo:
765
774
If so, whether there is manual or automatic approval.
766
775
gguf (`Dict`, *optional*):
767
776
GGUF information of the model.
768
- inference (`Literal["cold", "frozen", "warm"]`, *optional*):
769
- Status of the model on the inference API.
770
- Warm models are available for immediate use. Cold models will be loaded on first inference call.
771
- Frozen models are not available in Inference API.
772
- inference_provider_mapping (`Dict`, *optional*):
773
- Model's inference provider mapping.
777
+ inference (`Literal["warm"]`, *optional*):
778
+ Status of the model on Inference Providers. Warm if the model is served by at least one provider.
779
+ inference_provider_mapping (`List[InferenceProviderMapping]`, *optional*):
780
+ A list of [`InferenceProviderMapping`] ordered after the user's provider order.
774
781
likes (`int`):
775
782
Number of likes of the model.
776
783
library_name (`str`, *optional*):
@@ -815,8 +822,8 @@ class ModelInfo:
815
822
downloads_all_time : Optional [int ]
816
823
gated : Optional [Literal ["auto" , "manual" , False ]]
817
824
gguf : Optional [Dict ]
818
- inference : Optional [Literal ["warm" , "cold" , "frozen" ]]
819
- inference_provider_mapping : Optional [Dict [ str , InferenceProviderMapping ]]
825
+ inference : Optional [Literal ["warm" ]]
826
+ inference_provider_mapping : Optional [List [ InferenceProviderMapping ]]
820
827
likes : Optional [int ]
821
828
library_name : Optional [str ]
822
829
tags : Optional [List [str ]]
@@ -852,14 +859,25 @@ def __init__(self, **kwargs):
852
859
self .gguf = kwargs .pop ("gguf" , None )
853
860
854
861
self .inference = kwargs .pop ("inference" , None )
855
- self .inference_provider_mapping = kwargs .pop ("inferenceProviderMapping" , None )
856
- if self .inference_provider_mapping :
857
- self .inference_provider_mapping = {
858
- provider : InferenceProviderMapping (
859
- ** {** value , "hf_model_id" : self .id }
860
- ) # little hack to simplify Inference Providers logic
861
- for provider , value in self .inference_provider_mapping .items ()
862
- }
862
+
863
+ # little hack to simplify Inference Providers logic and make it backward and forward compatible
864
+ # right now, API returns a dict on model_info and a list on list_models. Let's harmonize to list.
865
+ mapping = kwargs .pop ("inferenceProviderMapping" , None )
866
+ if isinstance (mapping , list ):
867
+ self .inference_provider_mapping = [
868
+ InferenceProviderMapping (** {** value , "hf_model_id" : self .id }) for value in mapping
869
+ ]
870
+ elif isinstance (mapping , dict ):
871
+ self .inference_provider_mapping = [
872
+ InferenceProviderMapping (** {** value , "hf_model_id" : self .id , "provider" : provider })
873
+ for provider , value in mapping .items ()
874
+ ]
875
+ elif mapping is None :
876
+ self .inference_provider_mapping = None
877
+ else :
878
+ raise ValueError (
879
+ f"Unexpected type for `inferenceProviderMapping`. Expecting `dict` or `list`. Got { mapping } ."
880
+ )
863
881
864
882
self .tags = kwargs .pop ("tags" , None )
865
883
self .pipeline_tag = kwargs .pop ("pipeline_tag" , None )
@@ -1836,7 +1854,8 @@ def list_models(
1836
1854
filter : Union [str , Iterable [str ], None ] = None ,
1837
1855
author : Optional [str ] = None ,
1838
1856
gated : Optional [bool ] = None ,
1839
- inference : Optional [Literal ["cold" , "frozen" , "warm" ]] = None ,
1857
+ inference : Optional [Literal ["warm" ]] = None ,
1858
+ inference_provider : Optional [Union [Literal ["all" ], "PROVIDER_T" , List ["PROVIDER_T" ]]] = None ,
1840
1859
library : Optional [Union [str , List [str ]]] = None ,
1841
1860
language : Optional [Union [str , List [str ]]] = None ,
1842
1861
model_name : Optional [str ] = None ,
@@ -1870,10 +1889,11 @@ def list_models(
1870
1889
A boolean to filter models on the Hub that are gated or not. By default, all models are returned.
1871
1890
If `gated=True` is passed, only gated models are returned.
1872
1891
If `gated=False` is passed, only non-gated models are returned.
1873
- inference (`Literal["cold", "frozen", "warm"]`, *optional*):
1874
- A string to filter models on the Hub by their state on the Inference API.
1875
- Warm models are available for immediate use. Cold models will be loaded on first inference call.
1876
- Frozen models are not available in Inference API.
1892
+ inference (`Literal["warm"]`, *optional*):
1893
+ If "warm", filter models on the Hub currently served by at least one provider.
1894
+ inference_provider (`Literal["all"]` or `str`, *optional*):
1895
+ A string to filter models on the Hub that are served by a specific provider.
1896
+ Pass `"all"` to get all models served by at least one provider.
1877
1897
library (`str` or `List`, *optional*):
1878
1898
A string or list of strings of foundational libraries models were
1879
1899
originally trained from, such as pytorch, tensorflow, or allennlp.
@@ -1933,7 +1953,7 @@ def list_models(
1933
1953
Returns:
1934
1954
`Iterable[ModelInfo]`: an iterable of [`huggingface_hub.hf_api.ModelInfo`] objects.
1935
1955
1936
- Example usage with the `filter` argument :
1956
+ Example:
1937
1957
1938
1958
```python
1939
1959
>>> from huggingface_hub import HfApi
@@ -1943,24 +1963,19 @@ def list_models(
1943
1963
# List all models
1944
1964
>>> api.list_models()
1945
1965
1946
- # List only the text classification models
1966
+ # List text classification models
1947
1967
>>> api.list_models(filter="text-classification")
1948
1968
1949
- # List only models from the AllenNLP library
1950
- >>> api.list_models(filter="allennlp")
1951
- ```
1952
-
1953
- Example usage with the `search` argument:
1969
+ # List models from the KerasHub library
1970
+ >>> api.list_models(filter="keras-hub")
1954
1971
1955
- ```python
1956
- >>> from huggingface_hub import HfApi
1957
-
1958
- >>> api = HfApi()
1972
+ # List models served by Cohere
1973
+ >>> api.list_models(inference_provider="cohere")
1959
1974
1960
- # List all models with "bert" in their name
1975
+ # List models with "bert" in their name
1961
1976
>>> api.list_models(search="bert")
1962
1977
1963
- # List all models with "bert" in their name made by google
1978
+ # List models with "bert" in their name and pushed by google
1964
1979
>>> api.list_models(search="bert", author="google")
1965
1980
```
1966
1981
"""
@@ -2003,6 +2018,8 @@ def list_models(
2003
2018
params ["gated" ] = gated
2004
2019
if inference is not None :
2005
2020
params ["inference" ] = inference
2021
+ if inference_provider is not None :
2022
+ params ["inference_provider" ] = inference_provider
2006
2023
if pipeline_tag :
2007
2024
params ["pipeline_tag" ] = pipeline_tag
2008
2025
search_list = []
0 commit comments