Skip to content

Commit 97885b4

Browse files
jbrockmendelmroeschke
authored andcommitted
REF: implement Dtype.index_class (pandas-dev#54511)
* REF: implement Dtype.index_class * property->cache_readonly
1 parent 298d439 commit 97885b4

File tree

3 files changed

+41
-18
lines changed

3 files changed

+41
-18
lines changed

pandas/core/dtypes/base.py

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from pandas._libs import missing as libmissing
1717
from pandas._libs.hashtable import object_hash
18+
from pandas._libs.properties import cache_readonly
1819
from pandas.errors import AbstractMethodError
1920

2021
from pandas.core.dtypes.generic import (
@@ -32,6 +33,7 @@
3233
type_t,
3334
)
3435

36+
from pandas import Index
3537
from pandas.core.arrays import ExtensionArray
3638

3739
# To parameterize on same ExtensionDtype
@@ -406,6 +408,16 @@ def _is_immutable(self) -> bool:
406408
"""
407409
return False
408410

411+
@cache_readonly
412+
def index_class(self) -> type_t[Index]:
413+
"""
414+
The Index subclass to return from Index.__new__ when this dtype is
415+
encountered.
416+
"""
417+
from pandas import Index
418+
419+
return Index
420+
409421

410422
class StorageExtensionDtype(ExtensionDtype):
411423
"""ExtensionDtype that may be backed by more than one implementation."""

pandas/core/dtypes/dtypes.py

+28
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,11 @@
8181

8282
from pandas import (
8383
Categorical,
84+
CategoricalIndex,
85+
DatetimeIndex,
8486
Index,
87+
IntervalIndex,
88+
PeriodIndex,
8589
)
8690
from pandas.core.arrays import (
8791
BaseMaskedArray,
@@ -671,6 +675,12 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
671675

672676
return find_common_type(non_cat_dtypes)
673677

678+
@cache_readonly
679+
def index_class(self) -> type_t[CategoricalIndex]:
680+
from pandas import CategoricalIndex
681+
682+
return CategoricalIndex
683+
674684

675685
@register_extension_dtype
676686
class DatetimeTZDtype(PandasExtensionDtype):
@@ -911,6 +921,12 @@ def __setstate__(self, state) -> None:
911921
self._tz = state["tz"]
912922
self._unit = state["unit"]
913923

924+
@cache_readonly
925+
def index_class(self) -> type_t[DatetimeIndex]:
926+
from pandas import DatetimeIndex
927+
928+
return DatetimeIndex
929+
914930

915931
@register_extension_dtype
916932
class PeriodDtype(PeriodDtypeBase, PandasExtensionDtype):
@@ -1121,6 +1137,12 @@ def __from_arrow__(self, array: pa.Array | pa.ChunkedArray) -> PeriodArray:
11211137
return PeriodArray(np.array([], dtype="int64"), dtype=self, copy=False)
11221138
return PeriodArray._concat_same_type(results)
11231139

1140+
@cache_readonly
1141+
def index_class(self) -> type_t[PeriodIndex]:
1142+
from pandas import PeriodIndex
1143+
1144+
return PeriodIndex
1145+
11241146

11251147
@register_extension_dtype
11261148
class IntervalDtype(PandasExtensionDtype):
@@ -1384,6 +1406,12 @@ def _get_common_dtype(self, dtypes: list[DtypeObj]) -> DtypeObj | None:
13841406
return np.dtype(object)
13851407
return IntervalDtype(common, closed=closed)
13861408

1409+
@cache_readonly
1410+
def index_class(self) -> type_t[IntervalIndex]:
1411+
from pandas import IntervalIndex
1412+
1413+
return IntervalIndex
1414+
13871415

13881416
class NumpyEADtype(ExtensionDtype):
13891417
"""

pandas/core/indexes/base.py

+1-18
Original file line numberDiff line numberDiff line change
@@ -599,24 +599,7 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
599599
# Delay import for perf. https://github.com/pandas-dev/pandas/pull/31423
600600

601601
if isinstance(dtype, ExtensionDtype):
602-
if isinstance(dtype, DatetimeTZDtype):
603-
from pandas import DatetimeIndex
604-
605-
return DatetimeIndex
606-
elif isinstance(dtype, CategoricalDtype):
607-
from pandas import CategoricalIndex
608-
609-
return CategoricalIndex
610-
elif isinstance(dtype, IntervalDtype):
611-
from pandas import IntervalIndex
612-
613-
return IntervalIndex
614-
elif isinstance(dtype, PeriodDtype):
615-
from pandas import PeriodIndex
616-
617-
return PeriodIndex
618-
619-
return Index
602+
return dtype.index_class
620603

621604
if dtype.kind == "M":
622605
from pandas import DatetimeIndex

0 commit comments

Comments
 (0)