diff --git a/doc/source/whatsnew/v0.24.0.txt b/doc/source/whatsnew/v0.24.0.txt index c9874b4dd03d6..a1467cbca963a 100644 --- a/doc/source/whatsnew/v0.24.0.txt +++ b/doc/source/whatsnew/v0.24.0.txt @@ -492,6 +492,15 @@ Previous Behavior: ExtensionType Changes ^^^^^^^^^^^^^^^^^^^^^ +**:class:`pandas.api.extensions.ExtensionDtype` Equality and Hashability** + +Pandas now requires that extension dtypes be hashable. The base class implements +a default ``__eq__`` and ``__hash__``. If you have a parametrized dtype, you should +update the ``ExtensionDtype._metadata`` tuple to match the signature of your +``__init__`` method. See :class:`pandas.api.extensions.ExtensionDtype` for more (:issue:`22476`). + +**Other changes** + - ``ExtensionArray`` has gained the abstract methods ``.dropna()`` (:issue:`21185`) - ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`) diff --git a/pandas/core/dtypes/base.py b/pandas/core/dtypes/base.py index b0fa55e346613..ac4d6d1590f38 100644 --- a/pandas/core/dtypes/base.py +++ b/pandas/core/dtypes/base.py @@ -22,14 +22,17 @@ class _DtypeOpsMixin(object): # of the NA value, not the physical NA vaalue for storage. # e.g. for JSONArray, this is an empty dictionary. na_value = np.nan + _metadata = () def __eq__(self, other): """Check whether 'other' is equal to self. - By default, 'other' is considered equal if + By default, 'other' is considered equal if either * it's a string matching 'self.name'. - * it's an instance of this type. + * it's an instance of this type and all of the + the attributes in ``self._metadata`` are equal between + `self` and `other`. Parameters ---------- @@ -40,11 +43,19 @@ def __eq__(self, other): bool """ if isinstance(other, compat.string_types): - return other == self.name - elif isinstance(other, type(self)): - return True - else: - return False + try: + other = self.construct_from_string(other) + except TypeError: + return False + if isinstance(other, type(self)): + return all( + getattr(self, attr) == getattr(other, attr) + for attr in self._metadata + ) + return False + + def __hash__(self): + return hash(tuple(getattr(self, attr) for attr in self._metadata)) def __ne__(self, other): return not self.__eq__(other) @@ -161,6 +172,26 @@ class ExtensionDtype(_DtypeOpsMixin): The `na_value` class attribute can be used to set the default NA value for this type. :attr:`numpy.nan` is used by default. + ExtensionDtypes are required to be hashable. The base class provides + a default implementation, which relies on the ``_metadata`` class + attribute. ``_metadata`` should be a tuple containing the strings + that define your data type. For example, with ``PeriodDtype`` that's + the ``freq`` attribute. + + **If you have a parametrized dtype you should set the ``_metadata`` + class property**. + + Ideally, the attributes in ``_metadata`` will match the + parameters to your ``ExtensionDtype.__init__`` (if any). If any of + the attributes in ``_metadata`` don't implement the standard + ``__eq__`` or ``__hash__``, the default implementations here will not + work. + + .. versionchanged:: 0.24.0 + + Added ``_metadata``, ``__hash__``, and changed the default definition + of ``__eq__``. + This class does not inherit from 'abc.ABCMeta' for performance reasons. Methods and properties required by the interface raise ``pandas.errors.AbstractMethodError`` and no ``register`` method is diff --git a/pandas/core/dtypes/dtypes.py b/pandas/core/dtypes/dtypes.py index beda9bc02f4d5..611cae28877c3 100644 --- a/pandas/core/dtypes/dtypes.py +++ b/pandas/core/dtypes/dtypes.py @@ -101,7 +101,6 @@ class PandasExtensionDtype(_DtypeOpsMixin): base = None isbuiltin = 0 isnative = 0 - _metadata = [] _cache = {} def __unicode__(self): @@ -209,7 +208,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype): kind = 'O' str = '|O08' base = np.dtype('O') - _metadata = ['categories', 'ordered'] + _metadata = ('categories', 'ordered') _cache = {} def __init__(self, categories=None, ordered=None): @@ -485,7 +484,7 @@ class DatetimeTZDtype(PandasExtensionDtype): str = '|M8[ns]' num = 101 base = np.dtype('M8[ns]') - _metadata = ['unit', 'tz'] + _metadata = ('unit', 'tz') _match = re.compile(r"(datetime64|M8)\[(?P.+), (?P.+)\]") _cache = {} @@ -589,7 +588,7 @@ class PeriodDtype(PandasExtensionDtype): str = '|O08' base = np.dtype('O') num = 102 - _metadata = ['freq'] + _metadata = ('freq',) _match = re.compile(r"(P|p)eriod\[(?P.+)\]") _cache = {} @@ -709,7 +708,7 @@ class IntervalDtype(PandasExtensionDtype, ExtensionDtype): str = '|O08' base = np.dtype('O') num = 103 - _metadata = ['subtype'] + _metadata = ('subtype',) _match = re.compile(r"(I|i)nterval\[(?P.+)\]") _cache = {} diff --git a/pandas/tests/extension/base/dtype.py b/pandas/tests/extension/base/dtype.py index 8d1f1cadcc23f..d5cf9571e3622 100644 --- a/pandas/tests/extension/base/dtype.py +++ b/pandas/tests/extension/base/dtype.py @@ -49,6 +49,10 @@ def test_eq_with_str(self, dtype): def test_eq_with_numpy_object(self, dtype): assert dtype != np.dtype('object') + def test_eq_with_self(self, dtype): + assert dtype == dtype + assert dtype != object() + def test_array_type(self, data, dtype): assert dtype.construct_array_type() is type(data) @@ -81,3 +85,6 @@ def test_check_dtype(self, data): index=list('ABCD')) result = df.dtypes.apply(str) == str(dtype) self.assert_series_equal(result, expected) + + def test_hashable(self, dtype): + hash(dtype) # no error diff --git a/pandas/tests/extension/decimal/array.py b/pandas/tests/extension/decimal/array.py index 79e1a692f744a..a1ee3a4fefef2 100644 --- a/pandas/tests/extension/decimal/array.py +++ b/pandas/tests/extension/decimal/array.py @@ -15,15 +15,11 @@ class DecimalDtype(ExtensionDtype): type = decimal.Decimal name = 'decimal' na_value = decimal.Decimal('NaN') + _metadata = ('context',) def __init__(self, context=None): self.context = context or decimal.getcontext() - def __eq__(self, other): - if isinstance(other, type(self)): - return self.context == other.context - return super(DecimalDtype, self).__eq__(other) - def __repr__(self): return 'DecimalDtype(context={})'.format(self.context) diff --git a/pandas/tests/extension/decimal/test_decimal.py b/pandas/tests/extension/decimal/test_decimal.py index dd625d6e1eb3c..d65b4dde832e4 100644 --- a/pandas/tests/extension/decimal/test_decimal.py +++ b/pandas/tests/extension/decimal/test_decimal.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd +from pandas import compat import pandas.util.testing as tm import pytest @@ -93,7 +94,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs): class TestDtype(BaseDecimal, base.BaseDtypeTests): - pass + @pytest.mark.skipif(compat.PY2, reason="Context not hashable.") + def test_hashable(self, dtype): + pass class TestInterface(BaseDecimal, base.BaseInterfaceTests): diff --git a/pandas/tests/extension/json/array.py b/pandas/tests/extension/json/array.py index 87876d84bef99..976511941042d 100644 --- a/pandas/tests/extension/json/array.py +++ b/pandas/tests/extension/json/array.py @@ -27,6 +27,7 @@ class JSONDtype(ExtensionDtype): type = compat.Mapping name = 'json' + try: na_value = collections.UserDict() except AttributeError: