diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index d54d980d02ffa..6dbed5f138d5d 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -5,26 +5,16 @@ from pandas.errors import AbstractMethodError -class ExtensionDtype(object): - """A custom data type, to be paired with an ExtensionArray. - - Notes - ----- - The interface includes the following abstract methods that must - be implemented by subclasses: - - * type - * name - * construct_from_string - - This class does not inherit from 'abc.ABCMeta' for performance reasons. - Methods and properties required by the interface raise - ``pandas.errors.AbstractMethodError`` and no ``register`` method is - provided for registering virtual subclasses. - """ - - def __str__(self): - return self.name +class _DtypeOpsMixin(object): + # Not all of pandas' extension dtypes are compatibile with + # the new ExtensionArray interface. This means PandasExtensionDtype + # can't subclass ExtensionDtype yet, as is_extension_array_dtype would + # incorrectly say that these types are extension types. + # + # In the interim, we put methods that are shared between the two base + # classes ExtensionDtype and PandasExtensionDtype here. Both those base + # classes will inherit from this Mixin. Once everything is compatible, this + # class's methods can be moved to ExtensionDtype and removed. def __eq__(self, other): """Check whether 'other' is equal to self. @@ -52,6 +42,74 @@ def __eq__(self, other): def __ne__(self, other): return not self.__eq__(other) + @property + def names(self): + # type: () -> Optional[List[str]] + """Ordered list of field names, or None if there are no fields. + + This is for compatibility with NumPy arrays, and may be removed in the + future. + """ + return None + + @classmethod + def is_dtype(cls, dtype): + """Check if we match 'dtype'. + + Parameters + ---------- + dtype : object + The object to check. + + Returns + ------- + is_dtype : bool + + Notes + ----- + The default implementation is True if + + 1. ``cls.construct_from_string(dtype)`` is an instance + of ``cls``. + 2. ``dtype`` is an object and is an instance of ``cls`` + 3. ``dtype`` has a ``dtype`` attribute, and any of the above + conditions is true for ``dtype.dtype``. + """ + dtype = getattr(dtype, 'dtype', dtype) + + if isinstance(dtype, np.dtype): + return False + elif dtype is None: + return False + elif isinstance(dtype, cls): + return True + try: + return cls.construct_from_string(dtype) is not None + except TypeError: + return False + + +class ExtensionDtype(_DtypeOpsMixin): + """A custom data type, to be paired with an ExtensionArray. + + Notes + ----- + The interface includes the following abstract methods that must + be implemented by subclasses: + + * type + * name + * construct_from_string + + This class does not inherit from 'abc.ABCMeta' for performance reasons. + Methods and properties required by the interface raise + ``pandas.errors.AbstractMethodError`` and no ``register`` method is + provided for registering virtual subclasses. + """ + + def __str__(self): + return self.name + @property def type(self): # type: () -> type @@ -87,16 +145,6 @@ def name(self): """ raise AbstractMethodError(self) - @property - def names(self): - # type: () -> Optional[List[str]] - """Ordered list of field names, or None if there are no fields. - - This is for compatibility with NumPy arrays, and may be removed in the - future. - """ - return None - @classmethod def construct_from_string(cls, string): """Attempt to construct this type from a string. @@ -128,39 +176,3 @@ def construct_from_string(cls, string): ... "'{}'".format(cls, string)) """ raise AbstractMethodError(cls) - - @classmethod - def is_dtype(cls, dtype): - """Check if we match 'dtype'. - - Parameters - ---------- - dtype : object - The object to check. - - Returns - ------- - is_dtype : bool - - Notes - ----- - The default implementation is True if - - 1. ``cls.construct_from_string(dtype)`` is an instance - of ``cls``. - 2. ``dtype`` is an object and is an instance of ``cls`` - 3. ``dtype`` has a ``dtype`` attribute, and any of the above - conditions is true for ``dtype.dtype``. - """ - dtype = getattr(dtype, 'dtype', dtype) - - if isinstance(dtype, np.dtype): - return False - elif dtype is None: - return False - elif isinstance(dtype, cls): - return True - try: - return cls.construct_from_string(dtype) is not None - except TypeError: - return False diff --git a/pandas/core/dtypes/cast.py b/pandas/core/dtypes/cast.py index b1d0dc2a2442e..74aaa2c4f00aa 100644 --- a/pandas/core/dtypes/cast.py +++ b/pandas/core/dtypes/cast.py @@ -26,7 +26,8 @@ _ensure_int32, _ensure_int64, _NS_DTYPE, _TD_DTYPE, _INT64_DTYPE, _POSSIBLY_CAST_DTYPES) -from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype +from .dtypes import (ExtensionDtype, PandasExtensionDtype, DatetimeTZDtype, + PeriodDtype) from .generic import (ABCDatetimeIndex, ABCPeriodIndex, ABCSeries) from .missing import isna, notna @@ -1114,7 +1115,8 @@ def find_common_type(types): if all(is_dtype_equal(first, t) for t in types[1:]): return first - if any(isinstance(t, ExtensionDtype) for t in types): + if any(isinstance(t, (PandasExtensionDtype, ExtensionDtype)) + for t in types): return np.object # take lowest unit diff --git a/pandas/core/dtypes/common.py b/pandas/core/dtypes/common.py index 197b35de88896..3a90feb7ccd7d 100644 --- a/pandas/core/dtypes/common.py +++ b/pandas/core/dtypes/common.py @@ -9,7 +9,7 @@ DatetimeTZDtype, DatetimeTZDtypeType, PeriodDtype, PeriodDtypeType, IntervalDtype, IntervalDtypeType, - ExtensionDtype) + ExtensionDtype, PandasExtensionDtype) from .generic import (ABCCategorical, ABCPeriodIndex, ABCDatetimeIndex, ABCSeries, ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex, @@ -2006,7 +2006,7 @@ def pandas_dtype(dtype): return CategoricalDtype.construct_from_string(dtype) except TypeError: pass - elif isinstance(dtype, ExtensionDtype): + elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)): return dtype try: diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index d262a71933915..708f54f5ca75b 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -5,10 +5,10 @@ from pandas import compat from pandas.core.dtypes.generic import ABCIndexClass, ABCCategoricalIndex -from .base import ExtensionDtype +from .base import ExtensionDtype, _DtypeOpsMixin -class PandasExtensionDtype(ExtensionDtype): +class PandasExtensionDtype(_DtypeOpsMixin): """ A np.dtype duck-typed class, suitable for holding a custom dtype. @@ -83,7 +83,7 @@ class CategoricalDtypeType(type): pass -class CategoricalDtype(PandasExtensionDtype): +class CategoricalDtype(PandasExtensionDtype, ExtensionDtype): """ Type for categorical data with the categories and orderedness diff --git a/pandas/core/internals.py b/pandas/core/internals.py index 240c9b1f3377c..47db1b0d5383a 100644 --- a/pandas/core/internals.py +++ b/pandas/core/internals.py @@ -17,6 +17,7 @@ from pandas.core.dtypes.dtypes import ( ExtensionDtype, DatetimeTZDtype, + PandasExtensionDtype, CategoricalDtype) from pandas.core.dtypes.common import ( _TD_DTYPE, _NS_DTYPE, @@ -598,7 +599,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None, list(errors_legal_values), errors)) raise ValueError(invalid_arg) - if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype): + if (inspect.isclass(dtype) and + issubclass(dtype, (PandasExtensionDtype, ExtensionDtype))): msg = ("Expected an instance of {}, but got the class instead. " "Try instantiating 'dtype'.".format(dtype.__name__)) raise TypeError(msg) @@ -5005,7 +5007,7 @@ def _interleaved_dtype(blocks): dtype = find_common_type([b.dtype for b in blocks]) # only numpy compat - if isinstance(dtype, ExtensionDtype): + if isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)): dtype = np.object return dtype diff --git a/pandas/tests/extension/test_common.py b/pandas/tests/extension/test_common.py index 1f4582f687415..589134632c7e9 100644 --- a/pandas/tests/extension/test_common.py +++ b/pandas/tests/extension/test_common.py @@ -5,10 +5,10 @@ import pandas.util.testing as tm from pandas.core.arrays import ExtensionArray from pandas.core.dtypes.common import is_extension_array_dtype -from pandas.core.dtypes.dtypes import ExtensionDtype +from pandas.core.dtypes import dtypes -class DummyDtype(ExtensionDtype): +class DummyDtype(dtypes.ExtensionDtype): pass @@ -65,3 +65,21 @@ def test_astype_no_copy(): result = arr.astype(arr.dtype) assert arr.data is not result + + +@pytest.mark.parametrize('dtype', [ + dtypes.DatetimeTZDtype('ns', 'US/Central'), + dtypes.PeriodDtype("D"), + dtypes.IntervalDtype(), +]) +def test_is_not_extension_array_dtype(dtype): + assert not isinstance(dtype, dtypes.ExtensionDtype) + assert not is_extension_array_dtype(dtype) + + +@pytest.mark.parametrize('dtype', [ + dtypes.CategoricalDtype(), +]) +def test_is_extension_array_dtype(dtype): + assert isinstance(dtype, dtypes.ExtensionDtype) + assert is_extension_array_dtype(dtype)