Skip to content

Commit 69b4bb1

Browse files
committed
Add type hints to dtypes/dtypes.py (CategoricalDtype)
1 parent 17247ed commit 69b4bb1

File tree

2 files changed

+59
-44
lines changed

2 files changed

+59
-44
lines changed

pandas/core/arrays/categorical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def ordered(self):
429429
return self.dtype.ordered
430430

431431
@property
432-
def dtype(self):
432+
def dtype(self) -> 'CategoricalDtype':
433433
"""
434434
The :class:`~pandas.api.types.CategoricalDtype` for this instance
435435
"""

pandas/core/dtypes/dtypes.py

+58-43
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
""" define extension dtypes """
22
import re
3-
from typing import Any, Dict, Optional, Tuple, Type
3+
import typing
4+
from typing import Any, Dict, List, Optional, Tuple, Type
45
import warnings
56

67
import numpy as np
@@ -15,10 +16,8 @@
1516
from .base import ExtensionDtype, _DtypeOpsMixin
1617
from .inference import is_list_like
1718

18-
str_type = str
1919

20-
21-
def register_extension_dtype(cls):
20+
def register_extension_dtype(cls: 'ExtensionDtype') -> 'ExtensionDtype':
2221
"""
2322
Register an ExtensionType with pandas as class decorator.
2423
@@ -60,20 +59,20 @@ class Registry:
6059
These are tried in order.
6160
"""
6261
def __init__(self):
63-
self.dtypes = []
62+
self.dtypes = [] # type: List[ExtensionDtype]
6463

65-
def register(self, dtype):
64+
def register(self, dtype: 'ExtensionDtype') -> None:
6665
"""
6766
Parameters
6867
----------
6968
dtype : ExtensionDtype
7069
"""
71-
if not issubclass(dtype, (PandasExtensionDtype, ExtensionDtype)):
70+
if not issubclass(dtype, ExtensionDtype):
7271
raise ValueError("can only register pandas extension dtypes")
7372

7473
self.dtypes.append(dtype)
7574

