Skip to content

Commit 8a58303

Browse files
REF: Changed ExtensionDtype inheritance (#20363)
* REF: Changed ExtensionDtype inheritance `is_extension_array_dtype(dtype)` was incorrect for dtypes that haven't implemented the new interface yet. This is because they indirectly subclassed ExtensionDtype. This PR changes the hierarchy so that PandasExtensionDtype doesn't subclass ExtensionDtype. As we implement the interface, like Categorical, we'll add ExtensionDtype as a base class. Before: ``` DatetimeTZDtype <- PandasExtensionDtype <- ExtensionDtype (wrong) CategoricalDtype <- PandasExtensionDtype <- ExtensionDtype (right) After: DatetimeTZDtype <- PandasExtensionDtype \ - _DtypeOpsMixin / ExtensionDtype ------ CategoricalDtype - PandasExtensionDtype - \ \ \ -_DtypeOpsMixin \ / ExtensionDtype ------- ``` Once all our extension dtypes have implemented the interface we can go back to the simple, linear inheritance structure.
1 parent 0368927 commit 8a58303

File tree

6 files changed

+111
-77
lines changed

6 files changed

+111
-77
lines changed

pandas/core/dtypes/base.py

+78-66
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,16 @@
55
from pandas.errors import AbstractMethodError
66

77

8-
class ExtensionDtype(object):
9-
"""A custom data type, to be paired with an ExtensionArray.
10-
11-
Notes
12-
-----
13-
The interface includes the following abstract methods that must
14-
be implemented by subclasses:
15-
16-
* type
17-
* name
18-
* construct_from_string
19-
20-
This class does not inherit from 'abc.ABCMeta' for performance reasons.
21-
Methods and properties required by the interface raise
22-
``pandas.errors.AbstractMethodError`` and no ``register`` method is
23-
provided for registering virtual subclasses.
24-
"""
25-
26-
def __str__(self):
27-
return self.name
8+
class _DtypeOpsMixin(object):
9+
# Not all of pandas' extension dtypes are compatibile with
10+
# the new ExtensionArray interface. This means PandasExtensionDtype
11+
# can't subclass ExtensionDtype yet, as is_extension_array_dtype would
12+
# incorrectly say that these types are extension types.
13+
#
14+
# In the interim, we put methods that are shared between the two base
15+
# classes ExtensionDtype and PandasExtensionDtype here. Both those base
16+
# classes will inherit from this Mixin. Once everything is compatible, this
17+
# class's methods can be moved to ExtensionDtype and removed.
2818

2919
def __eq__(self, other):
3020
"""Check whether 'other' is equal to self.
@@ -52,6 +42,74 @@ def __eq__(self, other):
5242
def __ne__(self, other):
5343
return not self.__eq__(other)
5444

45+
@property
46+
def names(self):
47+
# type: () -> Optional[List[str]]
48+
"""Ordered list of field names, or None if there are no fields.
49+
50+
This is for compatibility with NumPy arrays, and may be removed in the
51+
future.
52+
"""
53+
return None
54+
55+
@classmethod
56+
def is_dtype(cls, dtype):
57+
"""Check if we match 'dtype'.
58+
59+
Parameters
60+
----------
61+
dtype : object
62+
The object to check.
63+
64+
Returns
65+
-------
66+
is_dtype : bool
67+
68+
Notes
69+
-----
70+
The default implementation is True if
71+
72+
1. ``cls.construct_from_string(dtype)`` is an instance
73+
of ``cls``.
74+
2. ``dtype`` is an object and is an instance of ``cls``
75+
3. ``dtype`` has a ``dtype`` attribute, and any of the above
76+
conditions is true for ``dtype.dtype``.
77+
"""
78+
dtype = getattr(dtype, 'dtype', dtype)
79+
80+
if isinstance(dtype, np.dtype):
81+
return False
82+
elif dtype is None:
83+
return False
84+
elif isinstance(dtype, cls):
85+
return True
86+
try:
87+
return cls.construct_from_string(dtype) is not None
88+
except TypeError:
89+
return False
90+
91+
92+
class ExtensionDtype(_DtypeOpsMixin):
93+
"""A custom data type, to be paired with an ExtensionArray.
94+
95+
Notes
96+
-----
97+
The interface includes the following abstract methods that must
98+
be implemented by subclasses:
99+
100+
* type
101+
* name
102+
* construct_from_string
103+
104+
This class does not inherit from 'abc.ABCMeta' for performance reasons.
105+
Methods and properties required by the interface raise
106+
``pandas.errors.AbstractMethodError`` and no ``register`` method is
107+
provided for registering virtual subclasses.
108+
"""
109+
110+
def __str__(self):
111+
return self.name
112+
55113
@property
56114
def type(self):
57115
# type: () -> type
@@ -87,16 +145,6 @@ def name(self):
87145
"""
88146
raise AbstractMethodError(self)
89147

90-
@property
91-
def names(self):
92-
# type: () -> Optional[List[str]]
93-
"""Ordered list of field names, or None if there are no fields.
94-
95-
This is for compatibility with NumPy arrays, and may be removed in the
96-
future.
97-
"""
98-
return None
99-
100148
@classmethod
101149
def construct_from_string(cls, string):
102150
"""Attempt to construct this type from a string.
@@ -128,39 +176,3 @@ def construct_from_string(cls, string):
128176
... "'{}'".format(cls, string))
129177
"""
130178
raise AbstractMethodError(cls)
131-
132-
@classmethod
133-
def is_dtype(cls, dtype):
134-
"""Check if we match 'dtype'.
135-
136-
Parameters
137-
----------
138-
dtype : object
139-
The object to check.
140-
141-
Returns
142-
-------
143-
is_dtype : bool
144-
145-
Notes
146-
-----
147-
The default implementation is True if
148-
149-
1. ``cls.construct_from_string(dtype)`` is an instance
150-
of ``cls``.
151-
2. ``dtype`` is an object and is an instance of ``cls``
152-
3. ``dtype`` has a ``dtype`` attribute, and any of the above
153-
conditions is true for ``dtype.dtype``.
154-
"""
155-
dtype = getattr(dtype, 'dtype', dtype)
156-
157-
if isinstance(dtype, np.dtype):
158-
return False
159-
elif dtype is None:
160-
return False
161-
elif isinstance(dtype, cls):
162-
return True
163-
try:
164-
return cls.construct_from_string(dtype) is not None
165-
except TypeError:
166-
return False

pandas/core/dtypes/cast.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
_ensure_int32, _ensure_int64,
2727
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
2828
_POSSIBLY_CAST_DTYPES)
29-
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
29+
from .dtypes import (ExtensionDtype, PandasExtensionDtype, DatetimeTZDtype,
30+
PeriodDtype)
3031
from .generic import (ABCDatetimeIndex, ABCPeriodIndex,
3132
ABCSeries)
3233
from .missing import isna, notna
@@ -1114,7 +1115,8 @@ def find_common_type(types):
11141115
if all(is_dtype_equal(first, t) for t in types[1:]):
11151116
return first
11161117

