Skip to content

Commit 203f483

Browse files
authored
REF: implement Dtype.index_class (#54511)
* REF: implement Dtype.index_class * property->cache_readonly
1 parent 16ccbca commit 203f483

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

pandas/core/dtypes/base.py

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pandas._libs import missing as libmissing
1717
from pandas._libs.hashtable import object_hash
18+
from pandas._libs.properties import cache_readonly
1819
from pandas.errors import AbstractMethodError
1920

2021
from pandas.core.dtypes.generic import (
@@ -32,6 +33,7 @@
3233
type_t,
3334
)
3435

36+
from pandas import Index
3537
from pandas.core.arrays import ExtensionArray
3638

3739
# To parameterize on same ExtensionDtype
@@ -406,6 +408,16 @@ def _is_immutable(self) -> bool:
406408
"""
407409
return False
408410

411+
@cache_readonly
412+
def index_class(self) -> type_t[Index]:
413+
"""
414+
The Index subclass to return from Index.__new__ when this dtype is
415+
encountered.
416+
"""
417+
from pandas import Index
418+
419+
return Index
420+
409421

410422
class StorageExtensionDtype(ExtensionDtype):
411423
"""ExtensionDtype that may be backed by more than one implementation."""

pandas/core/dtypes/dtypes.py

+28
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@
8181

8282
from pandas import (
8383
Categorical,
84+
CategoricalIndex,
85+
DatetimeIndex,
8486
Index,
87+
IntervalIndex,
88+
PeriodIndex,
8589
)
8690
from pandas.core.arrays import (
8791
BaseMaskedArray,
@@ -671,6 +675,12 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
671675

672676
return find_common_type(non_cat_dtypes)
673677

678+
@cache_readonly
679+
def index_class(self) -> type_t[CategoricalIndex]:
680+
from pandas import CategoricalIndex
681+
682+
return CategoricalIndex
683+
674684

675685
@register_extension_dtype
676686
class DatetimeTZDtype(PandasExtensionDtype):
@@ -911,6 +921,12 @@ def __setstate__(self, state) -> None:
911921
self._tz = state["tz"]
912922
self._unit = state["unit"]
913923

924+
@cache_readonly
925+
def index_class(self) -> type_t[DatetimeIndex]:
926+
from pandas import DatetimeIndex
927+
928+
return DatetimeIndex
929+
914930

915931
@register_extension_dtype
916932
class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype):
@@ -1121,6 +1137,12 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> PeriodArray:
11211137
return PeriodArray(np.array([], dtype="int64"), dtype=self, copy=False)
11221138
return PeriodArray._concat_same_type(results)
11231139

1140+
@cache_readonly
1141+
def index_class(self) -> type_t[PeriodIndex]:
1142+
from pandas import PeriodIndex
1143+
1144+
return PeriodIndex
1145+
11241146

11251147
@register_extension_dtype
11261148
class IntervalDtype(PandasExtensionDtype):
@@ -1384,6 +1406,12 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
13841406
return np.dtype(object)
13851407
return IntervalDtype(common, closed=closed)
13861408

1409+
@cache_readonly
1410+
def index_class(self) -> type_t[IntervalIndex]:
1411+
from pandas import IntervalIndex
1412+
1413+
return IntervalIndex
1414+
13871415

13881416
class NumpyEADtype(ExtensionDtype):
13891417
"""

pandas/core/indexes/base.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -600,24 +600,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
600600
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
601601

602602
if isinstance(dtype, ExtensionDtype):
603-
if isinstance(dtype, DatetimeTZDtype):
604-
from pandas import DatetimeIndex
605-
606-
return DatetimeIndex
607-
elif isinstance(dtype, CategoricalDtype):
608-
from pandas import CategoricalIndex
609-
610-
return CategoricalIndex
611-
elif isinstance(dtype, IntervalDtype):
612-
from pandas import IntervalIndex
613-
614-
return IntervalIndex
615-
elif isinstance(dtype, PeriodDtype):
616-
from pandas import PeriodIndex
617-
618-
return PeriodIndex
619-
620-
return Index
603+
return dtype.index_class
621604

622605
if dtype.kind == "M":
623606
from pandas import DatetimeIndex

0 commit comments

Comments
 (0)