Skip to content

Commit d7f25c8

Browse files
phoflmroeschke
andcommitted
BUG: MultiIndex.get_indexer not matching nan values (pandas-dev#49442)
* BUG: MultiIndex.get_indexer not matching nan values * Fix tests * Use variable for shifts * Fix mypy * Add gh ref * Update pandas/_libs/index.pyx Co-authored-by: Matthew Roeschke <[email protected]> * Adress review Co-authored-by: Matthew Roeschke <[email protected]>
1 parent 1a6987e commit d7f25c8

File tree

7 files changed

+55
-35
lines changed

7 files changed

+55
-35
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ Missing
548548

549549
MultiIndex
550550
^^^^^^^^^^
551+
- Bug in :meth:`MultiIndex.get_indexer` not matching ``NaN`` values (:issue:`37222`)
551552
- Bug in :meth:`MultiIndex.argsort` raising ``TypeError`` when index contains :attr:`NA` (:issue:`48495`)
552553
- Bug in :meth:`MultiIndex.difference` losing extension array dtype (:issue:`48606`)
553554
- Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`)

pandas/_libs/index.pyx

+18-7
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ from pandas._libs.missing cimport (
3636
is_matching_na,
3737
)
3838

39+
# Defines shift of MultiIndex codes to avoid negative codes (missing values)
40+
multiindex_nulls_shift = 2
41+
3942

4043
cdef inline bint is_definitely_invalid_key(object val):
4144
try:
@@ -648,10 +651,13 @@ cdef class BaseMultiIndexCodesEngine:
648651
self.levels = levels
649652
self.offsets = offsets
650653

651-
# Transform labels in a single array, and add 1 so that we are working
652-
# with positive integers (-1 for NaN becomes 0):
653-
codes = (np.array(labels, dtype='int64').T + 1).astype('uint64',
654-
copy=False)
654+
# Transform labels in a single array, and add 2 so that we are working
655+
# with positive integers (-1 for NaN becomes 1). This enables us to
656+
# differentiate between values that are missing in other and matching
657+
# NaNs. We will set values that are not found to 0 later:
658+
labels_arr = np.array(labels, dtype='int64').T + multiindex_nulls_shift
659+
codes = labels_arr.astype('uint64', copy=False)
660+
self.level_has_nans = [-1 in lab for lab in labels]
655661

656662
# Map each codes combination in the index to an integer unambiguously
657663
# (no collisions possible), based on the "offsets", which describe the
@@ -680,8 +686,13 @@ cdef class BaseMultiIndexCodesEngine:
680686
Integers representing one combination each
681687
"""
682688
zt = [target._get_level_values(i) for i in range(target.nlevels)]
683-
level_codes = [lev.get_indexer_for(codes) + 1 for lev, codes
684-
in zip(self.levels, zt)]
689+
level_codes = []
690+
for i, (lev, codes) in enumerate(zip(self.levels, zt)):
691+
result = lev.get_indexer_for(codes) + 1
692+
result[result > 0] += 1
693+
if self.level_has_nans[i] and codes.hasnans:
694+
result[codes.isna()] += 1
695+
level_codes.append(result)
685696
return self._codes_to_ints(np.array(level_codes, dtype='uint64').T)
686697

687698
def get_indexer(self, target: np.ndarray) -> np.ndarray:
@@ -792,7 +803,7 @@ cdef class BaseMultiIndexCodesEngine:
792803
if not isinstance(key, tuple):
793804
raise KeyError(key)
794805
try:
795-
indices = [0 if checknull(v) else lev.get_loc(v) + 1
806+
indices = [1 if checknull(v) else lev.get_loc(v) + multiindex_nulls_shift
796807
for lev, v in zip(self.levels, key)]
797808
except KeyError:
798809
raise KeyError(key)

pandas/core/indexes/multi.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1077,8 +1077,18 @@ def set_codes(self, codes, *, level=None, verify_integrity: bool = True):
10771077
@cache_readonly
10781078
def _engine(self):
10791079
# Calculate the number of bits needed to represent labels in each
1080-
# level, as log2 of their sizes (including -1 for NaN):
1081-
sizes = np.ceil(np.log2([len(level) + 1 for level in self.levels]))
1080+
# level, as log2 of their sizes:
1081+
# NaN values are shifted to 1 and missing values in other while
1082+
# calculating the indexer are shifted to 0
1083+
sizes = np.ceil(
1084+
np.log2(
1085+
[
1086+
len(level)
1087+
+ libindex.multiindex_nulls_shift # type: ignore[attr-defined]
1088+
for level in self.levels
1089+
]
1090+
)
1091+
)
10821092

10831093
# Sum bit counts, starting from the _right_....
10841094
lev_bits = np.cumsum(sizes[::-1])[::-1]

pandas/tests/indexes/multi/test_drop.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -32,16 +32,16 @@ def test_drop(idx):
3232
tm.assert_index_equal(dropped, expected)
3333

3434
index = MultiIndex.from_tuples([("bar", "two")])
35-
with pytest.raises(KeyError, match=r"^10$"):
35+
with pytest.raises(KeyError, match=r"^15$"):
3636
idx.drop([("bar", "two")])
37-
with pytest.raises(KeyError, match=r"^10$"):
37+
with pytest.raises(KeyError, match=r"^15$"):
3838
idx.drop(index)
3939
with pytest.raises(KeyError, match=r"^'two'$"):
4040
idx.drop(["foo", "two"])
4141

4242
# partially correct argument
4343
mixed_index = MultiIndex.from_tuples([("qux", "one"), ("bar", "two")])
44-
with pytest.raises(KeyError, match=r"^10$"):
44+
with pytest.raises(KeyError, match=r"^15$"):
4545
idx.drop(mixed_index)
4646

4747
# error='ignore'

pandas/tests/indexes/multi/test_indexing.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,16 @@ def test_get_indexer_kwarg_validation(self):
471471
with pytest.raises(ValueError, match=msg):
472472
mi.get_indexer(mi[:-1], tolerance="piano")
473473

474+
def test_get_indexer_nan(self):
475+
# GH#37222
476+
idx1 = MultiIndex.from_product([["A"], [1.0, 2.0]], names=["id1", "id2"])
477+
idx2 = MultiIndex.from_product([["A"], [np.nan, 2.0]], names=["id1", "id2"])
478+
expected = np.array([-1, 1])
479+
result = idx2.get_indexer(idx1)
480+
tm.assert_numpy_array_equal(result, expected, check_dtype=False)
481+
result = idx1.get_indexer(idx2)
482+
tm.assert_numpy_array_equal(result, expected, check_dtype=False)
483+
474484

475485
def test_getitem(idx):
476486
# scalar
@@ -527,7 +537,7 @@ class TestGetLoc:
527537
def test_get_loc(self, idx):
528538
assert idx.get_loc(("foo", "two")) == 1
529539
assert idx.get_loc(("baz", "two")) == 3
530-
with pytest.raises(KeyError, match=r"^10$"):
540+
with pytest.raises(KeyError, match=r"^15$"):
531541
idx.get_loc(("bar", "two"))
532542
with pytest.raises(KeyError, match=r"^'quux'$"):
533543
idx.get_loc("quux")

pandas/tests/indexes/multi/test_setops.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -659,9 +659,8 @@ def test_union_keep_ea_dtype_with_na(any_numeric_ea_dtype):
659659
midx = MultiIndex.from_arrays([arr1, [2, 1]], names=["a", None])
660660
midx2 = MultiIndex.from_arrays([arr2, [1, 2]])
661661
result = midx.union(midx2)
662-
# Expected is actually off and should contain (1, 1) too. See GH#37222
663662
expected = MultiIndex.from_arrays(
664-
[Series([4, pd.NA, pd.NA], dtype=any_numeric_ea_dtype), [2, 1, 2]]
663+
[Series([1, 4, pd.NA, pd.NA], dtype=any_numeric_ea_dtype), [1, 2, 1, 2]]
665664
)
666665
tm.assert_index_equal(result, expected)
667666

pandas/tests/indexing/multiindex/test_multiindex.py

+9-20
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,10 @@ def test_rename_multiindex_with_duplicates(self):
162162
[1, 2],
163163
],
164164
[
165-
[(81.0, np.nan), (np.nan, np.nan)],
166-
[(81.0, np.nan), (np.nan, np.nan)],
167-
[1, 2],
168-
[1, 1],
165+
[[81, 82.0, np.nan], Series([np.nan, np.nan, np.nan])],
166+
[[81, 82.0, np.nan], Series([np.nan, np.nan, np.nan])],
167+
[1, np.nan, 2],
168+
[np.nan, 2, 1],
169169
],
170170
),
171171
(
@@ -176,8 +176,8 @@ def test_rename_multiindex_with_duplicates(self):
176176
[1, 2],
177177
],
178178
[
179-
[(81.0, np.nan), (np.nan, np.nan)],
180-
[(81.0, np.nan), (np.nan, np.nan)],
179+
[[81.0, np.nan], Series([np.nan, np.nan])],
180+
[[81.0, np.nan], Series([np.nan, np.nan])],
181181
[1, 2],
182182
[2, 1],
183183
],
@@ -188,28 +188,17 @@ def test_subtracting_two_series_with_unordered_index_and_all_nan_index(
188188
self, data_result, data_expected
189189
):
190190
# GH 38439
191+
# TODO: Refactor. This is impossible to understand GH#49443
191192
a_index_result = MultiIndex.from_tuples(data_result[0])
192193
b_index_result = MultiIndex.from_tuples(data_result[1])
193194
a_series_result = Series(data_result[2], index=a_index_result)
194195
b_series_result = Series(data_result[3], index=b_index_result)
195196
result = a_series_result.align(b_series_result)
196197

197-
a_index_expected = MultiIndex.from_tuples(data_expected[0])
198-
b_index_expected = MultiIndex.from_tuples(data_expected[1])
198+
a_index_expected = MultiIndex.from_arrays(data_expected[0])
199+
b_index_expected = MultiIndex.from_arrays(data_expected[1])
199200
a_series_expected = Series(data_expected[2], index=a_index_expected)
200201
b_series_expected = Series(data_expected[3], index=b_index_expected)
201-
a_series_expected.index = a_series_expected.index.set_levels(
202-
[
203-
a_series_expected.index.levels[0].astype("float"),
204-
a_series_expected.index.levels[1].astype("float"),
205-
]
206-
)
207-
b_series_expected.index = b_series_expected.index.set_levels(
208-
[
209-
b_series_expected.index.levels[0].astype("float"),
210-
b_series_expected.index.levels[1].astype("float"),
211-
]
212-
)
213202

214203
tm.assert_series_equal(result[0], a_series_expected)
215204
tm.assert_series_equal(result[1], b_series_expected)

0 commit comments

Comments
 (0)