Skip to content

Commit b8bc510

Browse files
BUG: Hash and compare tuple subclasses as builtin tuples (#59286)
* cast all tuple subclass index keys to tuple * fix docs typo * add multi-index namedtuple test * hash and compare all tuple subclasses as tuples * test hashtable with namedtuples * remove redundant index key conversion * add comments * update whatsnew * check key error message * fix whatsnew section * test namedtuple and tuple interchangeable in hashtable * Update doc/source/whatsnew/v3.0.0.rst Co-authored-by: Matthew Roeschke <[email protected]> * use pytest.raises regexp instead of str eq --------- Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 7c0ee27 commit b8bc510

File tree

4 files changed

+70
-8
lines changed

4 files changed

+70
-8
lines changed

doc/source/whatsnew/v3.0.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Other enhancements
3333
- :func:`DataFrame.to_excel` now raises an ``UserWarning`` when the character count in a cell exceeds Excel's limitation of 32767 characters (:issue:`56954`)
3434
- :func:`read_stata` now returns ``datetime64`` resolutions better matching those natively stored in the stata format (:issue:`55642`)
3535
- :meth:`DataFrame.agg` called with ``axis=1`` and a ``func`` which relabels the result index now raises a ``NotImplementedError`` (:issue:`58807`).
36+
- :meth:`Index.get_loc` now accepts also subclasses of ``tuple`` as keys (:issue:`57922`)
3637
- :meth:`Styler.set_tooltips` provides alternative method to storing tooltips by using title attribute of td elements. (:issue:`56981`)
3738
- Allow dictionaries to be passed to :meth:`pandas.Series.str.replace` via ``pat`` parameter (:issue:`51748`)
3839
- Support passing a :class:`Series` input to :func:`json_normalize` that retains the :class:`Series` :class:`Index` (:issue:`51452`)
@@ -231,6 +232,7 @@ Other API changes
231232
^^^^^^^^^^^^^^^^^
232233
- 3rd party ``py.path`` objects are no longer explicitly supported in IO methods. Use :py:class:`pathlib.Path` objects instead (:issue:`57091`)
233234
- :func:`read_table`'s ``parse_dates`` argument defaults to ``None`` to improve consistency with :func:`read_csv` (:issue:`57476`)
235+
- All classes inheriting from builtin ``tuple`` (including types created with :func:`collections.namedtuple`) are now hashed and compared as builtin ``tuple`` during indexing operations (:issue:`57922`)
234236
- Made ``dtype`` a required argument in :meth:`ExtensionArray._from_sequence_of_strings` (:issue:`56519`)
235237
- Passing a :class:`Series` input to :func:`json_normalize` will now retain the :class:`Series` :class:`Index`, previously output had a new :class:`RangeIndex` (:issue:`51452`)
236238
- Updated :meth:`DataFrame.to_excel` so that the output spreadsheet has no styling. Custom styling can still be done using :meth:`Styler.to_excel` (:issue:`54154`)

pandas/_libs/include/pandas/vendored/klib/khash_python.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ static inline int pyobject_cmp(PyObject *a, PyObject *b) {
207207
if (PyComplex_CheckExact(a)) {
208208
return complexobject_cmp((PyComplexObject *)a, (PyComplexObject *)b);
209209
}
210-
if (PyTuple_CheckExact(a)) {
210+
if (PyTuple_Check(a)) {
211+
// compare tuple subclasses as builtin tuples
211212
return tupleobject_cmp((PyTupleObject *)a, (PyTupleObject *)b);
212213
}
213214
// frozenset isn't yet supported
@@ -311,7 +312,8 @@ static inline khuint32_t kh_python_hash_func(PyObject *key) {
311312
// because complex(k,0) == k holds for any int-object k
312313
// and kh_complex128_hash_func doesn't respect it
313314
hash = complexobject_hash((PyComplexObject *)key);
314-
} else if (PyTuple_CheckExact(key)) {
315+
} else if (PyTuple_Check(key)) {
316+
// hash tuple subclasses as builtin tuples
315317
hash = tupleobject_hash((PyTupleObject *)key);
316318
} else {
317319
hash = PyObject_Hash(key);

pandas/tests/indexes/multi/test_indexing.py

+24
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import namedtuple
12
from datetime import timedelta
23
import re
34

@@ -1006,3 +1007,26 @@ def test_get_indexer_for_multiindex_with_nans(nulls_fixture):
10061007
result = idx1.get_indexer(idx2)
10071008
expected = np.array([-1, 1], dtype=np.intp)
10081009
tm.assert_numpy_array_equal(result, expected)
1010+
1011+
1012+
def test_get_loc_namedtuple_behaves_like_tuple():
1013+
# GH57922
1014+
NamedIndex = namedtuple("NamedIndex", ("a", "b"))
1015+
multi_idx = MultiIndex.from_tuples(
1016+
[NamedIndex("i1", "i2"), NamedIndex("i3", "i4"), NamedIndex("i5", "i6")]
1017+
)
1018+
for idx in (multi_idx, multi_idx.to_flat_index()):
1019+
assert idx.get_loc(NamedIndex("i1", "i2")) == 0
1020+
assert idx.get_loc(NamedIndex("i3", "i4")) == 1
1021+
assert idx.get_loc(NamedIndex("i5", "i6")) == 2
1022+
assert idx.get_loc(("i1", "i2")) == 0
1023+
assert idx.get_loc(("i3", "i4")) == 1
1024+
assert idx.get_loc(("i5", "i6")) == 2
1025+
multi_idx = MultiIndex.from_tuples([("i1", "i2"), ("i3", "i4"), ("i5", "i6")])
1026+
for idx in (multi_idx, multi_idx.to_flat_index()):
1027+
assert idx.get_loc(NamedIndex("i1", "i2")) == 0
1028+
assert idx.get_loc(NamedIndex("i3", "i4")) == 1
1029+
assert idx.get_loc(NamedIndex("i5", "i6")) == 2
1030+
assert idx.get_loc(("i1", "i2")) == 0
1031+
assert idx.get_loc(("i3", "i4")) == 1
1032+
assert idx.get_loc(("i5", "i6")) == 2

pandas/tests/libs/test_hashtable.py

+40-6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import namedtuple
12
from collections.abc import Generator
23
from contextlib import contextmanager
34
import re
@@ -405,9 +406,8 @@ def test_nan_complex_real(self):
405406
table = ht.PyObjectHashTable()
406407
table.set_item(nan1, 42)
407408
assert table.get_item(nan2) == 42
408-
with pytest.raises(KeyError, match=None) as error:
409+
with pytest.raises(KeyError, match=re.escape(repr(other))):
409410
table.get_item(other)
410-
assert str(error.value) == str(other)
411411

412412
def test_nan_complex_imag(self):
413413
nan1 = complex(1, float("nan"))
@@ -417,9 +417,8 @@ def test_nan_complex_imag(self):
417417
table = ht.PyObjectHashTable()
418418
table.set_item(nan1, 42)
419419
assert table.get_item(nan2) == 42
420-
with pytest.raises(KeyError, match=None) as error:
420+
with pytest.raises(KeyError, match=re.escape(repr(other))):
421421
table.get_item(other)
422-
assert str(error.value) == str(other)
423422

424423
def test_nan_in_tuple(self):
425424
nan1 = (float("nan"),)
@@ -436,9 +435,28 @@ def test_nan_in_nested_tuple(self):
436435
table = ht.PyObjectHashTable()
437436
table.set_item(nan1, 42)
438437
assert table.get_item(nan2) == 42
439-
with pytest.raises(KeyError, match=None) as error:
438+
with pytest.raises(KeyError, match=re.escape(repr(other))):
439+
table.get_item(other)
440+
441+
def test_nan_in_namedtuple(self):
442+
T = namedtuple("T", ["x"])
443+
nan1 = T(float("nan"))
444+
nan2 = T(float("nan"))
445+
assert nan1.x is not nan2.x
446+
table = ht.PyObjectHashTable()
447+
table.set_item(nan1, 42)
448+
assert table.get_item(nan2) == 42
449+
450+
def test_nan_in_nested_namedtuple(self):
451+
T = namedtuple("T", ["x", "y"])
452+
nan1 = T(1, (2, (float("nan"),)))
453+
nan2 = T(1, (2, (float("nan"),)))
454+
other = T(1, 2)
455+
table = ht.PyObjectHashTable()
456+
table.set_item(nan1, 42)
457+
assert table.get_item(nan2) == 42
458+
with pytest.raises(KeyError, match=re.escape(repr(other))):
440459
table.get_item(other)
441-
assert str(error.value) == str(other)
442460

443461

444462
def test_hash_equal_tuple_with_nans():
@@ -448,6 +466,22 @@ def test_hash_equal_tuple_with_nans():
448466
assert ht.objects_are_equal(a, b)
449467

450468

469+
def test_hash_equal_namedtuple_with_nans():
470+
T = namedtuple("T", ["x", "y"])
471+
a = T(float("nan"), (float("nan"), float("nan")))
472+
b = T(float("nan"), (float("nan"), float("nan")))
473+
assert ht.object_hash(a) == ht.object_hash(b)
474+
assert ht.objects_are_equal(a, b)
475+
476+
477+
def test_hash_equal_namedtuple_and_tuple():
478+
T = namedtuple("T", ["x", "y"])
479+
a = T(1, (2, 3))
480+
b = (1, (2, 3))
481+
assert ht.object_hash(a) == ht.object_hash(b)
482+
assert ht.objects_are_equal(a, b)
483+
484+
451485
def test_get_labels_groupby_for_Int64(writable):
452486
table = ht.Int64HashTable()
453487
vals = np.array([1, 2, -1, 2, 1, -1], dtype=np.int64)

0 commit comments

Comments
 (0)