76-
def find(self, dtype):
75+
def find(self, dtype: 'ExtensionDtype') -> Optional[ExtensionDtype]:
7776
"""
7877
Parameters
7978
----------
@@ -117,25 +116,25 @@ class PandasExtensionDtype(_DtypeOpsMixin):
117116
# and ExtensionDtype's @properties in the subclasses below. The kind and
118117
# type variables in those subclasses are explicitly typed below.
119118
subdtype = None
120-
str = None # type: Optional[str_type]
119+
str = None # type: Optional[str]
121120
num = 100
122121
shape = tuple() # type: Tuple[int, ...]
123122
itemsize = 8
124123
base = None
125124
isbuiltin = 0
126125
isnative = 0
127-
_cache = {} # type: Dict[str_type, 'PandasExtensionDtype']
126+
_cache = {} # type: Dict[str, 'PandasExtensionDtype']
128127

129-
def __unicode__(self):
128+
def __unicode__(self) -> str:
130129
return self.name
131130

132-
def __str__(self):
131+
def __str__(self) -> str:
133132
"""
134133
Return a string representation for a particular Object
135134
"""
136135
return self.__unicode__()
137136

138-
def __bytes__(self):
137+
def __bytes__(self) -> bytes:
139138
"""
140139
Return a string representation for a particular object.
141140
"""
@@ -144,22 +143,22 @@ def __bytes__(self):
144143
encoding = get_option("display.encoding")
145144
return self.__unicode__().encode(encoding, 'replace')
146145

147-
def __repr__(self):
146+
def __repr__(self) -> str:
148147
"""
149148
Return a string representation for a particular object.
150149
"""
151150
return str(self)
152151

153-
def __hash__(self):
152+
def __hash__(self) -> int:
154153
raise NotImplementedError("sub-classes should implement an __hash__ "
155154
"method")
156155

157-
def __getstate__(self):
156+
def __getstate__(self) -> Dict[str, Any]:
158157
# pickle support; we don't want to pickle the cache
159158
return {k: getattr(self, k, None) for k in self._metadata}
160159

161160
@classmethod
162-
def reset_cache(cls):
161+
def reset_cache(cls) -> None:
163162
""" clear the cache """
164163
cls._cache = {}
165164

@@ -217,23 +216,30 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
217216
# TODO: Document public vs. private API
218217
name = 'category'
219218
type = CategoricalDtypeType # type: Type[CategoricalDtypeType]
220-
kind = 'O' # type: str_type
219+
kind = 'O' # type: str
221220
str = '|O08'
222221
base = np.dtype('O')
223222
_metadata = ('categories', 'ordered')
224-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
223+
_cache = {} # type: Dict[str, PandasExtensionDtype]
225224

226-
def __init__(self, categories=None, ordered=None):
225+
def __init__(self, categories=None, ordered: bool = None):
227226
self._finalize(categories, ordered, fastpath=False)
228227

229228
@classmethod
230-
def _from_fastpath(cls, categories=None, ordered=None):
229+
def _from_fastpath(cls,
230+
categories=None,
231+
ordered: bool = None
232+
) -> 'CategoricalDtype':
231233
self = cls.__new__(cls)
232234
self._finalize(categories, ordered, fastpath=True)
233235
return self
234236

235237
@classmethod
236-
def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
238+
def _from_categorical_dtype(cls,
239+
dtype: 'CategoricalDtype',
240+
categories=None,
241+
ordered: bool = None,
242+
) -> 'CategoricalDtype':
237243
if categories is ordered is None:
238244
return dtype
239245
if categories is None:
@@ -243,8 +249,12 @@ def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
243249
return cls(categories, ordered)
244250

245251
@classmethod
246-
def _from_values_or_dtype(cls, values=None, categories=None, ordered=None,
247-
dtype=None):
252+
def _from_values_or_dtype(cls,
253+
values=None,
254+
categories=None,
255+
ordered: bool = None,
256+
dtype: 'CategoricalDtype' = None,
257+
) -> 'CategoricalDtype':
248258
"""
249259
Construct dtype from the input parameters used in :class:`Categorical`.
250260
@@ -326,7 +336,11 @@ def _from_values_or_dtype(cls, values=None, categories=None, ordered=None,
326336

327337
return dtype
328338

329-
def _finalize(self, categories, ordered, fastpath=False):
339+
def _finalize(self,
340+
categories,
341+
ordered: Optional[bool],
342+
fastpath: bool = False,
343+
) -> None:
330344

331345
if ordered is not None:
332346
self.validate_ordered(ordered)
@@ -338,14 +352,14 @@ def _finalize(self, categories, ordered, fastpath=False):
338352
self._categories = categories
339353
self._ordered = ordered
340354

341-
def __setstate__(self, state):
355+
def __setstate__(self, state: 'Dict[str, Any]') -> None:
342356
# for pickle compat. __get_state__ is defined in the
343357
# PandasExtensionDtype superclass and uses the public properties to
344358
# pickle -> need to set the settable private ones here (see GH26067)
345359
self._categories = state.pop('categories', None)
346360
self._ordered = state.pop('ordered', False)
347361

348-
def __hash__(self):
362+
def __hash__(self) -> int:
349363
# _hash_categories returns a uint64, so use the negative
350364
# space for when we have unknown categories to avoid a conflict
351365
if self.categories is None:
@@ -356,7 +370,7 @@ def __hash__(self):
356370
# We *do* want to include the real self.ordered here
357371
return int(self._hash_categories(self.categories, self.ordered))
358372

359-
def __eq__(self, other):
373+
def __eq__(self, other: Any) -> bool:
360374
"""
361375
Rules for CDT equality:
362376
1) Any CDT is equal to the string 'category'
@@ -403,7 +417,7 @@ def __repr__(self):
403417
return tpl.format(data, self.ordered)
404418

405419
@staticmethod
406-
def _hash_categories(categories, ordered=True):
420+
def _hash_categories(categories, ordered: bool = True) -> int:
407421
from pandas.core.util.hashing import (
408422
hash_array, _combine_hash_arrays, hash_tuples
409423
)
@@ -453,7 +467,7 @@ def construct_array_type(cls):
453467
return Categorical
454468

455469
@classmethod
456-
def construct_from_string(cls, string):
470+
def construct_from_string(cls, string: str) -> 'CategoricalDtype':
457471
"""
458472
attempt to construct this type from a string, raise a TypeError if
459473
it's not possible """
@@ -466,7 +480,7 @@ def construct_from_string(cls, string):
466480
pass
467481

468482
@staticmethod
469-
def validate_ordered(ordered):
483+
def validate_ordered(ordered: bool) -> None:
470484
"""
471485
Validates that we have a valid ordered parameter. If
472486
it is not a boolean, a TypeError will be raised.
@@ -486,7 +500,7 @@ def validate_ordered(ordered):
486500
raise TypeError("'ordered' must either be 'True' or 'False'")
487501

488502
@staticmethod
489-
def validate_categories(categories, fastpath=False):
503+
def validate_categories(categories, fastpath: bool = False):
490504
"""
491505
Validates that we have good categories
492506
@@ -519,9 +533,9 @@ def validate_categories(categories, fastpath=False):
519533
if isinstance(categories, ABCCategoricalIndex):
520534
categories = categories.categories
521535

522-
return categories
536+
return typing.cast(Index, categories)
523537

524-
def update_dtype(self, dtype):
538+
def update_dtype(self, dtype: 'CategoricalDtype') -> 'CategoricalDtype':
525539
"""
526540
Returns a CategoricalDtype with categories and ordered taken from dtype
527541
if specified, otherwise falling back to self if unspecified
@@ -560,17 +574,18 @@ def categories(self):
560574
"""
561575
An ``Index`` containing the unique categories allowed.
562576
"""
563-
return self._categories
577+
from pandas import Index
578+
return typing.cast(Index, self._categories)
564579

565580
@property
566-
def ordered(self):
581+
def ordered(self) -> bool:
567582
"""
568583
Whether the categories have an ordered relationship.
569584
"""
570585
return self._ordered
571586

572587
@property
573-
def _is_boolean(self):
588+
def _is_boolean(self) -> bool:
574589
from pandas.core.dtypes.common import is_bool_dtype
575590

576591
return is_bool_dtype(self.categories)
@@ -614,14 +629,14 @@ class DatetimeTZDtype(PandasExtensionDtype, ExtensionDtype):
614629
datetime64[ns, tzfile('/usr/share/zoneinfo/US/Central')]
615630
"""
616631
type = Timestamp # type: Type[Timestamp]
617-
kind = 'M' # type: str_type
632+
kind = 'M' # type: str
618633
str = '|M8[ns]'
619634
num = 101
620635
base = np.dtype('M8[ns]')
621636
na_value = NaT
622637
_metadata = ('unit', 'tz')
623638
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
624-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
639+
_cache = {} # type: Dict[str, PandasExtensionDtype]
625640

626641
def __init__(self, unit="ns", tz=None):
627642
if isinstance(unit, DatetimeTZDtype):
@@ -765,13 +780,13 @@ class PeriodDtype(ExtensionDtype, PandasExtensionDtype):
765780
period[M]
766781
"""
767782
type = Period # type: Type[Period]
768-
kind = 'O' # type: str_type
783+
kind = 'O' # type: str
769784
str = '|O08'
770785
base = np.dtype('O')
771786
num = 102
772787
_metadata = ('freq',)
773788
_match = re.compile(r"(P|p)eriod\[(?P<freq>.+)\]")
774-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
789+
_cache = {} # type: Dict[str, PandasExtensionDtype]
775790

776791
def __new__(cls, freq=None):
777792
"""
@@ -919,13 +934,13 @@ class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
919934
interval[int64]
920935
"""
921936
name = 'interval'
922-
kind = None # type: Optional[str_type]
937+
kind = None # type: Optional[str]
923938
str = '|O08'
924939
base = np.dtype('O')
925940
num = 103
926941
_metadata = ('subtype',)
927942
_match = re.compile(r"(I|i)nterval\[(?P<subtype>.+)\]")
928-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
943+
_cache = {} # type: Dict[str, PandasExtensionDtype]
929944

930945
def __new__(cls, subtype=None):
931946
from pandas.core.dtypes.common import (

0 commit comments

Comments
 (0)