Skip to content

Commit d3dec49

Browse files
Fix typing.get_type_hints call on a ModelHubMixin (#2729)
* Add DataclassInstance for runtime type_checking * Add suggestion from code review * Fix type annotation * Add test * Actually fix type annotation (3.8) --------- Co-authored-by: Celina Hanouti <[email protected]>
1 parent 6bfa5dd commit d3dec49

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

src/huggingface_hub/hub_mixin.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import inspect
22
import json
33
import os
4-
from dataclasses import asdict, dataclass, is_dataclass
4+
from dataclasses import Field, asdict, dataclass, is_dataclass
55
from pathlib import Path
6-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union
6+
from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union
77

88
import packaging.version
99

@@ -24,9 +24,6 @@
2424
)
2525

2626

27-
if TYPE_CHECKING:
28-
from _typeshed import DataclassInstance
29-
3027
if is_torch_available():
3128
import torch # type: ignore
3229

@@ -38,6 +35,12 @@
3835

3936
logger = logging.get_logger(__name__)
4037

38+
39+
# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349
40+
class DataclassInstance(Protocol):
41+
__dataclass_fields__: ClassVar[Dict[str, Field]]
42+
43+
4144
# Generic variable that is either ModelHubMixin or a subclass thereof
4245
T = TypeVar("T", bound="ModelHubMixin")
4346
# Generic variable to represent an args type
@@ -175,7 +178,7 @@ class ModelHubMixin:
175178
```
176179
"""
177180

178-
_hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None
181+
_hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None
179182
# ^ optional config attribute automatically set in `from_pretrained`
180183
_hub_mixin_info: MixinInfo
181184
# ^ information about the library integrating ModelHubMixin (used to generate model card)
@@ -366,7 +369,7 @@ def save_pretrained(
366369
self,
367370
save_directory: Union[str, Path],
368371
*,
369-
config: Optional[Union[dict, "DataclassInstance"]] = None,
372+
config: Optional[Union[dict, DataclassInstance]] = None,
370373
repo_id: Optional[str] = None,
371374
push_to_hub: bool = False,
372375
model_card_kwargs: Optional[Dict[str, Any]] = None,
@@ -618,7 +621,7 @@ def push_to_hub(
618621
self,
619622
repo_id: str,
620623
*,
621-
config: Optional[Union[dict, "DataclassInstance"]] = None,
624+
config: Optional[Union[dict, DataclassInstance]] = None,
622625
commit_message: str = "Push model using huggingface_hub.",
623626
private: Optional[bool] = None,
624627
token: Optional[str] = None,
@@ -825,7 +828,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
825828
return model
826829

827830

828-
def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance":
831+
def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance:
829832
"""Load a dataclass instance from a dictionary.
830833
831834
Fields not expected by the dataclass are ignored.

tests/test_hub_mixin.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import unittest
55
from dataclasses import dataclass
66
from pathlib import Path
7-
from typing import Dict, Optional, Union
7+
from typing import Dict, Optional, Union, get_type_hints
88
from unittest.mock import Mock, patch
99

1010
import jedi
@@ -474,3 +474,24 @@ def dummy_example_for_test(self, x: str) -> str:
474474
source_lines = source.split("\n")
475475
completions = script.complete(len(source_lines), len(source_lines[-1]))
476476
assert any(completion.name == "dummy_example_for_test" for completion in completions)
477+
478+
def test_get_type_hints_works_as_expected(self):
479+
"""
480+
Ensure that `typing.get_type_hints` works as expected when inheriting from `ModelHubMixin`.
481+
482+
See https://github.com/huggingface/huggingface_hub/issues/2727.
483+
"""
484+
485+
class ModelWithHints(ModelHubMixin):
486+
def method_with_hints(self, x: int) -> str:
487+
return str(x)
488+
489+
assert get_type_hints(ModelWithHints) != {}
490+
491+
# Test method type hints on class
492+
hints = get_type_hints(ModelWithHints.method_with_hints)
493+
assert hints == {"x": int, "return": str}
494+
495+
# Test method type hints on instance
496+
model = ModelWithHints()
497+
assert get_type_hints(model.method_with_hints) == {"x": int, "return": str}

0 commit comments

Comments
 (0)