Skip to content

Commit 0e9844a

Browse files
authored
ENH: Index set operations with sort=True (#51346)
1 parent 128fc9a commit 0e9844a

File tree

8 files changed

+61
-43
lines changed

8 files changed

+61
-43
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ Other enhancements
313313
- Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`)
314314
- Added new argument ``engine`` to :func:`read_json` to support parsing JSON with pyarrow by specifying ``engine="pyarrow"`` (:issue:`48893`)
315315
- Added support for SQLAlchemy 2.0 (:issue:`40686`)
316+
- :class:`Index` set operations :meth:`Index.union`, :meth:`Index.intersection`, :meth:`Index.difference`, and :meth:`Index.symmetric_difference` now support ``sort=True``, which will always return a sorted result, unlike the default ``sort=None`` which does not sort in some cases (:issue:`25151`)
316317
-
317318

318319
.. ---------------------------------------------------------------------------

pandas/core/indexes/base.py

+35-13
Original file line numberDiff line numberDiff line change
@@ -3023,10 +3023,10 @@ def _get_reconciled_name_object(self, other):
30233023

30243024
@final
30253025
def _validate_sort_keyword(self, sort):
3026-
if sort not in [None, False]:
3026+
if sort not in [None, False, True]:
30273027
raise ValueError(
30283028
"The 'sort' keyword only takes the values of "
3029-
f"None or False; {sort} was passed."
3029+
f"None, True, or False; {sort} was passed."
30303030
)
30313031

30323032
@final
@@ -3070,6 +3070,7 @@ def union(self, other, sort=None):
30703070
A RuntimeWarning is issued in this case.
30713071
30723072
* False : do not sort the result.
3073+
* True : Sort the result (which may raise TypeError).
30733074
30743075
Returns
30753076
-------
@@ -3154,10 +3155,16 @@ def union(self, other, sort=None):
31543155
elif not len(other) or self.equals(other):
31553156
# NB: whether this (and the `if not len(self)` check below) come before
31563157
# or after the is_dtype_equal check above affects the returned dtype
3157-
return self._get_reconciled_name_object(other)
3158+
result = self._get_reconciled_name_object(other)
3159+
if sort is True:
3160+
return result.sort_values()
3161+
return result
31583162

31593163
elif not len(self):
3160-
return other._get_reconciled_name_object(self)
3164+
result = other._get_reconciled_name_object(self)
3165+
if sort is True:
3166+
return result.sort_values()
3167+
return result
31613168

31623169
result = self._union(other, sort=sort)
31633170

@@ -3258,12 +3265,13 @@ def intersection(self, other, sort: bool = False):
32583265
Parameters
32593266
----------
32603267
other : Index or array-like
3261-
sort : False or None, default False
3268+
sort : True, False or None, default False
32623269
Whether to sort the resulting index.
32633270
3264-
* False : do not sort the result.
32653271
* None : sort the result, except when `self` and `other` are equal
32663272
or when the values cannot be compared.
3273+
* False : do not sort the result.
3274+
* True : Sort the result (which may raise TypeError).
32673275
32683276
Returns
32693277
-------
@@ -3285,8 +3293,12 @@ def intersection(self, other, sort: bool = False):
32853293

32863294
if self.equals(other):
32873295
if self.has_duplicates:
3288-
return self.unique()._get_reconciled_name_object(other)
3289-
return self._get_reconciled_name_object(other)
3296+
result = self.unique()._get_reconciled_name_object(other)
3297+
else:
3298+
result = self._get_reconciled_name_object(other)
3299+
if sort is True:
3300+
result = result.sort_values()
3301+
return result
32903302

32913303
if len(self) == 0 or len(other) == 0:
32923304
# fastpath; we need to be careful about having commutativity
@@ -3403,14 +3415,15 @@ def difference(self, other, sort=None):
34033415
Parameters
34043416
----------
34053417
other : Index or array-like
3406-
sort : False or None, default None
3418+
sort : bool or None, default None
34073419
Whether to sort the resulting index. By default, the
34083420
values are attempted to be sorted, but any TypeError from
34093421
incomparable elements is caught by pandas.
34103422
34113423
* None : Attempt to sort the result, but catch any TypeErrors
34123424
from comparing incomparable elements.
34133425
* False : Do not sort the result.
3426+
* True : Sort the result (which may raise TypeError).
34143427
34153428
Returns
34163429
-------
@@ -3439,11 +3452,17 @@ def difference(self, other, sort=None):
34393452

34403453
if len(other) == 0:
34413454
# Note: we do not (yet) sort even if sort=None GH#24959
3442-
return self.rename(result_name)
3455+
result = self.rename(result_name)
3456+
if sort is True:
3457+
return result.sort_values()
3458+
return result
34433459

34443460
if not self._should_compare(other):
34453461
# Nothing matches -> difference is everything
3446-
return self.rename(result_name)
3462+
result = self.rename(result_name)
3463+
if sort is True:
3464+
return result.sort_values()
3465+
return result
34473466

34483467
result = self._difference(other, sort=sort)
34493468
return self._wrap_difference_result(other, result)
@@ -3479,14 +3498,15 @@ def symmetric_difference(self, other, result_name=None, sort=None):
34793498
----------
34803499
other : Index or array-like
34813500
result_name : str
3482-
sort : False or None, default None
3501+
sort : bool or None, default None
34833502
Whether to sort the resulting index. By default, the
34843503
values are attempted to be sorted, but any TypeError from
34853504
incomparable elements is caught by pandas.
34863505
34873506
* None : Attempt to sort the result, but catch any TypeErrors
34883507
from comparing incomparable elements.
34893508
* False : Do not sort the result.
3509+
* True : Sort the result (which may raise TypeError).
34903510
34913511
Returns
34923512
-------
@@ -7162,10 +7182,12 @@ def unpack_nested_dtype(other: _IndexT) -> _IndexT:
71627182

71637183

71647184
def _maybe_try_sort(result, sort):
7165-
if sort is None:
7185+
if sort is not False:
71667186
try:
71677187
result = algos.safe_sort(result)
71687188
except TypeError as err:
7189+
if sort is True:
7190+
raise
71697191
warnings.warn(
71707192
f"{err}, sort order is undefined for incomparable objects.",
71717193
RuntimeWarning,

pandas/core/indexes/multi.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3560,10 +3560,12 @@ def _union(self, other, sort) -> MultiIndex:
35603560
else:
35613561
result = self._get_reconciled_name_object(other)
35623562

3563-
if sort is None:
3563+
if sort is not False:
35643564
try:
35653565
result = result.sort_values()
35663566
except TypeError:
3567+
if sort is True:
3568+
raise
35673569
warnings.warn(
35683570
"The values in the array are unorderable. "
35693571
"Pass `sort=False` to suppress this warning.",

pandas/core/indexes/range.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def _difference(self, other, sort=None):
688688
if not isinstance(other, RangeIndex):
689689
return super()._difference(other, sort=sort)
690690

691-
if sort is None and self.step < 0:
691+
if sort is not False and self.step < 0:
692692
return self[::-1]._difference(other)
693693

694694
res_name = ops.get_op_result_name(self, other)

pandas/tests/indexes/base_class/test_setops.py

+5-7
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,15 @@ class TestIndexSetOps:
1616
@pytest.mark.parametrize(
1717
"method", ["union", "intersection", "difference", "symmetric_difference"]
1818
)
19-
def test_setops_disallow_true(self, method):
19+
def test_setops_sort_validation(self, method):
2020
idx1 = Index(["a", "b"])
2121
idx2 = Index(["b", "c"])
2222

2323
with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
24-
getattr(idx1, method)(idx2, sort=True)
24+
getattr(idx1, method)(idx2, sort=2)
25+
26+
# sort=True is supported as of GH#??
27+
getattr(idx1, method)(idx2, sort=True)
2528

2629
def test_setops_preserve_object_dtype(self):
2730
idx = Index([1, 2, 3], dtype=object)
@@ -88,17 +91,12 @@ def test_union_sort_other_incomparable(self):
8891
result = idx.union(idx[:1], sort=False)
8992
tm.assert_index_equal(result, idx)
9093

91-
@pytest.mark.xfail(reason="GH#25151 need to decide on True behavior")
9294
def test_union_sort_other_incomparable_true(self):
93-
# TODO(GH#25151): decide on True behaviour
94-
# sort=True
9595
idx = Index([1, pd.Timestamp("2000")])
9696
with pytest.raises(TypeError, match=".*"):
9797
idx.union(idx[:1], sort=True)
9898

99-
@pytest.mark.xfail(reason="GH#25151 need to decide on True behavior")
10099
def test_intersection_equal_sort_true(self):
101-
# TODO(GH#25151): decide on True behaviour
102100
idx = Index(["c", "a", "b"])
103101
sorted_ = Index(["a", "b", "c"])
104102
tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_)

pandas/tests/indexes/multi/test_setops.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,6 @@ def test_difference_sort_special():
204204
tm.assert_index_equal(result, idx)
205205

206206

207-
@pytest.mark.xfail(reason="Not implemented.")
208207
def test_difference_sort_special_true():
209208
# TODO(GH#25151): decide on True behaviour
210209
idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
@@ -233,8 +232,10 @@ def test_difference_sort_incomparable_true():
233232
idx = MultiIndex.from_product([[1, pd.Timestamp("2000"), 2], ["a", "b"]])
234233
other = MultiIndex.from_product([[3, pd.Timestamp("2000"), 4], ["c", "d"]])
235234

236-
msg = "The 'sort' keyword only takes the values of None or False; True was passed."
237-
with pytest.raises(ValueError, match=msg):
235+
# TODO: this is raising in constructing a Categorical when calling
236+
# algos.safe_sort. Should we catch and re-raise with a better message?
237+
msg = "'values' is not ordered, please explicitly specify the categories order "
238+
with pytest.raises(TypeError, match=msg):
238239
idx.difference(other, sort=True)
239240

240241

@@ -344,12 +345,11 @@ def test_intersect_equal_sort():
344345
tm.assert_index_equal(idx.intersection(idx, sort=None), idx)
345346

346347

347-
@pytest.mark.xfail(reason="Not implemented.")
348348
def test_intersect_equal_sort_true():
349-
# TODO(GH#25151): decide on True behaviour
350349
idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
351-
sorted_ = MultiIndex.from_product([[0, 1], ["a", "b"]])
352-
tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_)
350+
expected = MultiIndex.from_product([[0, 1], ["a", "b"]])
351+
result = idx.intersection(idx, sort=True)
352+
tm.assert_index_equal(result, expected)
353353

354354

355355
@pytest.mark.parametrize("slice_", [slice(None), slice(0)])
@@ -366,7 +366,6 @@ def test_union_sort_other_empty(slice_):
366366
tm.assert_index_equal(idx.union(other, sort=False), idx)
367367

368368

369-
@pytest.mark.xfail(reason="Not implemented.")
370369
def test_union_sort_other_empty_sort():
371370
# TODO(GH#25151): decide on True behaviour
372371
# # sort=True
@@ -391,12 +390,10 @@ def test_union_sort_other_incomparable():
391390
tm.assert_index_equal(result, idx)
392391

393392

394-
@pytest.mark.xfail(reason="Not implemented.")
395393
def test_union_sort_other_incomparable_sort():
396-
# TODO(GH#25151): decide on True behaviour
397-
# # sort=True
398394
idx = MultiIndex.from_product([[1, pd.Timestamp("2000")], ["a", "b"]])
399-
with pytest.raises(TypeError, match="Cannot compare"):
395+
msg = "'<' not supported between instances of 'Timestamp' and 'int'"
396+
with pytest.raises(TypeError, match=msg):
400397
idx.union(idx[:1], sort=True)
401398

402399

@@ -435,12 +432,15 @@ def test_union_multiindex_empty_rangeindex():
435432
@pytest.mark.parametrize(
436433
"method", ["union", "intersection", "difference", "symmetric_difference"]
437434
)
438-
def test_setops_disallow_true(method):
435+
def test_setops_sort_validation(method):
439436
idx1 = MultiIndex.from_product([["a", "b"], [1, 2]])
440437
idx2 = MultiIndex.from_product([["b", "c"], [1, 2]])
441438

442439
with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
443-
getattr(idx1, method)(idx2, sort=True)
440+
getattr(idx1, method)(idx2, sort=2)
441+
442+
# sort=True is supported as of GH#?
443+
getattr(idx1, method)(idx2, sort=True)
444444

445445

446446
@pytest.mark.parametrize("val", [pd.NA, 100])

pandas/tests/indexes/numeric/test_setops.py

-3
Original file line numberDiff line numberDiff line change
@@ -143,11 +143,8 @@ def test_union_sort_other_special(self, slice_):
143143
# sort=False
144144
tm.assert_index_equal(idx.union(other, sort=False), idx)
145145

146-
@pytest.mark.xfail(reason="Not implemented")
147146
@pytest.mark.parametrize("slice_", [slice(None), slice(0)])
148147
def test_union_sort_special_true(self, slice_):
149-
# TODO(GH#25151): decide on True behaviour
150-
# sort=True
151148
idx = Index([1, 0, 2])
152149
# default, sort=None
153150
other = idx[slice_]

pandas/tests/indexes/test_setops.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -836,16 +836,14 @@ def test_difference_incomparable(self, opname):
836836
result = op(a)
837837
tm.assert_index_equal(result, expected)
838838

839-
@pytest.mark.xfail(reason="Not implemented")
840839
@pytest.mark.parametrize("opname", ["difference", "symmetric_difference"])
841840
def test_difference_incomparable_true(self, opname):
842-
# TODO(GH#25151): decide on True behaviour
843-
# # sort=True, raises
844841
a = Index([3, Timestamp("2000"), 1])
845842
b = Index([2, Timestamp("1999"), 1])
846843
op = operator.methodcaller(opname, b, sort=True)
847844

848-
with pytest.raises(TypeError, match="Cannot compare"):
845+
msg = "'<' not supported between instances of 'Timestamp' and 'int'"
846+
with pytest.raises(TypeError, match=msg):
849847
op(a)
850848

851849
def test_symmetric_difference_mi(self, sort):

0 commit comments

Comments
 (0)