1117-
if any(isinstance(t, ExtensionDtype) for t in types):
1118+
if any(isinstance(t, (PandasExtensionDtype, ExtensionDtype))
1119+
for t in types):
11181120
return np.object
11191121

11201122
# take lowest unit

pandas/core/dtypes/common.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
DatetimeTZDtype, DatetimeTZDtypeType,
1010
PeriodDtype, PeriodDtypeType,
1111
IntervalDtype, IntervalDtypeType,
12-
ExtensionDtype)
12+
ExtensionDtype, PandasExtensionDtype)
1313
from .generic import (ABCCategorical, ABCPeriodIndex,
1414
ABCDatetimeIndex, ABCSeries,
1515
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex,
@@ -2006,7 +2006,7 @@ def pandas_dtype(dtype):
20062006
return CategoricalDtype.construct_from_string(dtype)
20072007
except TypeError:
20082008
pass
2009-
elif isinstance(dtype, ExtensionDtype):
2009+
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
20102010
return dtype
20112011

20122012
try:

pandas/core/dtypes/dtypes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
from pandas import compat
66
from pandas.core.dtypes.generic import ABCIndexClass, ABCCategoricalIndex
77

8-
from .base import ExtensionDtype
8+
from .base import ExtensionDtype, _DtypeOpsMixin
99

1010

11-
class PandasExtensionDtype(ExtensionDtype):
11+
class PandasExtensionDtype(_DtypeOpsMixin):
1212
"""
1313
A np.dtype duck-typed class, suitable for holding a custom dtype.
1414
@@ -83,7 +83,7 @@ class CategoricalDtypeType(type):
8383
pass
8484

8585

86-
class CategoricalDtype(PandasExtensionDtype):
86+
class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
8787
"""
8888
Type for categorical data with the categories and orderedness
8989

pandas/core/internals.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from pandas.core.dtypes.dtypes import (
1919
ExtensionDtype, DatetimeTZDtype,
20+
PandasExtensionDtype,
2021
CategoricalDtype)
2122
from pandas.core.dtypes.common import (
2223
_TD_DTYPE, _NS_DTYPE,
@@ -598,7 +599,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
598599
list(errors_legal_values), errors))
599600
raise ValueError(invalid_arg)
600601

601-
if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
602+
if (inspect.isclass(dtype) and
603+
issubclass(dtype, (PandasExtensionDtype, ExtensionDtype))):
602604
msg = ("Expected an instance of {}, but got the class instead. "
603605
"Try instantiating 'dtype'.".format(dtype.__name__))
604606
raise TypeError(msg)
@@ -5005,7 +5007,7 @@ def _interleaved_dtype(blocks):
50055007
dtype = find_common_type([b.dtype for b in blocks])
50065008

50075009
# only numpy compat
5008-
if isinstance(dtype, ExtensionDtype):
5010+
if isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
50095011
dtype = np.object
50105012

50115013
return dtype

pandas/tests/extension/test_common.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import pandas.util.testing as tm
66
from pandas.core.arrays import ExtensionArray
77
from pandas.core.dtypes.common import is_extension_array_dtype
8-
from pandas.core.dtypes.dtypes import ExtensionDtype
8+
from pandas.core.dtypes import dtypes
99

1010

11-
class DummyDtype(ExtensionDtype):
11+
class DummyDtype(dtypes.ExtensionDtype):
1212
pass
1313

1414

@@ -65,3 +65,21 @@ def test_astype_no_copy():
6565

6666
result = arr.astype(arr.dtype)
6767
assert arr.data is not result
68+
69+
70+
@pytest.mark.parametrize('dtype', [
71+
dtypes.DatetimeTZDtype('ns', 'US/Central'),
72+
dtypes.PeriodDtype("D"),
73+
dtypes.IntervalDtype(),
74+
])
75+
def test_is_not_extension_array_dtype(dtype):
76+
assert not isinstance(dtype, dtypes.ExtensionDtype)
77+
assert not is_extension_array_dtype(dtype)
78+
79+
80+
@pytest.mark.parametrize('dtype', [
81+
dtypes.CategoricalDtype(),
82+
])
83+
def test_is_extension_array_dtype(dtype):
84+
assert isinstance(dtype, dtypes.ExtensionDtype)
85+
assert is_extension_array_dtype(dtype)

0 commit comments

Comments
 (0)