|
1 | 1 | import inspect
|
2 | 2 | import json
|
3 | 3 | import os
|
4 |
| -from dataclasses import asdict, dataclass, is_dataclass |
| 4 | +from dataclasses import Field, asdict, dataclass, is_dataclass |
5 | 5 | from pathlib import Path
|
6 |
| -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union |
| 6 | +from typing import Any, Callable, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, TypeVar, Union |
7 | 7 |
|
8 | 8 | import packaging.version
|
9 | 9 |
|
|
24 | 24 | )
|
25 | 25 |
|
26 | 26 |
|
27 |
| -if TYPE_CHECKING: |
28 |
| - from _typeshed import DataclassInstance |
29 |
| - |
30 | 27 | if is_torch_available():
|
31 | 28 | import torch # type: ignore
|
32 | 29 |
|
|
38 | 35 |
|
39 | 36 | logger = logging.get_logger(__name__)
|
40 | 37 |
|
| 38 | + |
| 39 | +# Type alias for dataclass instances, copied from https://github.com/python/typeshed/blob/9f28171658b9ca6c32a7cb93fbb99fc92b17858b/stdlib/_typeshed/__init__.pyi#L349 |
| 40 | +class DataclassInstance(Protocol): |
| 41 | + __dataclass_fields__: ClassVar[Dict[str, Field]] |
| 42 | + |
| 43 | + |
41 | 44 | # Generic variable that is either ModelHubMixin or a subclass thereof
|
42 | 45 | T = TypeVar("T", bound="ModelHubMixin")
|
43 | 46 | # Generic variable to represent an args type
|
@@ -175,7 +178,7 @@ class ModelHubMixin:
|
175 | 178 | ```
|
176 | 179 | """
|
177 | 180 |
|
178 |
| - _hub_mixin_config: Optional[Union[dict, "DataclassInstance"]] = None |
| 181 | + _hub_mixin_config: Optional[Union[dict, DataclassInstance]] = None |
179 | 182 | # ^ optional config attribute automatically set in `from_pretrained`
|
180 | 183 | _hub_mixin_info: MixinInfo
|
181 | 184 | # ^ information about the library integrating ModelHubMixin (used to generate model card)
|
@@ -366,7 +369,7 @@ def save_pretrained(
|
366 | 369 | self,
|
367 | 370 | save_directory: Union[str, Path],
|
368 | 371 | *,
|
369 |
| - config: Optional[Union[dict, "DataclassInstance"]] = None, |
| 372 | + config: Optional[Union[dict, DataclassInstance]] = None, |
370 | 373 | repo_id: Optional[str] = None,
|
371 | 374 | push_to_hub: bool = False,
|
372 | 375 | model_card_kwargs: Optional[Dict[str, Any]] = None,
|
@@ -618,7 +621,7 @@ def push_to_hub(
|
618 | 621 | self,
|
619 | 622 | repo_id: str,
|
620 | 623 | *,
|
621 |
| - config: Optional[Union[dict, "DataclassInstance"]] = None, |
| 624 | + config: Optional[Union[dict, DataclassInstance]] = None, |
622 | 625 | commit_message: str = "Push model using huggingface_hub.",
|
623 | 626 | private: Optional[bool] = None,
|
624 | 627 | token: Optional[str] = None,
|
@@ -825,7 +828,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
|
825 | 828 | return model
|
826 | 829 |
|
827 | 830 |
|
828 |
| -def _load_dataclass(datacls: Type["DataclassInstance"], data: dict) -> "DataclassInstance": |
| 831 | +def _load_dataclass(datacls: Type[DataclassInstance], data: dict) -> DataclassInstance: |
829 | 832 | """Load a dataclass instance from a dictionary.
|
830 | 833 |
|
831 | 834 | Fields not expected by the dataclass are ignored.
|
|
0 commit comments