Skip to content

Commit 796e205

Browse files
Backport PR pandas-dev#41952: BUG: take nans correctly into consideration in complex and tuple (pandas-dev#42058)
Co-authored-by: realead <[email protected]>
1 parent b756601 commit 796e205

File tree

4 files changed

+151
-19
lines changed

4 files changed

+151
-19
lines changed

doc/source/whatsnew/v1.3.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1033,6 +1033,7 @@ Missing
10331033
- Bug in :meth:`DataFrame.fillna` not accepting a dictionary for the ``downcast`` keyword (:issue:`40809`)
10341034
- Bug in :func:`isna` not returning a copy of the mask for nullable types, causing any subsequent mask modification to change the original array (:issue:`40935`)
10351035
- Bug in :class:`DataFrame` construction with float data containing ``NaN`` and an integer ``dtype`` casting instead of retaining the ``NaN`` (:issue:`26919`)
1036+
- Bug in :meth:`Series.isin` and :meth:`MultiIndex.isin` didn't treat all nans as equivalent if they were in tuples (:issue:`41836`)
10361037

10371038
MultiIndex
10381039
^^^^^^^^^^

pandas/_libs/src/klib/khash_python.h

+78-6
Original file line numberDiff line numberDiff line change
@@ -163,18 +163,90 @@ KHASH_MAP_INIT_COMPLEX128(complex128, size_t)
163163
#define kh_exist_complex128(h, k) (kh_exist(h, k))
164164

165165

166+
// NaN-floats should be in the same equivalency class, see GH 22119
167+
int PANDAS_INLINE floatobject_cmp(PyFloatObject* a, PyFloatObject* b){
168+
return (
169+
Py_IS_NAN(PyFloat_AS_DOUBLE(a)) &&
170+
Py_IS_NAN(PyFloat_AS_DOUBLE(b))
171+
)
172+
||
173+
( PyFloat_AS_DOUBLE(a) == PyFloat_AS_DOUBLE(b) );
174+
}
175+
176+
177+
// NaNs should be in the same equivalency class, see GH 41836
178+
// PyObject_RichCompareBool for complexobjects has a different behavior
179+
// needs to be replaced
180+
int PANDAS_INLINE complexobject_cmp(PyComplexObject* a, PyComplexObject* b){
181+
return (
182+
Py_IS_NAN(a->cval.real) &&
183+
Py_IS_NAN(b->cval.real) &&
184+
Py_IS_NAN(a->cval.imag) &&
185+
Py_IS_NAN(b->cval.imag)
186+
)
187+
||
188+
(
189+
Py_IS_NAN(a->cval.real) &&
190+
Py_IS_NAN(b->cval.real) &&
191+
a->cval.imag == b->cval.imag
192+
)
193+
||
194+
(
195+
a->cval.real == b->cval.real &&
196+
Py_IS_NAN(a->cval.imag) &&
197+
Py_IS_NAN(b->cval.imag)
198+
)
199+
||
200+
(
201+
a->cval.real == b->cval.real &&
202+
a->cval.imag == b->cval.imag
203+
);
204+
}
205+
206+
int PANDAS_INLINE pyobject_cmp(PyObject* a, PyObject* b);
207+
208+
209+
// replacing PyObject_RichCompareBool (NaN!=NaN) with pyobject_cmp (NaN==NaN),
210+
// which treats NaNs as equivalent
211+
// see GH 41836
212+
int PANDAS_INLINE tupleobject_cmp(PyTupleObject* a, PyTupleObject* b){
213+
Py_ssize_t i;
214+
215+
if (Py_SIZE(a) != Py_SIZE(b)) {
216+
return 0;
217+
}
218+
219+
for (i = 0; i < Py_SIZE(a); ++i) {
220+
if (!pyobject_cmp(PyTuple_GET_ITEM(a, i), PyTuple_GET_ITEM(b, i))) {
221+
return 0;
222+
}
223+
}
224+
return 1;
225+
}
226+
227+
166228
int PANDAS_INLINE pyobject_cmp(PyObject* a, PyObject* b) {
229+
if (Py_TYPE(a) == Py_TYPE(b)) {
230+
// special handling for some built-in types which could have NaNs
231+
// as we would like to have them equivalent, but the usual
232+
// PyObject_RichCompareBool would return False
233+
if (PyFloat_CheckExact(a)) {
234+
return floatobject_cmp((PyFloatObject*)a, (PyFloatObject*)b);
235+
}
236+
if (PyComplex_CheckExact(a)) {
237+
return complexobject_cmp((PyComplexObject*)a, (PyComplexObject*)b);
238+
}
239+
if (PyTuple_CheckExact(a)) {
240+
return tupleobject_cmp((PyTupleObject*)a, (PyTupleObject*)b);
241+
}
242+
// frozenset isn't yet supported
243+
}
244+
167245
int result = PyObject_RichCompareBool(a, b, Py_EQ);
168246
if (result < 0) {
169247
PyErr_Clear();
170248
return 0;
171249
}
172-
if (result == 0) { // still could be two NaNs
173-
return PyFloat_CheckExact(a) &&
174-
PyFloat_CheckExact(b) &&
175-
Py_IS_NAN(PyFloat_AS_DOUBLE(a)) &&
176-
Py_IS_NAN(PyFloat_AS_DOUBLE(b));
177-
}
178250
return result;
179251
}
180252

