Skip to content

Commit 2c32937

Browse files
committed
Add type hints to dtypes/dtypes.py (CategoricalDtype)
1 parent 17247ed commit 2c32937

File tree

2 files changed

+54
-43
lines changed

2 files changed

+54
-43
lines changed

pandas/core/arrays/categorical.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def ordered(self):
429429
return self.dtype.ordered
430430

431431
@property
432-
def dtype(self):
432+
def dtype(self) -> 'CategoricalDtype':
433433
"""
434434
The :class:`~pandas.api.types.CategoricalDtype` for this instance
435435
"""

pandas/core/dtypes/dtypes.py

+53-42
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
""" define extension dtypes """
22
import re
3-
from typing import Any, Dict, Optional, Tuple, Type
3+
import typing
4+
from typing import Any, Dict, List, Optional, Tuple, Type
45
import warnings
56

67
import numpy as np
@@ -15,10 +16,8 @@
1516
from .base import ExtensionDtype, _DtypeOpsMixin
1617
from .inference import is_list_like
1718

18-
str_type = str
1919

20-
21-
def register_extension_dtype(cls):
20+
def register_extension_dtype(cls: 'ExtensionDtype') -> 'ExtensionDtype':
2221
"""
2322
Register an ExtensionType with pandas as class decorator.
2423
@@ -60,20 +59,20 @@ class Registry:
6059
These are tried in order.
6160
"""
6261
def __init__(self):
63-
self.dtypes = []
62+
self.dtypes = [] # type: List[ExtensionDtype]
6463

65-
def register(self, dtype):
64+
def register(self, dtype: 'ExtensionDtype') -> None:
6665
"""
6766
Parameters
6867
----------
6968
dtype : ExtensionDtype
7069
"""
71-
if not issubclass(dtype, (PandasExtensionDtype, ExtensionDtype)):
70+
if not issubclass(dtype, ExtensionDtype):
7271
raise ValueError("can only register pandas extension dtypes")
7372

7473
self.dtypes.append(dtype)
7574

