Skip to content

Commit 3cb83d3

Browse files
lukemanleynoatamir
authored andcommitted
BUG/PERF: MultiIndex setops with sort=None (pandas-dev#49010)
* perf: algos.safe_sort with multiindex * add sort to multiindex setop asv * fix asv * whatsnew * update test_union_nan_got_duplicated * add test for sort bug * parameterize dtype in test
1 parent 9ea7080 commit 3cb83d3

File tree

4 files changed

+27
-37
lines changed

4 files changed

+27
-37
lines changed

asv_bench/benchmarks/multiindex_object.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -239,10 +239,11 @@ class SetOperations:
239239
("monotonic", "non_monotonic"),
240240
("datetime", "int", "string", "ea_int"),
241241
("intersection", "union", "symmetric_difference"),
242+
(False, None),
242243
]
243-
param_names = ["index_structure", "dtype", "method"]
244+
param_names = ["index_structure", "dtype", "method", "sort"]
244245

245-
def setup(self, index_structure, dtype, method):
246+
def setup(self, index_structure, dtype, method, sort):
246247
N = 10**5
247248
level1 = range(1000)
248249

@@ -272,8 +273,8 @@ def setup(self, index_structure, dtype, method):
272273
self.left = data[dtype]["left"]
273274
self.right = data[dtype]["right"]
274275

275-
def time_operation(self, index_structure, dtype, method):
276-
getattr(self.left, method)(self.right)
276+
def time_operation(self, index_structure, dtype, method, sort):
277+
getattr(self.left, method)(self.right, sort=sort)
277278

278279

279280
class Difference:

doc/source/whatsnew/v1.6.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ Performance improvements
141141
- Performance improvement in :meth:`MultiIndex.size` (:issue:`48723`)
142142
- Performance improvement in :meth:`MultiIndex.union` without missing values and without duplicates (:issue:`48505`)
143143
- Performance improvement in :meth:`MultiIndex.difference` (:issue:`48606`)
144+
- Performance improvement in :class:`MultiIndex` set operations with sort=None (:issue:`49010`)
144145
- Performance improvement in :meth:`.DataFrameGroupBy.mean`, :meth:`.SeriesGroupBy.mean`, :meth:`.DataFrameGroupBy.var`, and :meth:`.SeriesGroupBy.var` for extension array dtypes (:issue:`37493`)
145146
- Performance improvement in :meth:`MultiIndex.isin` when ``level=None`` (:issue:`48622`)
146147
- Performance improvement in :meth:`Index.union` and :meth:`MultiIndex.union` when index contains duplicates (:issue:`48900`)
@@ -230,6 +231,7 @@ MultiIndex
230231
- Bug in :meth:`MultiIndex.unique` losing extension array dtype (:issue:`48335`)
231232
- Bug in :meth:`MultiIndex.intersection` losing extension array (:issue:`48604`)
232233
- Bug in :meth:`MultiIndex.union` losing extension array (:issue:`48498`, :issue:`48505`, :issue:`48900`)
234+
- Bug in :meth:`MultiIndex.union` not sorting when sort=None and index contains missing values (:issue:`49010`)
233235
- Bug in :meth:`MultiIndex.append` not checking names for equality (:issue:`48288`)
234236
- Bug in :meth:`MultiIndex.symmetric_difference` losing extension array (:issue:`48607`)
235237
-

pandas/core/algorithms.py

+6-27
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Sequence,
1515
cast,
1616
final,
17-
overload,
1817
)
1918
import warnings
2019

