From 8f506d50d3e907d5a50f6fea3269f3d95fb6b57c Mon Sep 17 00:00:00 2001 From: Brock Date: Sat, 11 Feb 2023 17:20:14 -0800 Subject: [PATCH] ENH: Index set operations with sort=True --- doc/source/whatsnew/v2.0.0.rst | 1 + pandas/core/indexes/base.py | 48 ++++++++++++++----- pandas/core/indexes/multi.py | 4 +- pandas/core/indexes/range.py | 2 +- .../tests/indexes/base_class/test_setops.py | 12 ++--- pandas/tests/indexes/multi/test_setops.py | 28 +++++------ pandas/tests/indexes/numeric/test_setops.py | 3 -- pandas/tests/indexes/test_setops.py | 6 +-- 8 files changed, 61 insertions(+), 43 deletions(-) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 29f360e050548..a393e0dca4c79 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -312,6 +312,7 @@ Other enhancements - Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`) - Added new argument ``engine`` to :func:`read_json` to support parsing JSON with pyarrow by specifying ``engine="pyarrow"`` (:issue:`48893`) - Added support for SQLAlchemy 2.0 (:issue:`40686`) +- :class:`Index` set operations :meth:`Index.union`, :meth:`Index.intersection`, :meth:`Index.difference`, and :meth:`Index.symmetric_difference` now support ``sort=True``, which will always return a sorted result, unlike the default ``sort=None`` which does not sort in some cases (:issue:`25151`) - .. --------------------------------------------------------------------------- diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 363bfe76d40fb..c0b280c760db4 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -3023,10 +3023,10 @@ def _get_reconciled_name_object(self, other): @final def _validate_sort_keyword(self, sort): - if sort not in [None, False]: + if sort not in [None, False, True]: raise ValueError( "The 'sort' keyword only takes the values of " - f"None or False; {sort} was passed." + f"None, True, or False; {sort} was passed." ) @final @@ -3070,6 +3070,7 @@ def union(self, other, sort=None): A RuntimeWarning is issued in this case. * False : do not sort the result. + * True : Sort the result (which may raise TypeError). Returns ------- @@ -3154,10 +3155,16 @@ def union(self, other, sort=None): elif not len(other) or self.equals(other): # NB: whether this (and the `if not len(self)` check below) come before # or after the is_dtype_equal check above affects the returned dtype - return self._get_reconciled_name_object(other) + result = self._get_reconciled_name_object(other) + if sort is True: + return result.sort_values() + return result elif not len(self): - return other._get_reconciled_name_object(self) + result = other._get_reconciled_name_object(self) + if sort is True: + return result.sort_values() + return result result = self._union(other, sort=sort) @@ -3258,12 +3265,13 @@ def intersection(self, other, sort: bool = False): Parameters ---------- other : Index or array-like - sort : False or None, default False + sort : True, False or None, default False Whether to sort the resulting index. - * False : do not sort the result. * None : sort the result, except when `self` and `other` are equal or when the values cannot be compared. + * False : do not sort the result. + * True : Sort the result (which may raise TypeError). Returns ------- @@ -3285,8 +3293,12 @@ def intersection(self, other, sort: bool = False): if self.equals(other): if self.has_duplicates: - return self.unique()._get_reconciled_name_object(other) - return self._get_reconciled_name_object(other) + result = self.unique()._get_reconciled_name_object(other) + else: + result = self._get_reconciled_name_object(other) + if sort is True: + result = result.sort_values() + return result if len(self) == 0 or len(other) == 0: # fastpath; we need to be careful about having commutativity @@ -3403,7 +3415,7 @@ def difference(self, other, sort=None): Parameters ---------- other : Index or array-like - sort : False or None, default None + sort : bool or None, default None Whether to sort the resulting index. By default, the values are attempted to be sorted, but any TypeError from incomparable elements is caught by pandas. @@ -3411,6 +3423,7 @@ def difference(self, other, sort=None): * None : Attempt to sort the result, but catch any TypeErrors from comparing incomparable elements. * False : Do not sort the result. + * True : Sort the result (which may raise TypeError). Returns ------- @@ -3439,11 +3452,17 @@ def difference(self, other, sort=None): if len(other) == 0: # Note: we do not (yet) sort even if sort=None GH#24959 - return self.rename(result_name) + result = self.rename(result_name) + if sort is True: + return result.sort_values() + return result if not self._should_compare(other): # Nothing matches -> difference is everything - return self.rename(result_name) + result = self.rename(result_name) + if sort is True: + return result.sort_values() + return result result = self._difference(other, sort=sort) return self._wrap_difference_result(other, result) @@ -3479,7 +3498,7 @@ def symmetric_difference(self, other, result_name=None, sort=None): ---------- other : Index or array-like result_name : str - sort : False or None, default None + sort : bool or None, default None Whether to sort the resulting index. By default, the values are attempted to be sorted, but any TypeError from incomparable elements is caught by pandas. @@ -3487,6 +3506,7 @@ def symmetric_difference(self, other, result_name=None, sort=None): * None : Attempt to sort the result, but catch any TypeErrors from comparing incomparable elements. * False : Do not sort the result. + * True : Sort the result (which may raise TypeError). Returns ------- @@ -7161,10 +7181,12 @@ def unpack_nested_dtype(other: _IndexT) -> _IndexT: def _maybe_try_sort(result, sort): - if sort is None: + if sort is not False: try: result = algos.safe_sort(result) except TypeError as err: + if sort is True: + raise warnings.warn( f"{err}, sort order is undefined for incomparable objects.", RuntimeWarning, diff --git a/pandas/core/indexes/multi.py b/pandas/core/indexes/multi.py index 95f35eabb342e..a4df2acf0ab45 100644 --- a/pandas/core/indexes/multi.py +++ b/pandas/core/indexes/multi.py @@ -3560,10 +3560,12 @@ def _union(self, other, sort) -> MultiIndex: else: result = self._get_reconciled_name_object(other) - if sort is None: + if sort is not False: try: result = result.sort_values() except TypeError: + if sort is True: + raise warnings.warn( "The values in the array are unorderable. " "Pass `sort=False` to suppress this warning.", diff --git a/pandas/core/indexes/range.py b/pandas/core/indexes/range.py index ca34fcfc7a625..670b97adf7c36 100644 --- a/pandas/core/indexes/range.py +++ b/pandas/core/indexes/range.py @@ -688,7 +688,7 @@ def _difference(self, other, sort=None): if not isinstance(other, RangeIndex): return super()._difference(other, sort=sort) - if sort is None and self.step < 0: + if sort is not False and self.step < 0: return self[::-1]._difference(other) res_name = ops.get_op_result_name(self, other) diff --git a/pandas/tests/indexes/base_class/test_setops.py b/pandas/tests/indexes/base_class/test_setops.py index 87ffe99896199..21d1630af9de2 100644 --- a/pandas/tests/indexes/base_class/test_setops.py +++ b/pandas/tests/indexes/base_class/test_setops.py @@ -16,12 +16,15 @@ class TestIndexSetOps: @pytest.mark.parametrize( "method", ["union", "intersection", "difference", "symmetric_difference"] ) - def test_setops_disallow_true(self, method): + def test_setops_sort_validation(self, method): idx1 = Index(["a", "b"]) idx2 = Index(["b", "c"]) with pytest.raises(ValueError, match="The 'sort' keyword only takes"): - getattr(idx1, method)(idx2, sort=True) + getattr(idx1, method)(idx2, sort=2) + + # sort=True is supported as of GH#?? + getattr(idx1, method)(idx2, sort=True) def test_setops_preserve_object_dtype(self): idx = Index([1, 2, 3], dtype=object) @@ -88,17 +91,12 @@ def test_union_sort_other_incomparable(self): result = idx.union(idx[:1], sort=False) tm.assert_index_equal(result, idx) - @pytest.mark.xfail(reason="GH#25151 need to decide on True behavior") def test_union_sort_other_incomparable_true(self): - # TODO(GH#25151): decide on True behaviour - # sort=True idx = Index([1, pd.Timestamp("2000")]) with pytest.raises(TypeError, match=".*"): idx.union(idx[:1], sort=True) - @pytest.mark.xfail(reason="GH#25151 need to decide on True behavior") def test_intersection_equal_sort_true(self): - # TODO(GH#25151): decide on True behaviour idx = Index(["c", "a", "b"]) sorted_ = Index(["a", "b", "c"]) tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_) diff --git a/pandas/tests/indexes/multi/test_setops.py b/pandas/tests/indexes/multi/test_setops.py index 4979e461f4cf0..de4d0e014be29 100644 --- a/pandas/tests/indexes/multi/test_setops.py +++ b/pandas/tests/indexes/multi/test_setops.py @@ -204,7 +204,6 @@ def test_difference_sort_special(): tm.assert_index_equal(result, idx) -@pytest.mark.xfail(reason="Not implemented.") def test_difference_sort_special_true(): # TODO(GH#25151): decide on True behaviour idx = MultiIndex.from_product([[1, 0], ["a", "b"]]) @@ -233,8 +232,10 @@ def test_difference_sort_incomparable_true(): idx = MultiIndex.from_product([[1, pd.Timestamp("2000"), 2], ["a", "b"]]) other = MultiIndex.from_product([[3, pd.Timestamp("2000"), 4], ["c", "d"]]) - msg = "The 'sort' keyword only takes the values of None or False; True was passed." - with pytest.raises(ValueError, match=msg): + # TODO: this is raising in constructing a Categorical when calling + # algos.safe_sort. Should we catch and re-raise with a better message? + msg = "'values' is not ordered, please explicitly specify the categories order " + with pytest.raises(TypeError, match=msg): idx.difference(other, sort=True) @@ -344,12 +345,11 @@ def test_intersect_equal_sort(): tm.assert_index_equal(idx.intersection(idx, sort=None), idx) -@pytest.mark.xfail(reason="Not implemented.") def test_intersect_equal_sort_true(): - # TODO(GH#25151): decide on True behaviour idx = MultiIndex.from_product([[1, 0], ["a", "b"]]) - sorted_ = MultiIndex.from_product([[0, 1], ["a", "b"]]) - tm.assert_index_equal(idx.intersection(idx, sort=True), sorted_) + expected = MultiIndex.from_product([[0, 1], ["a", "b"]]) + result = idx.intersection(idx, sort=True) + tm.assert_index_equal(result, expected) @pytest.mark.parametrize("slice_", [slice(None), slice(0)]) @@ -366,7 +366,6 @@ def test_union_sort_other_empty(slice_): tm.assert_index_equal(idx.union(other, sort=False), idx) -@pytest.mark.xfail(reason="Not implemented.") def test_union_sort_other_empty_sort(): # TODO(GH#25151): decide on True behaviour # # sort=True @@ -391,12 +390,10 @@ def test_union_sort_other_incomparable(): tm.assert_index_equal(result, idx) -@pytest.mark.xfail(reason="Not implemented.") def test_union_sort_other_incomparable_sort(): - # TODO(GH#25151): decide on True behaviour - # # sort=True idx = MultiIndex.from_product([[1, pd.Timestamp("2000")], ["a", "b"]]) - with pytest.raises(TypeError, match="Cannot compare"): + msg = "'<' not supported between instances of 'Timestamp' and 'int'" + with pytest.raises(TypeError, match=msg): idx.union(idx[:1], sort=True) @@ -435,12 +432,15 @@ def test_union_multiindex_empty_rangeindex(): @pytest.mark.parametrize( "method", ["union", "intersection", "difference", "symmetric_difference"] ) -def test_setops_disallow_true(method): +def test_setops_sort_validation(method): idx1 = MultiIndex.from_product([["a", "b"], [1, 2]]) idx2 = MultiIndex.from_product([["b", "c"], [1, 2]]) with pytest.raises(ValueError, match="The 'sort' keyword only takes"): - getattr(idx1, method)(idx2, sort=True) + getattr(idx1, method)(idx2, sort=2) + + # sort=True is supported as of GH#? + getattr(idx1, method)(idx2, sort=True) @pytest.mark.parametrize("val", [pd.NA, 100]) diff --git a/pandas/tests/indexes/numeric/test_setops.py b/pandas/tests/indexes/numeric/test_setops.py index 3e3de14960f4e..2276b10db1fe3 100644 --- a/pandas/tests/indexes/numeric/test_setops.py +++ b/pandas/tests/indexes/numeric/test_setops.py @@ -143,11 +143,8 @@ def test_union_sort_other_special(self, slice_): # sort=False tm.assert_index_equal(idx.union(other, sort=False), idx) - @pytest.mark.xfail(reason="Not implemented") @pytest.mark.parametrize("slice_", [slice(None), slice(0)]) def test_union_sort_special_true(self, slice_): - # TODO(GH#25151): decide on True behaviour - # sort=True idx = Index([1, 0, 2]) # default, sort=None other = idx[slice_] diff --git a/pandas/tests/indexes/test_setops.py b/pandas/tests/indexes/test_setops.py index 001efe07b5d2b..dd27470d82111 100644 --- a/pandas/tests/indexes/test_setops.py +++ b/pandas/tests/indexes/test_setops.py @@ -836,16 +836,14 @@ def test_difference_incomparable(self, opname): result = op(a) tm.assert_index_equal(result, expected) - @pytest.mark.xfail(reason="Not implemented") @pytest.mark.parametrize("opname", ["difference", "symmetric_difference"]) def test_difference_incomparable_true(self, opname): - # TODO(GH#25151): decide on True behaviour - # # sort=True, raises a = Index([3, Timestamp("2000"), 1]) b = Index([2, Timestamp("1999"), 1]) op = operator.methodcaller(opname, b, sort=True) - with pytest.raises(TypeError, match="Cannot compare"): + msg = "'<' not supported between instances of 'Timestamp' and 'int'" + with pytest.raises(TypeError, match=msg): op(a) def test_symmetric_difference_mi(self, sort):