Skip to content

Commit e6cc042

Browse files
TomAugspurgertm9k1
authored andcommitted
API: ExtensionDtype Equality and Hashability (pandas-dev#22996)
* API: ExtensionDtype Equality and Hashability Closes pandas-dev#22476
1 parent 60cac71 commit e6cc042

File tree

7 files changed

+64
-18
lines changed

7 files changed

+64
-18
lines changed

doc/source/whatsnew/v0.24.0.txt

+9
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,15 @@ Previous Behavior:
527527
ExtensionType Changes
528528
^^^^^^^^^^^^^^^^^^^^^
529529

530+
**:class:`pandas.api.extensions.ExtensionDtype` Equality and Hashability**
531+
532+
Pandas now requires that extension dtypes be hashable. The base class implements
533+
a default ``__eq__`` and ``__hash__``. If you have a parametrized dtype, you should
534+
update the ``ExtensionDtype._metadata`` tuple to match the signature of your
535+
``__init__`` method. See :class:`pandas.api.extensions.ExtensionDtype` for more (:issue:`22476`).
536+
537+
**Other changes**
538+
530539
- ``ExtensionArray`` has gained the abstract methods ``.dropna()`` (:issue:`21185`)
531540
- ``ExtensionDtype`` has gained the ability to instantiate from string dtypes, e.g. ``decimal`` would instantiate a registered ``DecimalDtype``; furthermore
532541
the ``ExtensionDtype`` has gained the method ``construct_array_type`` (:issue:`21185`)

pandas/core/dtypes/base.py

+38-7
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,17 @@ class _DtypeOpsMixin(object):
2222
# of the NA value, not the physical NA vaalue for storage.
2323
# e.g. for JSONArray, this is an empty dictionary.
2424
na_value = np.nan
25+
_metadata = ()
2526

2627
def __eq__(self, other):
2728
"""Check whether 'other' is equal to self.
2829
29-
By default, 'other' is considered equal if
30+
By default, 'other' is considered equal if either
3031
3132
* it's a string matching 'self.name'.
32-
* it's an instance of this type.
33+
* it's an instance of this type and all of the
34+
the attributes in ``self._metadata`` are equal between
35+
`self` and `other`.
3336
3437
Parameters
3538
----------
@@ -40,11 +43,19 @@ def __eq__(self, other):
4043
bool
4144
"""
4245
if isinstance(other, compat.string_types):
43-
return other == self.name
44-
elif isinstance(other, type(self)):
45-
return True
46-
else:
47-
return False
46+
try:
47+
other = self.construct_from_string(other)
48+
except TypeError:
49+
return False
50+
if isinstance(other, type(self)):
51+
return all(
52+
getattr(self, attr) == getattr(other, attr)
53+
for attr in self._metadata
54+
)
55+
return False
56+
57+
def __hash__(self):
58+
return hash(tuple(getattr(self, attr) for attr in self._metadata))
4859

4960
def __ne__(self, other):
5061
return not self.__eq__(other)
@@ -161,6 +172,26 @@ class ExtensionDtype(_DtypeOpsMixin):
161172
The `na_value` class attribute can be used to set the default NA value
162173
for this type. :attr:`numpy.nan` is used by default.
163174
175+
ExtensionDtypes are required to be hashable. The base class provides
176+
a default implementation, which relies on the ``_metadata`` class
177+
attribute. ``_metadata`` should be a tuple containing the strings
178+
that define your data type. For example, with ``PeriodDtype`` that's
179+
the ``freq`` attribute.
180+
181+
**If you have a parametrized dtype you should set the ``_metadata``
182+
class property**.
183+
184+
Ideally, the attributes in ``_metadata`` will match the
185+
parameters to your ``ExtensionDtype.__init__`` (if any). If any of
186+
the attributes in ``_metadata`` don't implement the standard
187+
``__eq__`` or ``__hash__``, the default implementations here will not
188+
work.
189+
190+
.. versionchanged:: 0.24.0
191+
192+
Added ``_metadata``, ``__hash__``, and changed the default definition
193+
of ``__eq__``.
194+
164195
This class does not inherit from 'abc.ABCMeta' for performance reasons.
165196
Methods and properties required by the interface raise
166197
``pandas.errors.AbstractMethodError`` and no ``register`` method is

pandas/core/dtypes/dtypes.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ class PandasExtensionDtype(_DtypeOpsMixin):
101101
base = None
102102
isbuiltin = 0
103103
isnative = 0
104-
_metadata = []
105104
_cache = {}
106105

107106
def __unicode__(self):
@@ -209,7 +208,7 @@ class CategoricalDtype(PandasExtensionDtype, ExtensionDtype):
209208
kind = 'O'
210209
str = '|O08'
211210
base = np.dtype('O')
212-
_metadata = ['categories', 'ordered']
211+
_metadata = ('categories', 'ordered')
213212
_cache = {}
214213

215214
def __init__(self, categories=None, ordered=None):
@@ -485,7 +484,7 @@ class DatetimeTZDtype(PandasExtensionDtype):
485484
str = '|M8[ns]'
486485
num = 101
487486
base = np.dtype('M8[ns]')
488-
_metadata = ['unit', 'tz']
487+
_metadata = ('unit', 'tz')
489488
_match = re.compile(r"(datetime64|M8)\[(?P<unit>.+), (?P<tz>.+)\]")
490489
_cache = {}
491490

@@ -589,7 +588,7 @@ class PeriodDtype(PandasExtensionDtype):
589588
str = '|O08'
590589
base = np.dtype('O')
591590
num = 102
592-
_metadata = ['freq']
591+
_metadata = ('freq',)
593592
_match = re.compile(r"(P|p)eriod\[(?P<freq>.+)\]")
594593
_cache = {}
595594

@@ -709,7 +708,7 @@ class IntervalDtype(PandasExtensionDtype, ExtensionDtype):
709708
str = '|O08'
710709
base = np.dtype('O')
711710
num = 103
712-
_metadata = ['subtype']
711+
_metadata = ('subtype',)
713712
_match = re.compile(r"(I|i)nterval\[(?P<subtype>.+)\]")
714713
_cache = {}
715714

pandas/tests/extension/base/dtype.py

+7
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def test_eq_with_str(self, dtype):
4949
def test_eq_with_numpy_object(self, dtype):
5050
assert dtype != np.dtype('object')
5151

52+
def test_eq_with_self(self, dtype):
53+
assert dtype == dtype
54+
assert dtype != object()
55+
5256
def test_array_type(self, data, dtype):
5357
assert dtype.construct_array_type() is type(data)
5458

@@ -81,3 +85,6 @@ def test_check_dtype(self, data):
8185
index=list('ABCD'))
8286
result = df.dtypes.apply(str) == str(dtype)
8387
self.assert_series_equal(result, expected)
88+
89+
def test_hashable(self, dtype):
90+
hash(dtype) # no error