76-
def find(self, dtype):
75+
def find(self, dtype: 'ExtensionDtype') -> Optional[ExtensionDtype]:
7776
"""
7877
Parameters
7978
----------
@@ -117,25 +116,25 @@ class PandasExtensionDtype(_DtypeOpsMixin):
117116
# and ExtensionDtype's @properties in the subclasses below. The kind and
118117
# type variables in those subclasses are explicitly typed below.
119118
subdtype = None
120-
str = None # type: Optional[str_type]
119+
str = None # type: Optional[str]
121120
num = 100
122121
shape = tuple() # type: Tuple[int, ...]
123122
itemsize = 8
124123
base = None
125124
isbuiltin = 0
126125
isnative = 0
127-
_cache = {} # type: Dict[str_type, 'PandasExtensionDtype']
126+
_cache = {} # type: Dict[str, 'PandasExtensionDtype']
128127

129-
def __unicode__(self):
128+
def __unicode__(self) -> str:
130129
return self.name
131130

132-
def __str__(self):
131+
def __str__(self) -> str:
133132
"""
134133
Return a string representation for a particular Object
135134
"""
136135
return self.__unicode__()
137136

138-
def __bytes__(self):
137+
def __bytes__(self) -> bytes:
139138
"""
140139
Return a string representation for a particular object.
141140
"""
@@ -144,22 +143,22 @@ def __bytes__(self):
144143
encoding = get_option("display.encoding")
145144
return self.__unicode__().encode(encoding, 'replace')
146145

147-
def __repr__(self):
146+
def __repr__(self) -> str:
148147
"""
149148
Return a string representation for a particular object.
150149
"""
151150
return str(self)
152151

153-
def __hash__(self):
152+
def __hash__(self) -> int:
154153
raise NotImplementedError("sub-classes should implement an __hash__ "
155154
"method")
156155

157-
def __getstate__(self):
156+
def __getstate__(self) -> Dict[str, Any]:
158157
# pickle support; we don't want to pickle the cache
159158
return {k: getattr(self, k, None) for k in self._metadata}
160159

161160
@classmethod
162-
def reset_cache(cls):
161+
def reset_cache(cls) -> None:
163162
""" clear the cache """
164163
cls._cache = {}
165164

@@ -217,23 +216,27 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
217216
# TODO: Document public vs. private API
218217
name = 'category'
219218
type = CategoricalDtypeType # type: Type[CategoricalDtypeType]
220-
kind = 'O' # type: str_type
219+
kind = 'O' # type: str
221220
str = '|O08'
222221
base = np.dtype('O')
223222
_metadata = ('categories', 'ordered')
224-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
223+
_cache = {} # type: Dict[str, PandasExtensionDtype]
225224

226-
def __init__(self, categories=None, ordered=None):
225+
def __init__(self, categories=None, ordered: bool = None):
227226
self._finalize(categories, ordered, fastpath=False)
228227

229228
@classmethod
230-
def _from_fastpath(cls, categories=None, ordered=None):
229+
def _from_fastpath(cls, categories=None, ordered: bool = None):
231230
self = cls.__new__(cls)
232231
self._finalize(categories, ordered, fastpath=True)
233232
return self
234233

235234
@classmethod
236-
def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
235+
def _from_categorical_dtype(cls,
236+
dtype: 'CategoricalDtype',
237+
categories=None,
238+
ordered: bool = None,
239+
) -> 'CategoricalDtype':
237240
if categories is ordered is None:
238241
return dtype
239242
if categories is None:
@@ -243,8 +246,11 @@ def _from_categorical_dtype(cls, dtype, categories=None, ordered=None):
243246
return cls(categories, ordered)
244247

245248
@classmethod
246-
def _from_values_or_dtype(cls, values=None, categories=None, ordered=None,
247-
dtype=None):
249+
def _from_values_or_dtype(cls,
250+
values=None,
251+
categories=None,
252+
ordered: bool = None,
253+
dtype: 'CategoricalDtype' = None):
248254
"""
249255
Construct dtype from the input parameters used in :class:`Categorical`.
250256
@@ -326,7 +332,11 @@ def _from_values_or_dtype(cls, values=None, categories=None, ordered=None,
326332

327333
return dtype
328334

329-
def _finalize(self, categories, ordered, fastpath=False):
335+
def _finalize(self,
336+
categories,
337+
ordered: Optional[bool],
338+
fastpath: bool = False,
339+
) -> None:
330340

331341
if ordered is not None:
332342
self.validate_ordered(ordered)
@@ -338,14 +348,14 @@ def _finalize(self, categories, ordered, fastpath=False):
338348
self._categories = categories
339349
self._ordered = ordered
340350

341-
def __setstate__(self, state):
351+
def __setstate__(self, state: 'Dict[str, Any]') -> None:
342352
# for pickle compat. __get_state__ is defined in the
343353
# PandasExtensionDtype superclass and uses the public properties to
344354
# pickle -> need to set the settable private ones here (see GH26067)
345355
self._categories = state.pop('categories', None)
346356
self._ordered = state.pop('ordered', False)
347357

348-
def __hash__(self):
358+
def __hash__(self) -> int:
349359
# _hash_categories returns a uint64, so use the negative
350360
# space for when we have unknown categories to avoid a conflict
351361
if self.categories is None:
@@ -356,7 +366,7 @@ def __hash__(self):
356366
# We *do* want to include the real self.ordered here
357367
return int(self._hash_categories(self.categories, self.ordered))
358368

359-
def __eq__(self, other):
369+
def __eq__(self, other: Any) -> bool:
360370
"""
361371
Rules for CDT equality:
362372
1) Any CDT is equal to the string 'category'
@@ -403,7 +413,7 @@ def __repr__(self):
403413
return tpl.format(data, self.ordered)
404414

405415
@staticmethod
406-
def _hash_categories(categories, ordered=True):
416+
def _hash_categories(categories, ordered: bool = True) -> int:
407417
from pandas.core.util.hashing import (
408418
hash_array, _combine_hash_arrays, hash_tuples
409419
)
@@ -453,7 +463,7 @@ def construct_array_type(cls):
453463
return Categorical
454464

455465
@classmethod
456-
def construct_from_string(cls, string):
466+
def construct_from_string(cls, string: str) -> 'CategoricalDtype':
457467
"""
458468
attempt to construct this type from a string, raise a TypeError if
459469
it's not possible """
@@ -466,7 +476,7 @@ def construct_from_string(cls, string):
466476
pass
467477

468478
@staticmethod
469-
def validate_ordered(ordered):
479+
def validate_ordered(ordered: bool) -> None:
470480
"""
471481
Validates that we have a valid ordered parameter. If
472482
it is not a boolean, a TypeError will be raised.
@@ -486,7 +496,7 @@ def validate_ordered(ordered):
486496
raise TypeError("'ordered' must either be 'True' or 'False'")
487497

488498
@staticmethod
489-
def validate_categories(categories, fastpath=False):
499+
def validate_categories(categories, fastpath: bool = False):
490500
"""
491501
Validates that we have good categories
492502
@@ -521,7 +531,7 @@ def validate_categories(categories, fastpath=False):
521531

522532
return categories
523533

524-
def update_dtype(self, dtype):
534+
def update_dtype(self, dtype: 'CategoricalDtype') -> 'CategoricalDtype':
525535
"""
526536
Returns a CategoricalDtype with categories and ordered taken from dtype
527537
if specified, otherwise falling back to self if unspecified
@@ -560,17 +570,18 @@ def categories(self):
560570
"""
561571
An ``Index`` containing the unique categories allowed.
562572
"""
563-
return self._categories
573+
from pandas import Index
574+
return typing.cast(Index, self._categories)
564575

565576
@property
566-
def ordered(self):
577+
def ordered(self) -> bool:
567578
"""
568579
Whether the categories have an ordered relationship.
569580
"""
570581
return self._ordered
571582

572583
@property
573-
def _is_boolean(self):
584+
def _is_boolean(self) -> bool:
574585
from pandas.core.dtypes.common import is_bool_dtype
575586

576587
return is_bool_dtype(self.categories)
@@ -614,14 +625,14 @@ class DatetimeTZDtype(PandasExtensionDtype, ExtensionDtype):
614625
datetime64[ns, tzfile('/usr/share/zoneinfo/US/Central')]
615626
"""
616627
type = Timestamp # type: Type[Timestamp]
617-
kind = 'M' # type: str_type
628+
kind = 'M' # type: str
618629
str = '|M8[ns]'
619630
num = 101
620631
base = np.dtype('M8[ns]')
621632
na_value = NaT
622633
_metadata = ('unit', 'tz')
623634
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
624-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
635+
_cache = {} # type: Dict[str, PandasExtensionDtype]
625636

626637
def __init__(self, unit="ns", tz=None):
627638
if isinstance(unit, DatetimeTZDtype):
@@ -765,13 +776,13 @@ class PeriodDtype(ExtensionDtype, PandasExtensionDtype):
765776
period[M]
766777
"""
767778
type = Period # type: Type[Period]
768-
kind = 'O' # type: str_type
779+
kind = 'O' # type: str
769780
str = '|O08'
770781
base = np.dtype('O')
771782
num = 102
772783
_metadata = ('freq',)
773784
_match = re.compile(r"(P|p)eriod\[(?P<freq>.+)\]")
774-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
785+
_cache = {} # type: Dict[str, PandasExtensionDtype]
775786

776787
def __new__(cls, freq=None):
777788
"""
@@ -919,13 +930,13 @@ class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
919930
interval[int64]
920931
"""
921932
name = 'interval'
922-
kind = None # type: Optional[str_type]
933+
kind = None # type: Optional[str]
923934
str = '|O08'
924935
base = np.dtype('O')
925936
num = 103
926937
_metadata = ('subtype',)
927938
_match = re.compile(r"(I|i)nterval\[(?P<subtype>.+)\]")
928-
_cache = {} # type: Dict[str_type, PandasExtensionDtype]
939+
_cache = {} # type: Dict[str, PandasExtensionDtype]
929940

930941
def __new__(cls, subtype=None):
931942
from pandas.core.dtypes.common import (

0 commit comments

Comments
 (0)