Skip to content

Commit 44a4f16

Browse files
authored
BUG: MultiIndex.difference not keeping ea dtype (#48606)
* BUG: MultiIndex.difference not keeping ea dtype * Add asv * Add whatsnew * Reduce asv * Ad ea asv * Fix mypy * Add whatsnew * Fix mypy
1 parent 1209160 commit 44a4f16

File tree

5 files changed

+71
-9
lines changed

5 files changed

+71
-9
lines changed

asv_bench/benchmarks/multiindex_object.py

+39
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,45 @@ def time_operation(self, index_structure, dtype, method):
276276
getattr(self.left, method)(self.right)
277277

278278

279+
class Difference:
280+
281+
params = [
282+
("datetime", "int", "string", "ea_int"),
283+
]
284+
param_names = ["dtype"]
285+
286+
def setup(self, dtype):
287+
N = 10**4 * 2
288+
level1 = range(1000)
289+
290+
level2 = date_range(start="1/1/2000", periods=N // 1000)
291+
dates_left = MultiIndex.from_product([level1, level2])
292+
293+
level2 = range(N // 1000)
294+
int_left = MultiIndex.from_product([level1, level2])
295+
296+
level2 = Series(range(N // 1000), dtype="Int64")
297+
level2[0] = NA
298+
ea_int_left = MultiIndex.from_product([level1, level2])
299+
300+
level2 = tm.makeStringIndex(N // 1000).values
301+
str_left = MultiIndex.from_product([level1, level2])
302+
303+
data = {
304+
"datetime": dates_left,
305+
"int": int_left,
306+
"ea_int": ea_int_left,
307+
"string": str_left,
308+
}
309+
310+
data = {k: {"left": mi, "right": mi[:5]} for k, mi in data.items()}
311+
self.left = data[dtype]["left"]
312+
self.right = data[dtype]["right"]
313+
314+
def time_difference(self, dtype):
315+
self.left.difference(self.right)
316+
317+
279318
class Unique:
280319
params = [
281320
(("Int64", NA), ("int64", 0)),

doc/source/whatsnew/v1.6.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ Performance improvements
135135
- Performance improvement in :meth:`MultiIndex.argsort` and :meth:`MultiIndex.sort_values` (:issue:`48406`)
136136
- Performance improvement in :meth:`MultiIndex.size` (:issue:`48723`)
137137
- Performance improvement in :meth:`MultiIndex.union` without missing values and without duplicates (:issue:`48505`)
138+
- Performance improvement in :meth:`MultiIndex.difference` (:issue:`48606`)
138139
- Performance improvement in :meth:`.DataFrameGroupBy.mean`, :meth:`.SeriesGroupBy.mean`, :meth:`.DataFrameGroupBy.var`, and :meth:`.SeriesGroupBy.var` for extension array dtypes (:issue:`37493`)
139140
- Performance improvement in :meth:`MultiIndex.isin` when ``level=None`` (:issue:`48622`)
140141
- Performance improvement for :meth:`Series.value_counts` with nullable dtype (:issue:`48338`)
@@ -210,6 +211,7 @@ Missing
210211

211212
MultiIndex
212213
^^^^^^^^^^
214+
- Bug in :meth:`MultiIndex.difference` losing extension array dtype (:issue:`48606`)
213215
- Bug in :class:`MultiIndex.set_levels` raising ``IndexError`` when setting empty level (:issue:`48636`)
214216
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
215217
- Bug in :meth:`MultiIndex.intersection` losing extension array (:issue:`48604`)

pandas/core/indexes/base.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -3684,7 +3684,12 @@ def _difference(self, other, sort):
36843684
indexer = indexer.take((indexer != -1).nonzero()[0])
36853685

36863686
label_diff = np.setdiff1d(np.arange(this.size), indexer, assume_unique=True)
3687-
the_diff = this._values.take(label_diff)
3687+
3688+
the_diff: MultiIndex | ArrayLike
3689+
if isinstance(this, ABCMultiIndex):
3690+
the_diff = this.take(label_diff)
3691+
else:
3692+
the_diff = this._values.take(label_diff)
36883693
the_diff = _maybe_try_sort(the_diff, sort)
36893694

36903695
return the_diff

pandas/core/indexes/multi.py

+3-8
Original file line numberDiff line numberDiff line change
@@ -3746,18 +3746,13 @@ def _wrap_intersection_result(self, other, result) -> MultiIndex:
37463746
_, result_names = self._convert_can_do_setop(other)
37473747
return result.set_names(result_names)
37483748

3749-
def _wrap_difference_result(self, other, result) -> MultiIndex:
3749+
def _wrap_difference_result(self, other, result: MultiIndex) -> MultiIndex:
37503750
_, result_names = self._convert_can_do_setop(other)
37513751

37523752
if len(result) == 0:
3753-
return MultiIndex(
3754-
levels=[[]] * self.nlevels,
3755-
codes=[[]] * self.nlevels,
3756-
names=result_names,
3757-
verify_integrity=False,
3758-
)
3753+
return result.remove_unused_levels().set_names(result_names)
37593754
else:
3760-
return MultiIndex.from_tuples(result, sortorder=0, names=result_names)
3755+
return result.set_names(result_names)
37613756

37623757
def _convert_can_do_setop(self, other):
37633758
result_names = self.names

pandas/tests/indexes/multi/test_setops.py

+21
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,27 @@ def test_setops_disallow_true(method):
440440
getattr(idx1, method)(idx2, sort=True)
441441

442442

443+
@pytest.mark.parametrize("val", [pd.NA, 100])
444+
def test_difference_keep_ea_dtypes(any_numeric_ea_dtype, val):
445+
# GH#48606
446+
midx = MultiIndex.from_arrays(
447+
[Series([1, 2], dtype=any_numeric_ea_dtype), [2, 1]], names=["a", None]
448+
)
449+
midx2 = MultiIndex.from_arrays(
450+
[Series([1, 2, val], dtype=any_numeric_ea_dtype), [1, 1, 3]]
451+
)
452+
result = midx.difference(midx2)
453+
expected = MultiIndex.from_arrays([Series([1], dtype=any_numeric_ea_dtype), [2]])
454+
tm.assert_index_equal(result, expected)
455+
456+
result = midx.difference(midx.sort_values(ascending=False))
457+
expected = MultiIndex.from_arrays(
458+
[Series([], dtype=any_numeric_ea_dtype), Series([], dtype=int)],
459+
names=["a", None],
460+
)
461+
tm.assert_index_equal(result, expected)
462+
463+
443464
@pytest.mark.parametrize("val", [pd.NA, 5])
444465
def test_symmetric_difference_keeping_ea_dtype(any_numeric_ea_dtype, val):
445466
# GH#48607

0 commit comments

Comments
 (0)