pandas/tests/extension/decimal/array.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,11 @@ class DecimalDtype(ExtensionDtype):
1515
type = decimal.Decimal
1616
name = 'decimal'
1717
na_value = decimal.Decimal('NaN')
18+
_metadata = ('context',)
1819

1920
def __init__(self, context=None):
2021
self.context = context or decimal.getcontext()
2122

22-
def __eq__(self, other):
23-
if isinstance(other, type(self)):
24-
return self.context == other.context
25-
return super(DecimalDtype, self).__eq__(other)
26-
2723
def __repr__(self):
2824
return 'DecimalDtype(context={})'.format(self.context)
2925

pandas/tests/extension/decimal/test_decimal.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import pandas as pd
6+
from pandas import compat
67
import pandas.util.testing as tm
78
import pytest
89

@@ -93,7 +94,9 @@ def assert_frame_equal(self, left, right, *args, **kwargs):
9394

9495

9596
class TestDtype(BaseDecimal, base.BaseDtypeTests):
96-
pass
97+
@pytest.mark.skipif(compat.PY2, reason="Context not hashable.")
98+
def test_hashable(self, dtype):
99+
pass
97100

98101

99102
class TestInterface(BaseDecimal, base.BaseInterfaceTests):

pandas/tests/extension/json/array.py

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
class JSONDtype(ExtensionDtype):
2828
type = compat.Mapping
2929
name = 'json'
30+
3031
try:
3132
na_value = collections.UserDict()
3233
except AttributeError:

0 commit comments

Comments
 (0)