Skip to content

Commit 3bb31da

Browse files
authored
BUG: use hash-function which takes nans correctly into account for ExtensionDtype (#42135)
1 parent 1963c0e commit 3bb31da

File tree

6 files changed

+28
-3
lines changed

6 files changed

+28
-3
lines changed

pandas/_libs/hashtable.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,5 @@ def ismember(
228228
arr: np.ndarray,
229229
values: np.ndarray,
230230
) -> np.ndarray: ... # np.ndarray[bool]
231+
def object_hash(obj) -> int: ...
232+
def objects_are_equal(a, b) -> bool: ...

pandas/_libs/hashtable.pyx

+10
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ from pandas._libs.khash cimport (
3434
are_equivalent_khcomplex64_t,
3535
are_equivalent_khcomplex128_t,
3636
kh_needed_n_buckets,
37+
kh_python_hash_equal,
38+
kh_python_hash_func,
3739
kh_str_t,
3840
khcomplex64_t,
3941
khcomplex128_t,
@@ -46,6 +48,14 @@ def get_hashtable_trace_domain():
4648
return KHASH_TRACE_DOMAIN
4749

4850

51+
def object_hash(obj):
52+
return kh_python_hash_func(obj)
53+
54+
55+
def objects_are_equal(a, b):
56+
return kh_python_hash_equal(a, b)
57+
58+
4959
cdef int64_t NPY_NAT = util.get_nat()
5060
SIZE_HINT_LIMIT = (1 << 20) + 7
5161

pandas/_libs/khash.pxd

+3
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ cdef extern from "khash_python.h":
4141
bint are_equivalent_float32_t \
4242
"kh_floats_hash_equal" (float32_t a, float32_t b) nogil
4343

44+
uint32_t kh_python_hash_func(object key)
45+
bint kh_python_hash_equal(object a, object b)
46+
4447
ctypedef struct kh_pymap_t:
4548
khuint_t n_buckets, size, n_occupied, upper_bound
4649
uint32_t *flags

pandas/_libs/src/klib/khash_python.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ Py_hash_t PANDAS_INLINE complexobject_hash(PyComplexObject* key) {
287287
}
288288

289289

290-
khint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key);
290+
khuint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key);
291291

292292
//we could use any hashing algorithm, this is the original CPython's for tuples
293293

@@ -328,7 +328,7 @@ Py_hash_t PANDAS_INLINE tupleobject_hash(PyTupleObject* key) {
328328
}
329329

330330

331-
khint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key) {
331+
khuint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key) {
332332
Py_hash_t hash;
333333
// For PyObject_Hash holds:
334334
// hash(0.0) == 0 == hash(-0.0)

pandas/core/dtypes/base.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import numpy as np
1414

15+
from pandas._libs.hashtable import object_hash
1516
from pandas._typing import (
1617
DtypeObj,
1718
type_t,
@@ -128,7 +129,9 @@ def __eq__(self, other: Any) -> bool:
128129
return False
129130

130131
def __hash__(self) -> int:
131-
return hash(tuple(getattr(self, attr) for attr in self._metadata))
132+
# for python>=3.10, different nan objects have different hashes
133+
# we need to avoid that und thus use hash function with old behavior
134+
return object_hash(tuple(getattr(self, attr) for attr in self._metadata))
132135

133136
def __ne__(self, other: Any) -> bool:
134137
return not self.__eq__(other)

pandas/tests/libs/test_hashtable.py

+7
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,13 @@ def test_nan_in_nested_tuple(self):
240240
assert str(error.value) == str(other)
241241

242242

243+
def test_hash_equal_tuple_with_nans():
244+
a = (float("nan"), (float("nan"), float("nan")))
245+
b = (float("nan"), (float("nan"), float("nan")))
246+
assert ht.object_hash(a) == ht.object_hash(b)
247+
assert ht.objects_are_equal(a, b)
248+
249+
243250
def test_get_labels_groupby_for_Int64(writable):
244251
table = ht.Int64HashTable()
245252
vals = np.array([1, 2, -1, 2, 1, -1], dtype=np.int64)

0 commit comments

Comments
 (0)