Skip to content

REF: Changed ExtensionDtype inheritance #20363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 78 additions & 66 deletions pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,16 @@
from pandas.errors import AbstractMethodError


class ExtensionDtype(object):
"""A custom data type, to be paired with an ExtensionArray.

Notes
-----
The interface includes the following abstract methods that must
be implemented by subclasses:

* type
* name
* construct_from_string

This class does not inherit from 'abc.ABCMeta' for performance reasons.
Methods and properties required by the interface raise
``pandas.errors.AbstractMethodError`` and no ``register`` method is
provided for registering virtual subclasses.
"""

def __str__(self):
return self.name
class _DtypeOpsMixin(object):
# Not all of pandas' extension dtypes are compatibile with
# the new ExtensionArray interface. This means PandasExtensionDtype
# can't subclass ExtensionDtype yet, as is_extension_array_dtype would
# incorrectly say that these types are extension types.
#
# In the interim, we put methods that are shared between the two base
# classes ExtensionDtype and PandasExtensionDtype here. Both those base
# classes will inherit from this Mixin. Once everything is compatible, this
# class's methods can be moved to ExtensionDtype and removed.

def __eq__(self, other):
"""Check whether 'other' is equal to self.
Expand Down Expand Up @@ -52,6 +42,74 @@ def __eq__(self, other):
def __ne__(self, other):
return not self.__eq__(other)

@property
def names(self):
# type: () -> Optional[List[str]]
"""Ordered list of field names, or None if there are no fields.

This is for compatibility with NumPy arrays, and may be removed in the
future.
"""
return None

@classmethod
def is_dtype(cls, dtype):
"""Check if we match 'dtype'.

Parameters
----------
dtype : object
The object to check.

Returns
-------
is_dtype : bool

Notes
-----
The default implementation is True if

1. ``cls.construct_from_string(dtype)`` is an instance
of ``cls``.
2. ``dtype`` is an object and is an instance of ``cls``
3. ``dtype`` has a ``dtype`` attribute, and any of the above
conditions is true for ``dtype.dtype``.
"""
dtype = getattr(dtype, 'dtype', dtype)

if isinstance(dtype, np.dtype):
return False
elif dtype is None:
return False
elif isinstance(dtype, cls):
return True
try:
return cls.construct_from_string(dtype) is not None
except TypeError:
return False


class ExtensionDtype(_DtypeOpsMixin):
"""A custom data type, to be paired with an ExtensionArray.

Notes
-----
The interface includes the following abstract methods that must
be implemented by subclasses:

* type
* name
* construct_from_string

This class does not inherit from 'abc.ABCMeta' for performance reasons.
Methods and properties required by the interface raise
``pandas.errors.AbstractMethodError`` and no ``register`` method is
provided for registering virtual subclasses.
"""

def __str__(self):
return self.name

@property
def type(self):
# type: () -> type
Expand Down Expand Up @@ -87,16 +145,6 @@ def name(self):
"""
raise AbstractMethodError(self)

@property
def names(self):
# type: () -> Optional[List[str]]
"""Ordered list of field names, or None if there are no fields.

This is for compatibility with NumPy arrays, and may be removed in the
future.
"""
return None

@classmethod
def construct_from_string(cls, string):
"""Attempt to construct this type from a string.
Expand Down Expand Up @@ -128,39 +176,3 @@ def construct_from_string(cls, string):
... "'{}'".format(cls, string))
"""
raise AbstractMethodError(cls)

@classmethod
def is_dtype(cls, dtype):
"""Check if we match 'dtype'.

Parameters
----------
dtype : object
The object to check.

Returns
-------
is_dtype : bool

Notes
-----
The default implementation is True if

1. ``cls.construct_from_string(dtype)`` is an instance
of ``cls``.
2. ``dtype`` is an object and is an instance of ``cls``
3. ``dtype`` has a ``dtype`` attribute, and any of the above
conditions is true for ``dtype.dtype``.
"""
dtype = getattr(dtype, 'dtype', dtype)

if isinstance(dtype, np.dtype):
return False
elif dtype is None:
return False
elif isinstance(dtype, cls):
return True
try:
return cls.construct_from_string(dtype) is not None
except TypeError:
return False
6 changes: 4 additions & 2 deletions pandas/core/dtypes/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
_ensure_int32, _ensure_int64,
_NS_DTYPE, _TD_DTYPE, _INT64_DTYPE,
_POSSIBLY_CAST_DTYPES)
from .dtypes import ExtensionDtype, DatetimeTZDtype, PeriodDtype
from .dtypes import (ExtensionDtype, PandasExtensionDtype, DatetimeTZDtype,
PeriodDtype)
from .generic import (ABCDatetimeIndex, ABCPeriodIndex,
ABCSeries)
from .missing import isna, notna
Expand Down Expand Up @@ -1114,7 +1115,8 @@ def find_common_type(types):
if all(is_dtype_equal(first, t) for t in types[1:]):
return first