@@ -1816,10 +1815,8 @@ def safe_sort(
18161815
raise TypeError(
18171816
"Only list-like objects are allowed to be passed to safe_sort as values"
18181817
)
1819-
original_values = values
1820-
is_mi = isinstance(original_values, ABCMultiIndex)
18211818

1822-
if not isinstance(values, (np.ndarray, ABCExtensionArray)):
1819+
if not isinstance(values, (np.ndarray, ABCExtensionArray, ABCMultiIndex)):
18231820
# don't convert to string types
18241821
dtype, _ = infer_dtype_from_array(values)
18251822
# error: Argument "dtype" to "asarray" has incompatible type "Union[dtype[Any],
@@ -1839,17 +1836,13 @@ def safe_sort(
18391836
else:
18401837
try:
18411838
sorter = values.argsort()
1842-
if is_mi:
1843-
# Operate on original object instead of casted array (MultiIndex)
1844-
ordered = original_values.take(sorter)
1845-
else:
1846-
ordered = values.take(sorter)
1839+
ordered = values.take(sorter)
18471840
except TypeError:
18481841
# Previous sorters failed or were not applicable, try `_sort_mixed`
18491842
# which would work, but which fails for special case of 1d arrays
18501843
# with tuples.
18511844
if values.size and isinstance(values[0], tuple):
1852-
ordered = _sort_tuples(values, original_values)
1845+
ordered = _sort_tuples(values)
18531846
else:
18541847
ordered = _sort_mixed(values)
18551848

@@ -1912,33 +1905,19 @@ def _sort_mixed(values) -> np.ndarray:
19121905
)
19131906

19141907

1915-
@overload
1916-
def _sort_tuples(values: np.ndarray, original_values: np.ndarray) -> np.ndarray:
1917-
...
1918-
1919-
1920-
@overload
1921-
def _sort_tuples(values: np.ndarray, original_values: MultiIndex) -> MultiIndex:
1922-
...
1923-
1924-
1925-
def _sort_tuples(
1926-
values: np.ndarray, original_values: np.ndarray | MultiIndex
1927-
) -> np.ndarray | MultiIndex:
1908+
def _sort_tuples(values: np.ndarray) -> np.ndarray:
19281909
"""
1929-
Convert array of tuples (1d) to array or array (2d).
1910+
Convert array of tuples (1d) to array of arrays (2d).
19301911
We need to keep the columns separately as they contain different types and
19311912
nans (can't use `np.sort` as it may fail when str and nan are mixed in a
19321913
column as types cannot be compared).
1933-
We have to apply the indexer to the original values to keep the dtypes in
1934-
case of MultiIndexes
19351914
"""
19361915
from pandas.core.internals.construction import to_arrays
19371916
from pandas.core.sorting import lexsort_indexer
19381917

19391918
arrays, _ = to_arrays(values, None)
19401919
indexer = lexsort_indexer(arrays, orders=True)
1941-
return original_values[indexer]
1920+
return values[indexer]
19421921

19431922

19441923
def union_with_duplicates(

pandas/tests/indexes/multi/test_setops.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -550,12 +550,20 @@ def test_intersection_with_missing_values_on_both_sides(nulls_fixture):
550550
tm.assert_index_equal(result, expected)
551551

552552

553-
def test_union_nan_got_duplicated():
554-
# GH#38977
555-
mi1 = MultiIndex.from_arrays([[1.0, np.nan], [2, 3]])
556-
mi2 = MultiIndex.from_arrays([[1.0, np.nan, 3.0], [2, 3, 4]])
557-
result = mi1.union(mi2)
558-
tm.assert_index_equal(result, mi2)
553+
@pytest.mark.parametrize("dtype", ["float64", "Float64"])
554+
@pytest.mark.parametrize("sort", [None, False])
555+
def test_union_nan_got_duplicated(dtype, sort):
556+
# GH#38977, GH#49010
557+
mi1 = MultiIndex.from_arrays([pd.array([1.0, np.nan], dtype=dtype), [2, 3]])
558+
mi2 = MultiIndex.from_arrays([pd.array([1.0, np.nan, 3.0], dtype=dtype), [2, 3, 4]])
559+
result = mi1.union(mi2, sort=sort)
560+
if sort is None:
561+
expected = MultiIndex.from_arrays(
562+
[pd.array([1.0, 3.0, np.nan], dtype=dtype), [2, 4, 3]]
563+
)
564+
else:
565+
expected = mi2
566+
tm.assert_index_equal(result, expected)
559567

560568

561569
@pytest.mark.parametrize("val", [4, 1])

0 commit comments

Comments
 (0)