Skip to content

Commit 438b957

Browse files
authored
ENH: MultiIndex.intersection now keeping EA dtypes (#48604)
* ENH: MultiIndex.intersection now keeping EA dtypes * Improve performance * Improve performance * Improve performance * Add test and whatsnew * Add gh ref * Fix typing * Fix typing
1 parent 3c0215d commit 438b957

File tree

6 files changed

+57
-25
lines changed

6 files changed

+57
-25
lines changed

asv_bench/benchmarks/index_object.py

+18-8
Original file line numberDiff line numberDiff line change
@@ -19,28 +19,38 @@
1919
class SetOperations:
2020

2121
params = (
22-
["datetime", "date_string", "int", "strings"],
22+
["monotonic", "non_monotonic"],
23+
["datetime", "date_string", "int", "strings", "ea_int"],
2324
["intersection", "union", "symmetric_difference"],
2425
)
25-
param_names = ["dtype", "method"]
26+
param_names = ["index_structure", "dtype", "method"]
2627

27-
def setup(self, dtype, method):
28+
def setup(self, index_structure, dtype, method):
2829
N = 10**5
2930
dates_left = date_range("1/1/2000", periods=N, freq="T")
3031
fmt = "%Y-%m-%d %H:%M:%S"
3132
date_str_left = Index(dates_left.strftime(fmt))
3233
int_left = Index(np.arange(N))
34+
ea_int_left = Index(np.arange(N), dtype="Int64")
3335
str_left = tm.makeStringIndex(N)
36+
3437
data = {
35-
"datetime": {"left": dates_left, "right": dates_left[:-1]},
36-
"date_string": {"left": date_str_left, "right": date_str_left[:-1]},
37-
"int": {"left": int_left, "right": int_left[:-1]},
38-
"strings": {"left": str_left, "right": str_left[:-1]},
38+
"datetime": dates_left,
39+
"date_string": date_str_left,
40+
"int": int_left,
41+
"strings": str_left,
42+
"ea_int": ea_int_left,
3943
}
44+
45+
if index_structure == "non_monotonic":
46+
data = {k: mi[::-1] for k, mi in data.items()}
47+
48+
data = {k: {"left": idx, "right": idx[:-1]} for k, idx in data.items()}
49+
4050
self.left = data[dtype]["left"]
4151
self.right = data[dtype]["right"]
4252

43-
def time_operation(self, dtype, method):
53+
def time_operation(self, index_structure, dtype, method):
4454
getattr(self.left, method)(self.right)
4555

4656

asv_bench/benchmarks/multiindex_object.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ class SetOperations:
237237

238238
params = [
239239
("monotonic", "non_monotonic"),
240-
("datetime", "int", "string"),
240+
("datetime", "int", "string", "ea_int"),
241241
("intersection", "union", "symmetric_difference"),
242242
]
243243
param_names = ["index_structure", "dtype", "method"]
@@ -255,10 +255,14 @@ def setup(self, index_structure, dtype, method):
255255
level2 = tm.makeStringIndex(N // 1000).values
256256
str_left = MultiIndex.from_product([level1, level2])
257257

258+
level2 = range(N // 1000)
259+
ea_int_left = MultiIndex.from_product([level1, Series(level2, dtype="Int64")])
260+
258261
data = {
259262
"datetime": dates_left,
260263
"int": int_left,
261264
"string": str_left,
265+
"ea_int": ea_int_left,
262266
}
263267

264268
if index_structure == "non_monotonic":

doc/source/whatsnew/v1.6.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ Performance improvements
112112
- Performance improvement for :class:`Series` constructor passing integer numpy array with nullable dtype (:issue:`48338`)
113113
- Performance improvement in :meth:`DataFrame.loc` and :meth:`Series.loc` for tuple-based indexing of a :class:`MultiIndex` (:issue:`48384`)
114114
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
115+
- Performance improvement for :meth:`MultiIndex.intersection` (:issue:`48604`)
115116
- Performance improvement in ``var`` for nullable dtypes (:issue:`48379`).
116117
- Performance improvement to :func:`read_sas` with ``blank_missing=True`` (:issue:`48502`)
117118
-
@@ -179,6 +180,7 @@ MultiIndex
179180
^^^^^^^^^^
180181
- Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`)
181182
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
183+
- Bug in :meth:`MultiIndex.intersection` losing extension array (:issue:`48604`)
182184
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`)
183185
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
184186
- Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`)

pandas/core/indexes/base.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -3559,13 +3559,19 @@ def _intersection(self, other: Index, sort: bool = False):
35593559
and self._can_use_libjoin
35603560
):
35613561
try:
3562-
result = self._inner_indexer(other)[0]
3562+
res_indexer, indexer, _ = self._inner_indexer(other)
35633563
except TypeError:
35643564
# non-comparable; should only be for object dtype
35653565
pass
35663566
else:
35673567
# TODO: algos.unique1d should preserve DTA/TDA
3568-
res = algos.unique1d(result)
3568+
if self.is_numeric():
3569+
# This is faster, because Index.unique() checks for uniqueness
3570+
# before calculating the unique values.
3571+
res = algos.unique1d(res_indexer)
3572+
else:
3573+
result = self.take(indexer)
3574+
res = result.drop_duplicates()
35693575
return ensure_wrapped_if_datetimelike(res)
35703576

35713577
res_values = self._intersection_via_get_indexer(other, sort=sort)
@@ -3577,7 +3583,9 @@ def _wrap_intersection_result(self, other, result):
35773583
return self._wrap_setop_result(other, result)
35783584

35793585
@final
3580-
def _intersection_via_get_indexer(self, other: Index, sort) -> ArrayLike:
3586+
def _intersection_via_get_indexer(
3587+
self, other: Index | MultiIndex, sort
3588+
) -> ArrayLike | MultiIndex:
35813589
"""
35823590
Find the intersection of two Indexes using get_indexer.
35833591
@@ -3600,7 +3608,10 @@ def _intersection_via_get_indexer(self, other: Index, sort) -> ArrayLike:
36003608
# unnecessary in the case with sort=None bc we will sort later
36013609
taker = np.sort(taker)
36023610

3603-
result = left_unique.take(taker)._values
3611+
if isinstance(left_unique, ABCMultiIndex):
3612+
result = left_unique.take(taker)
3613+
else:
3614+
result = left_unique.take(taker)._values
36043615
return result
36053616

36063617
@final

pandas/core/indexes/multi.py

+1-10
Original file line numberDiff line numberDiff line change
@@ -3725,16 +3725,7 @@ def _maybe_match_names(self, other):
37253725

37263726
def _wrap_intersection_result(self, other, result) -> MultiIndex:
37273727
_, result_names = self._convert_can_do_setop(other)
3728-
3729-
if len(result) == 0:
3730-
return MultiIndex(
3731-
levels=self.levels,
3732-
codes=[[]] * self.nlevels,
3733-
names=result_names,
3734-
verify_integrity=False,
3735-
)
3736-
else:
3737-
return MultiIndex.from_arrays(zip(*result), sortorder=0, names=result_names)
3728+
return result.set_names(result_names)
37383729

37393730
def _wrap_difference_result(self, other, result) -> MultiIndex:
37403731
_, result_names = self._convert_can_do_setop(other)

pandas/tests/indexes/multi/test_setops.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,7 @@ def test_intersection_with_missing_values_on_both_sides(nulls_fixture):
525525
mi1 = MultiIndex.from_arrays([[3, nulls_fixture, 4, nulls_fixture], [1, 2, 4, 2]])
526526
mi2 = MultiIndex.from_arrays([[3, nulls_fixture, 3], [1, 2, 4]])
527527
result = mi1.intersection(mi2)
528-
expected = MultiIndex.from_arrays([[3.0, nulls_fixture], [1, 2]])
528+
expected = MultiIndex.from_arrays([[3, nulls_fixture], [1, 2]])
529529
tm.assert_index_equal(result, expected)
530530

531531

@@ -631,4 +631,18 @@ def test_intersection_lexsort_depth(levels1, levels2, codes1, codes2, names):
631631
mi_int = mi1.intersection(mi2)
632632

633633
with tm.assert_produces_warning(FutureWarning, match="MultiIndex.lexsort_depth"):
634-
assert mi_int.lexsort_depth == 0
634+
assert mi_int.lexsort_depth == 2
635+
636+
637+
@pytest.mark.parametrize("val", [pd.NA, 100])
638+
def test_intersection_keep_ea_dtypes(val, any_numeric_ea_dtype):
639+
# GH#48604
640+
midx = MultiIndex.from_arrays(
641+
[Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
642+
)
643+
midx2 = MultiIndex.from_arrays(
644+
[Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
645+
)
646+
result = midx.intersection(midx2)
647+
expected = MultiIndex.from_arrays([Series([2], dtype=any_numeric_ea_dtype), [1]])
648+
tm.assert_index_equal(result, expected)

0 commit comments

Comments
 (0)