Skip to content

Commit cae328e

Browse files
TomAugspurgerdberenbaum
authored andcommitted
REF/API: Stricter extension checking. (pandas-dev#22031)
1 parent 7e94b7e commit cae328e

File tree

3 files changed

+32
-26
lines changed

3 files changed

+32
-26
lines changed

pandas/core/dtypes/common.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from pandas.core.dtypes.dtypes import (
1010
registry, CategoricalDtype, CategoricalDtypeType, DatetimeTZDtype,
1111
DatetimeTZDtypeType, PeriodDtype, PeriodDtypeType, IntervalDtype,
12-
IntervalDtypeType, ExtensionDtype)
12+
IntervalDtypeType, PandasExtensionDtype, ExtensionDtype,
13+
_pandas_registry)
1314
from pandas.core.dtypes.generic import (
1415
ABCCategorical, ABCPeriodIndex, ABCDatetimeIndex, ABCSeries,
1516
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex, ABCIndexClass,
@@ -1709,17 +1710,9 @@ def is_extension_array_dtype(arr_or_dtype):
17091710
Third-party libraries may implement arrays or types satisfying
17101711
this interface as well.
17111712
"""
1712-
from pandas.core.arrays import ExtensionArray
1713-
1714-
if isinstance(arr_or_dtype, (ABCIndexClass, ABCSeries)):
1715-
arr_or_dtype = arr_or_dtype._values
1716-
1717-
try:
1718-
arr_or_dtype = pandas_dtype(arr_or_dtype)
1719-
except TypeError:
1720-
pass
1721-
1722-
return isinstance(arr_or_dtype, (ExtensionDtype, ExtensionArray))
1713+
dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype)
1714+
return (isinstance(dtype, ExtensionDtype) or
1715+
registry.find(dtype) is not None)
17231716

17241717

17251718
def is_complex_dtype(arr_or_dtype):
@@ -1999,12 +1992,12 @@ def pandas_dtype(dtype):
19991992
return dtype
20001993

20011994
# registered extension types
2002-
result = registry.find(dtype)
1995+
result = _pandas_registry.find(dtype) or registry.find(dtype)
20031996
if result is not None:
20041997
return result
20051998

20061999
# un-registered extension types
2007-
elif isinstance(dtype, ExtensionDtype):
2000+
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
20082001
return dtype
20092002

20102003
# try a numpy dtype

pandas/core/dtypes/dtypes.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@ class Registry(object):
2222
--------
2323
registry.register(MyExtensionDtype)
2424
"""
25-
dtypes = []
25+
def __init__(self):
26+
self.dtypes = []
2627

27-
@classmethod
2828
def register(self, dtype):
2929
"""
3030
Parameters
@@ -50,7 +50,7 @@ def find(self, dtype):
5050
dtype_type = dtype
5151
if not isinstance(dtype, type):
5252
dtype_type = type(dtype)
53-
if issubclass(dtype_type, (PandasExtensionDtype, ExtensionDtype)):
53+
if issubclass(dtype_type, ExtensionDtype):
5454
return dtype
5555

5656
return None
@@ -65,6 +65,9 @@ def find(self, dtype):
6565

6666

6767
registry = Registry()
68+
# TODO(Extension): remove the second registry once all internal extension
69+
# dtypes are real extension dtypes.
70+
_pandas_registry = Registry()
6871

6972

7073
class PandasExtensionDtype(_DtypeOpsMixin):
@@ -822,7 +825,7 @@ def is_dtype(cls, dtype):
822825

823826

824827
# register the dtypes in search order
825-
registry.register(DatetimeTZDtype)
826-
registry.register(PeriodDtype)
827828
registry.register(IntervalDtype)
828829
registry.register(CategoricalDtype)
830+
_pandas_registry.register(DatetimeTZDtype)
831+
_pandas_registry.register(PeriodDtype)

pandas/tests/dtypes/test_dtypes.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from pandas.core.dtypes.dtypes import (
1111
DatetimeTZDtype, PeriodDtype,
12-
IntervalDtype, CategoricalDtype, registry)
12+
IntervalDtype, CategoricalDtype, registry, _pandas_registry)
1313
from pandas.core.dtypes.common import (
1414
is_categorical_dtype, is_categorical,
1515
is_datetime64tz_dtype, is_datetimetz,
@@ -775,21 +775,31 @@ def test_update_dtype_errors(self, bad_dtype):
775775

776776
@pytest.mark.parametrize(
777777
'dtype',
778-
[DatetimeTZDtype, CategoricalDtype,
779-
PeriodDtype, IntervalDtype])
778+
[CategoricalDtype, IntervalDtype])
780779
def test_registry(dtype):
781780
assert dtype in registry.dtypes
782781

783782

783+
@pytest.mark.parametrize('dtype', [DatetimeTZDtype, PeriodDtype])
784+
def test_pandas_registry(dtype):
785+
assert dtype not in registry.dtypes
786+
assert dtype in _pandas_registry.dtypes
787+
788+
784789
@pytest.mark.parametrize(
785790
'dtype, expected',
786791
[('int64', None),
787792
('interval', IntervalDtype()),
788793
('interval[int64]', IntervalDtype()),
789794
('interval[datetime64[ns]]', IntervalDtype('datetime64[ns]')),
790-
('category', CategoricalDtype()),
791-
('period[D]', PeriodDtype('D')),
792-
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
795+
('category', CategoricalDtype())])
793796
def test_registry_find(dtype, expected):
794-
795797
assert registry.find(dtype) == expected
798+
799+
800+
@pytest.mark.parametrize(
801+
'dtype, expected',
802+
[('period[D]', PeriodDtype('D')),
803+
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
804+
def test_pandas_registry_find(dtype, expected):
805+
assert _pandas_registry.find(dtype) == expected

0 commit comments

Comments
 (0)