if any(isinstance(t, ExtensionDtype) for t in types):
if any(isinstance(t, (PandasExtensionDtype, ExtensionDtype))
for t in types):
return np.object

# take lowest unit
Expand Down
4 changes: 2 additions & 2 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
DatetimeTZDtype, DatetimeTZDtypeType,
PeriodDtype, PeriodDtypeType,
IntervalDtype, IntervalDtypeType,
ExtensionDtype)
ExtensionDtype, PandasExtensionDtype)
from .generic import (ABCCategorical, ABCPeriodIndex,
ABCDatetimeIndex, ABCSeries,
ABCSparseArray, ABCSparseSeries, ABCCategoricalIndex,
Expand Down Expand Up @@ -2006,7 +2006,7 @@ def pandas_dtype(dtype):
return CategoricalDtype.construct_from_string(dtype)
except TypeError:
pass
elif isinstance(dtype, ExtensionDtype):
elif isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
return dtype

try:
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/dtypes/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
from pandas import compat
from pandas.core.dtypes.generic import ABCIndexClass, ABCCategoricalIndex

from .base import ExtensionDtype
from .base import ExtensionDtype, _DtypeOpsMixin


class PandasExtensionDtype(ExtensionDtype):
class PandasExtensionDtype(_DtypeOpsMixin):
"""
A np.dtype duck-typed class, suitable for holding a custom dtype.

Expand Down Expand Up @@ -83,7 +83,7 @@ class CategoricalDtypeType(type):
pass


class CategoricalDtype(PandasExtensionDtype):
class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
"""
Type for categorical data with the categories and orderedness

Expand Down
6 changes: 4 additions & 2 deletions pandas/core/internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from pandas.core.dtypes.dtypes import (
ExtensionDtype, DatetimeTZDtype,
PandasExtensionDtype,
CategoricalDtype)
from pandas.core.dtypes.common import (
_TD_DTYPE, _NS_DTYPE,
Expand Down Expand Up @@ -598,7 +599,8 @@ def _astype(self, dtype, copy=False, errors='raise', values=None,
list(errors_legal_values), errors))
raise ValueError(invalid_arg)

if inspect.isclass(dtype) and issubclass(dtype, ExtensionDtype):
if (inspect.isclass(dtype) and
issubclass(dtype, (PandasExtensionDtype, ExtensionDtype))):
msg = ("Expected an instance of {}, but got the class instead. "
"Try instantiating 'dtype'.".format(dtype.__name__))
raise TypeError(msg)
Expand Down Expand Up @@ -5005,7 +5007,7 @@ def _interleaved_dtype(blocks):
dtype = find_common_type([b.dtype for b in blocks])

# only numpy compat
if isinstance(dtype, ExtensionDtype):
if isinstance(dtype, (PandasExtensionDtype, ExtensionDtype)):
dtype = np.object

return dtype
Expand Down
22 changes: 20 additions & 2 deletions pandas/tests/extension/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import pandas.util.testing as tm
from pandas.core.arrays import ExtensionArray
from pandas.core.dtypes.common import is_extension_array_dtype
from pandas.core.dtypes.dtypes import ExtensionDtype
from pandas.core.dtypes import dtypes


class DummyDtype(ExtensionDtype):
class DummyDtype(dtypes.ExtensionDtype):
pass


Expand Down Expand Up @@ -65,3 +65,21 @@ def test_astype_no_copy():

result = arr.astype(arr.dtype)
assert arr.data is not result


@pytest.mark.parametrize('dtype', [
dtypes.DatetimeTZDtype('ns', 'US/Central'),
dtypes.PeriodDtype("D"),
dtypes.IntervalDtype(),
])
def test_is_not_extension_array_dtype(dtype):
assert not isinstance(dtype, dtypes.ExtensionDtype)
assert not is_extension_array_dtype(dtype)


@pytest.mark.parametrize('dtype', [
dtypes.CategoricalDtype(),
])
def test_is_extension_array_dtype(dtype):
assert isinstance(dtype, dtypes.ExtensionDtype)
assert is_extension_array_dtype(dtype)