Skip to content

Backport PR #42135 on branch 1.3.x (BUG: use hash-function which takes nans correctly into account for ExtensionDtype) #42183

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/_libs/hashtable.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,5 @@ def ismember(
arr: np.ndarray,
values: np.ndarray,
) -> np.ndarray: ... # np.ndarray[bool]
def object_hash(obj) -> int: ...
def objects_are_equal(a, b) -> bool: ...
10 changes: 10 additions & 0 deletions pandas/_libs/hashtable.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ from pandas._libs.khash cimport (
are_equivalent_khcomplex64_t,
are_equivalent_khcomplex128_t,
kh_needed_n_buckets,
kh_python_hash_equal,
kh_python_hash_func,
kh_str_t,
khcomplex64_t,
khcomplex128_t,
Expand All @@ -46,6 +48,14 @@ def get_hashtable_trace_domain():
return KHASH_TRACE_DOMAIN


def object_hash(obj):
return kh_python_hash_func(obj)


def objects_are_equal(a, b):
return kh_python_hash_equal(a, b)


cdef int64_t NPY_NAT = util.get_nat()
SIZE_HINT_LIMIT = (1 << 20) + 7

Expand Down
3 changes: 3 additions & 0 deletions pandas/_libs/khash.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ cdef extern from "khash_python.h":
bint are_equivalent_float32_t \
"kh_floats_hash_equal" (float32_t a, float32_t b) nogil

uint32_t kh_python_hash_func(object key)
bint kh_python_hash_equal(object a, object b)

ctypedef struct kh_pymap_t:
khuint_t n_buckets, size, n_occupied, upper_bound
uint32_t *flags
Expand Down
4 changes: 2 additions & 2 deletions pandas/_libs/src/klib/khash_python.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ Py_hash_t PANDAS_INLINE complexobject_hash(PyComplexObject* key) {
}


khint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key);
khuint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key);

//we could use any hashing algorithm, this is the original CPython's for tuples

Expand Down Expand Up @@ -328,7 +328,7 @@ Py_hash_t PANDAS_INLINE tupleobject_hash(PyTupleObject* key) {
}


khint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key) {
khuint32_t PANDAS_INLINE kh_python_hash_func(PyObject* key) {
Py_hash_t hash;
// For PyObject_Hash holds:
// hash(0.0) == 0 == hash(-0.0)
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/dtypes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import numpy as np

from pandas._libs.hashtable import object_hash
from pandas._typing import (
DtypeObj,
type_t,
Expand Down Expand Up @@ -128,7 +129,9 @@ def __eq__(self, other: Any) -> bool:
return False

def __hash__(self) -> int:
return hash(tuple(getattr(self, attr) for attr in self._metadata))
# for python>=3.10, different nan objects have different hashes
# we need to avoid that und thus use hash function with old behavior
return object_hash(tuple(getattr(self, attr) for attr in self._metadata))

def __ne__(self, other: Any) -> bool:
return not self.__eq__(other)
Expand Down
7 changes: 7 additions & 0 deletions pandas/tests/libs/test_hashtable.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ def test_nan_in_nested_tuple(self):
assert str(error.value) == str(other)


def test_hash_equal_tuple_with_nans():
a = (float("nan"), (float("nan"), float("nan")))
b = (float("nan"), (float("nan"), float("nan")))
assert ht.object_hash(a) == ht.object_hash(b)
assert ht.objects_are_equal(a, b)


def test_get_labels_groupby_for_Int64(writable):
table = ht.Int64HashTable()
vals = np.array([1, 2, -1, 2, 1, -1], dtype=np.int64)
Expand Down