Skip to content

ENH: Index set operations with sort=True #51346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ Other enhancements
- Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`)
- Added new argument ``engine`` to :func:`read_json` to support parsing JSON with pyarrow by specifying ``engine="pyarrow"`` (:issue:`48893`)
- Added support for SQLAlchemy 2.0 (:issue:`40686`)
- :class:`Index` set operations :meth:`Index.union`, :meth:`Index.intersection`, :meth:`Index.difference`, and :meth:`Index.symmetric_difference` now support ``sort=True``, which will always return a sorted result, unlike the default ``sort=None`` which does not sort in some cases (:issue:`25151`)
-

.. ---------------------------------------------------------------------------
Expand Down
48 changes: 35 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3023,10 +3023,10 @@ def _get_reconciled_name_object(self, other):

@final
def _validate_sort_keyword(self, sort):
if sort not in [None, False]:
if sort not in [None, False, True]:
raise ValueError(
"The 'sort' keyword only takes the values of "
f"None or False; {sort} was passed."
f"None, True, or False; {sort} was passed."
)

@final
Expand Down Expand Up @@ -3070,6 +3070,7 @@ def union(self, other, sort=None):
A RuntimeWarning is issued in this case.

* False : do not sort the result.
* True : Sort the result (which may raise TypeError).

Returns
-------
Expand Down Expand Up @@ -3154,10 +3155,16 @@ def union(self, other, sort=None):
elif not len(other) or self.equals(other):
# NB: whether this (and the `if not len(self)` check below) come before
# or after the is_dtype_equal check above affects the returned dtype
return self._get_reconciled_name_object(other)
result = self._get_reconciled_name_object(other)
if sort is True:
return result.sort_values()
return result

elif not len(self):
return other._get_reconciled_name_object(self)
result = other._get_reconciled_name_object(self)
if sort is True:
return result.sort_values()
return result

result = self._union(other, sort=sort)

Expand Down Expand Up @@ -3258,12 +3265,13 @@ def intersection(self, other, sort: bool = False):
Parameters
----------
other : Index or array-like
sort : False or None, default False
sort : True, False or None, default False
Whether to sort the resulting index.

* False : do not sort the result.
* None : sort the result, except when `self` and `other` are equal
or when the values cannot be compared.
* False : do not sort the result.
* True : Sort the result (which may raise TypeError).

Returns
-------
Expand All @@ -3285,8 +3293,12 @@ def intersection(self, other, sort: bool = False):

if self.equals(other):
if self.has_duplicates:
return self.unique()._get_reconciled_name_object(other)
return self._get_reconciled_name_object(other)
result = self.unique()._get_reconciled_name_object(other)
else:
result = self._get_reconciled_name_object(other)
if sort is True:
result = result.sort_values()
return result

if len(self) == 0 or len(other) == 0:
# fastpath; we need to be careful about having commutativity
Expand Down Expand Up @@ -3403,14 +3415,15 @@ def difference(self, other, sort=None):
Parameters
----------
other : Index or array-like
sort : False or None, default None
sort : bool or None, default None
Whether to sort the resulting index. By default, the
values are attempted to be sorted, but any TypeError from
incomparable elements is caught by pandas.

* None : Attempt to sort the result, but catch any TypeErrors
from comparing incomparable elements.
* False : Do not sort the result.
* True : Sort the result (which may raise TypeError).

Returns
-------
Expand Down Expand Up @@ -3439,11 +3452,17 @@ def difference(self, other, sort=None):

if len(other) == 0:
# Note: we do not (yet) sort even if sort=None GH#24959
return self.rename(result_name)
result = self.rename(result_name)
if sort is True:
return result.sort_values()
return result

if not self._should_compare(other):
# Nothing matches -> difference is everything
return self.rename(result_name)
result = self.rename(result_name)
if sort is True:
return result.sort_values()
return result

result = self._difference(other, sort=sort)
return self._wrap_difference_result(other, result)
Expand Down Expand Up @@ -3479,14 +3498,15 @@ def symmetric_difference(self, other, result_name=None, sort=None):
----------
other : Index or array-like
result_name : str
sort : False or None, default None
sort : bool or None, default None
Whether to sort the resulting index. By default, the
values are attempted to be sorted, but any TypeError from
incomparable elements is caught by pandas.

* None : Attempt to sort the result, but catch any TypeErrors
from comparing incomparable elements.
* False : Do not sort the result.
* True : Sort the result (which may raise TypeError).

Returns
-------
Expand Down Expand Up @@ -7161,10 +7181,12 @@ def unpack_nested_dtype(other: _IndexT) -> _IndexT:


def _maybe_try_sort(result, sort):
if sort is None:
if sort is not False:
try:
result = algos.safe_sort(result)
except TypeError as err:
if sort is True:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have tests for these "stricter" raising behavior?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, they’re xfailed in main

raise
warnings.warn(
f"{err}, sort order is undefined for incomparable objects.",
RuntimeWarning,
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3560,10 +3560,12 @@ def _union(self, other, sort) -> MultiIndex:
else:
result = self._get_reconciled_name_object(other)

if sort is None:
if sort is not False:
try:
result = result.sort_values()
except TypeError:
if sort is True:
raise
warnings.warn(
"The values in the array are unorderable. "
"Pass `sort=False` to suppress this warning.",
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def _difference(self, other, sort=None):
if not isinstance(other, RangeIndex):
return super()._difference(other, sort=sort)

if sort is None and self.step < 0:
if sort is not False and self.step < 0:
return self[::-1]._difference(other)

res_name = ops.get_op_result_name(self, other)
Expand Down
12 changes: 5 additions & 7 deletions pandas/tests/indexes/base_class/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ class TestIndexSetOps:
@pytest.mark.parametrize(
"method", ["union", "intersection", "difference", "symmetric_difference"]
)
def test_setops_disallow_true(self, method):
def test_setops_sort_validation(self, method):
idx1 = Index(["a", "b"])
idx2 = Index(["b", "c"])

with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
getattr(idx1, method)(idx2, sort=True)
getattr(idx1, method)(idx2, sort=2)

# sort=True is supported as of GH#??
getattr(idx1, method)(idx2, sort=True)

def test_setops_preserve_object_dtype(self):
idx = Index([1, 2, 3], dtype=object)
Expand Down Expand Up @@ -88,17 +91,12 @@ def test_union_sort_other_incomparable(self):
result = idx.union(idx[:1], sort=False)
tm.assert_index_equal(result, idx)

@pytest.mark.xfail(reason="GH#25151 need to decide on True behavior")
def test_union_sort_other_incomparable_true(self):
# TODO(GH#25151): decide on True behaviour
# sort=True
idx = Index([1, pd.Timestamp("2000")])
with pytest.raises(TypeError, match=".*"):
idx.union(idx[:1], sort=True)

@pytest.mark.xfail(reason="GH#25151 need to decide on True behavior")
def test_intersection_equal_sort_true(self):
# TODO(GH#25151): decide on True behaviour
idx = Index(["c", "a", "b"])
sorted_ = Index(["a", "b", "c"])
tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_)
Expand Down
28 changes: 14 additions & 14 deletions pandas/tests/indexes/multi/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,6 @@ def test_difference_sort_special():
tm.assert_index_equal(result, idx)


@pytest.mark.xfail(reason="Not implemented.")
def test_difference_sort_special_true():
# TODO(GH#25151): decide on True behaviour
idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
Expand Down Expand Up @@ -233,8 +232,10 @@ def test_difference_sort_incomparable_true():
idx = MultiIndex.from_product([[1, pd.Timestamp("2000"), 2], ["a", "b"]])
other = MultiIndex.from_product([[3, pd.Timestamp("2000"), 4], ["c", "d"]])

msg = "The 'sort' keyword only takes the values of None or False; True was passed."
with pytest.raises(ValueError, match=msg):
# TODO: this is raising in constructing a Categorical when calling
# algos.safe_sort. Should we catch and re-raise with a better message?
msg = "'values' is not ordered, please explicitly specify the categories order "
with pytest.raises(TypeError, match=msg):
idx.difference(other, sort=True)


Expand Down Expand Up @@ -344,12 +345,11 @@ def test_intersect_equal_sort():
tm.assert_index_equal(idx.intersection(idx, sort=None), idx)


@pytest.mark.xfail(reason="Not implemented.")
def test_intersect_equal_sort_true():
# TODO(GH#25151): decide on True behaviour
idx = MultiIndex.from_product([[1, 0], ["a", "b"]])
sorted_ = MultiIndex.from_product([[0, 1], ["a", "b"]])
tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_)
expected = MultiIndex.from_product([[0, 1], ["a", "b"]])
result = idx.intersection(idx, sort=True)
tm.assert_index_equal(result, expected)


@pytest.mark.parametrize("slice_", [slice(None), slice(0)])
Expand All @@ -366,7 +366,6 @@ def test_union_sort_other_empty(slice_):
tm.assert_index_equal(idx.union(other, sort=False), idx)


@pytest.mark.xfail(reason="Not implemented.")
def test_union_sort_other_empty_sort():
# TODO(GH#25151): decide on True behaviour
# # sort=True
Expand All @@ -391,12 +390,10 @@ def test_union_sort_other_incomparable():
tm.assert_index_equal(result, idx)


@pytest.mark.xfail(reason="Not implemented.")
def test_union_sort_other_incomparable_sort():
# TODO(GH#25151): decide on True behaviour
# # sort=True
idx = MultiIndex.from_product([[1, pd.Timestamp("2000")], ["a", "b"]])
with pytest.raises(TypeError, match="Cannot compare"):
msg = "'<' not supported between instances of 'Timestamp' and 'int'"
with pytest.raises(TypeError, match=msg):
idx.union(idx[:1], sort=True)


Expand Down Expand Up @@ -435,12 +432,15 @@ def test_union_multiindex_empty_rangeindex():
@pytest.mark.parametrize(
"method", ["union", "intersection", "difference", "symmetric_difference"]
)
def test_setops_disallow_true(method):
def test_setops_sort_validation(method):
idx1 = MultiIndex.from_product([["a", "b"], [1, 2]])
idx2 = MultiIndex.from_product([["b", "c"], [1, 2]])

with pytest.raises(ValueError, match="The 'sort' keyword only takes"):
getattr(idx1, method)(idx2, sort=True)
getattr(idx1, method)(idx2, sort=2)

# sort=True is supported as of GH#?
getattr(idx1, method)(idx2, sort=True)


@pytest.mark.parametrize("val", [pd.NA, 100])
Expand Down
3 changes: 0 additions & 3 deletions pandas/tests/indexes/numeric/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,8 @@ def test_union_sort_other_special(self, slice_):
# sort=False
tm.assert_index_equal(idx.union(other, sort=False), idx)

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.parametrize("slice_", [slice(None), slice(0)])
def test_union_sort_special_true(self, slice_):
# TODO(GH#25151): decide on True behaviour
# sort=True
idx = Index([1, 0, 2])
# default, sort=None
other = idx[slice_]
Expand Down
6 changes: 2 additions & 4 deletions pandas/tests/indexes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,16 +836,14 @@ def test_difference_incomparable(self, opname):
result = op(a)
tm.assert_index_equal(result, expected)

@pytest.mark.xfail(reason="Not implemented")
@pytest.mark.parametrize("opname", ["difference", "symmetric_difference"])
def test_difference_incomparable_true(self, opname):
# TODO(GH#25151): decide on True behaviour
# # sort=True, raises
a = Index([3, Timestamp("2000"), 1])
b = Index([2, Timestamp("1999"), 1])
op = operator.methodcaller(opname, b, sort=True)

with pytest.raises(TypeError, match="Cannot compare"):
msg = "'<' not supported between instances of 'Timestamp' and 'int'"
with pytest.raises(TypeError, match=msg):
op(a)

def test_symmetric_difference_mi(self, sort):
Expand Down