Skip to content

REF/API: Stricter extension checking. #22031

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
21 changes: 7 additions & 14 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from pandas.core.dtypes.dtypes import (
registry, CategoricalDtype, CategoricalDtypeType, DatetimeTZDtype,
DatetimeTZDtypeType, PeriodDtype, PeriodDtypeType, IntervalDtype,
IntervalDtypeType, ExtensionDtype)
IntervalDtypeType, PandasExtensionDtype, ExtensionDtype,
_pandas_registry)
from pandas.core.dtypes.generic import (
ABCCategorical, ABCPeriodIndex, ABCDatetimeIndex, ABCSeries,
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex, ABCIndexClass,
Expand Down Expand Up @@ -1709,17 +1710,9 @@ def is_extension_array_dtype(arr_or_dtype):
Third-party libraries may implement arrays or types satisfying
this interface as well.
"""
from pandas.core.arrays import ExtensionArray

if isinstance(arr_or_dtype, (ABCIndexClass, ABCSeries)):
arr_or_dtype = arr_or_dtype._values

try:
arr_or_dtype = pandas_dtype(arr_or_dtype)
except TypeError:
pass

return isinstance(arr_or_dtype, (ExtensionDtype, ExtensionArray))
dtype = getattr(arr_or_dtype, 'dtype', arr_or_dtype)
return (isinstance(dtype, ExtensionDtype) or
registry.find(dtype) is not None)


def is_complex_dtype(arr_or_dtype):
Expand Down Expand Up @@ -1999,12 +1992,12 @@ def pandas_dtype(dtype):
return dtype

# registered extension types
result = registry.find(dtype)
result = _pandas_registry.find(dtype) or registry.find(dtype)
if result is not None:
return result

# un-registered extension types
elif isinstance(dtype, ExtensionDtype):
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
return dtype

# try a numpy dtype
Expand Down
13 changes: 8 additions & 5 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ class Registry(object):
--------
registry.register(MyExtensionDtype)
"""
dtypes = []
def __init__(self):
self.dtypes = []

@classmethod
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't actually a classmethod.

def register(self, dtype):
"""
Parameters
Expand All @@ -50,7 +50,7 @@ def find(self, dtype):
dtype_type = dtype
if not isinstance(dtype, type):
dtype_type = type(dtype)
if issubclass(dtype_type, (PandasExtensionDtype, ExtensionDtype)):
if issubclass(dtype_type, ExtensionDtype):
return dtype

return None
Expand All @@ -65,6 +65,9 @@ def find(self, dtype):


registry = Registry()
# TODO(Extension): remove the second registry once all internal extension
# dtypes are real extension dtypes.
_pandas_registry = Registry()


class PandasExtensionDtype(_DtypeOpsMixin):
Expand Down Expand Up @@ -822,7 +825,7 @@ def is_dtype(cls, dtype):


# register the dtypes in search order
registry.register(DatetimeTZDtype)
registry.register(PeriodDtype)
registry.register(IntervalDtype)
registry.register(CategoricalDtype)
_pandas_registry.register(DatetimeTZDtype)
_pandas_registry.register(PeriodDtype)
24 changes: 17 additions & 7 deletions pandas/tests/dtypes/test_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from pandas.core.dtypes.dtypes import (
DatetimeTZDtype, PeriodDtype,
IntervalDtype, CategoricalDtype, registry)
IntervalDtype, CategoricalDtype, registry, _pandas_registry)
from pandas.core.dtypes.common import (
is_categorical_dtype, is_categorical,
is_datetime64tz_dtype, is_datetimetz,
Expand Down Expand Up @@ -775,21 +775,31 @@ def test_update_dtype_errors(self, bad_dtype):

@pytest.mark.parametrize(
'dtype',
[DatetimeTZDtype, CategoricalDtype,
PeriodDtype, IntervalDtype])
[CategoricalDtype, IntervalDtype])
def test_registry(dtype):
assert dtype in registry.dtypes


@pytest.mark.parametrize('dtype', [DatetimeTZDtype, PeriodDtype])
def test_pandas_registry(dtype):
assert dtype not in registry.dtypes
assert dtype in _pandas_registry.dtypes


@pytest.mark.parametrize(
'dtype, expected',
[('int64', None),
('interval', IntervalDtype()),
('interval[int64]', IntervalDtype()),
('interval[datetime64[ns]]', IntervalDtype('datetime64[ns]')),
('category', CategoricalDtype()),
('period[D]', PeriodDtype('D')),
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
('category', CategoricalDtype())])
def test_registry_find(dtype, expected):

assert registry.find(dtype) == expected


@pytest.mark.parametrize(
'dtype, expected',
[('period[D]', PeriodDtype('D')),
('datetime64[ns, US/Eastern]', DatetimeTZDtype('ns', 'US/Eastern'))])
def test_pandas_registry_find(dtype, expected):
assert _pandas_registry.find(dtype) == expected