pandas/tests/indexes/multi/test_isin.py

+1-13
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import numpy as np
22
import pytest
33

4-
from pandas.compat import PYPY
5-
64
from pandas import MultiIndex
75
import pandas._testing as tm
86

97

10-
@pytest.mark.skipif(not PYPY, reason="tuples cmp recursively on PyPy")
11-
def test_isin_nan_pypy():
8+
def test_isin_nan():
129
idx = MultiIndex.from_arrays([["foo", "bar"], [1.0, np.nan]])
1310
tm.assert_numpy_array_equal(idx.isin([("bar", np.nan)]), np.array([False, True]))
1411
tm.assert_numpy_array_equal(
@@ -31,15 +28,6 @@ def test_isin():
3128
assert result.dtype == np.bool_
3229

3330

34-
@pytest.mark.skipif(PYPY, reason="tuples cmp recursively on PyPy")
35-
def test_isin_nan_not_pypy():
36-
idx = MultiIndex.from_arrays([["foo", "bar"], [1.0, np.nan]])
37-
tm.assert_numpy_array_equal(idx.isin([("bar", np.nan)]), np.array([False, False]))
38-
tm.assert_numpy_array_equal(
39-
idx.isin([("bar", float("nan"))]), np.array([False, False])
40-
)
41-
42-
4331
def test_isin_level_kwarg():
4432
idx = MultiIndex.from_arrays([["qux", "baz", "foo", "bar"], np.arange(4)])
4533

pandas/tests/libs/test_hashtable.py

+71
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import pandas as pd
1010
import pandas._testing as tm
11+
from pandas.core.algorithms import isin
1112

1213

1314
@contextmanager
@@ -178,6 +179,67 @@ def test_no_reallocation(self, table_type, dtype):
178179
assert n_buckets_start == clean_table.get_state()["n_buckets"]
179180

180181

182+
class TestPyObjectHashTableWithNans:
183+
def test_nan_float(self):
184+
nan1 = float("nan")
185+
nan2 = float("nan")
186+
assert nan1 is not nan2
187+
table = ht.PyObjectHashTable()
188+
table.set_item(nan1, 42)
189+
assert table.get_item(nan2) == 42
190+
191+
def test_nan_complex_both(self):
192+
nan1 = complex(float("nan"), float("nan"))
193+
nan2 = complex(float("nan"), float("nan"))
194+
assert nan1 is not nan2
195+
table = ht.PyObjectHashTable()
196+
table.set_item(nan1, 42)
197+
assert table.get_item(nan2) == 42
198+
199+
def test_nan_complex_real(self):
200+
nan1 = complex(float("nan"), 1)
201+
nan2 = complex(float("nan"), 1)
202+
other = complex(float("nan"), 2)
203+
assert nan1 is not nan2
204+
table = ht.PyObjectHashTable()
205+
table.set_item(nan1, 42)
206+
assert table.get_item(nan2) == 42
207+
with pytest.raises(KeyError, match=None) as error:
208+
table.get_item(other)
209+
assert str(error.value) == str(other)
210+
211+
def test_nan_complex_imag(self):
212+
nan1 = complex(1, float("nan"))
213+
nan2 = complex(1, float("nan"))
214+
other = complex(2, float("nan"))
215+
assert nan1 is not nan2
216+
table = ht.PyObjectHashTable()
217+
table.set_item(nan1, 42)
218+
assert table.get_item(nan2) == 42
219+
with pytest.raises(KeyError, match=None) as error:
220+
table.get_item(other)
221+
assert str(error.value) == str(other)
222+
223+
def test_nan_in_tuple(self):
224+
nan1 = (float("nan"),)
225+
nan2 = (float("nan"),)
226+
assert nan1[0] is not nan2[0]
227+
table = ht.PyObjectHashTable()
228+
table.set_item(nan1, 42)
229+
assert table.get_item(nan2) == 42
230+
231+
def test_nan_in_nested_tuple(self):
232+
nan1 = (1, (2, (float("nan"),)))
233+
nan2 = (1, (2, (float("nan"),)))
234+
other = (1, 2)
235+
table = ht.PyObjectHashTable()
236+
table.set_item(nan1, 42)
237+
assert table.get_item(nan2) == 42
238+
with pytest.raises(KeyError, match=None) as error:
239+
table.get_item(other)
240+
assert str(error.value) == str(other)
241+
242+
181243
def test_get_labels_groupby_for_Int64(writable):
182244
table = ht.Int64HashTable()
183245
vals = np.array([1, 2, -1, 2, 1, -1], dtype=np.int64)
@@ -426,3 +488,12 @@ def test_mode(self, dtype, type_suffix):
426488
values = np.array([42, np.nan, np.nan, np.nan], dtype=dtype)
427489
assert mode(values, True) == 42
428490
assert np.isnan(mode(values, False))
491+
492+
493+
def test_ismember_tuple_with_nans():
494+
# GH-41836
495+
values = [("a", float("nan")), ("b", 1)]
496+
comps = [("a", float("nan"))]
497+
result = isin(values, comps)
498+
expected = np.array([True, False], dtype=np.bool_)
499+
tm.assert_numpy_array_equal(result, expected)

0 commit comments

Comments
 (0)