Skip to content

Commit 1c54354

Browse files
Sort Index.difference & union results for early exit scenarios (#14681)
This PR sorts results in `Index.difference` & `union` in the early exit scenarios similar to: pandas-dev/pandas#51346 On `pandas_2.0_feature_branch`: ``` = 110 failed, 101331 passed, 2091 skipped, 952 xfailed, 312 xpassed in 1064.30s (0:17:44) = ``` This PR: ``` = 87 failed, 101354 passed, 2091 skipped, 952 xfailed, 312 xpassed in 1004.34s (0:16:44) = ```
1 parent cb09a39 commit 1c54354

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

python/cudf/cudf/core/_base_index.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -629,12 +629,18 @@ def union(self, other, sort=None):
629629
common_dtype = cudf.utils.dtypes.find_common_type(
630630
[self.dtype, other.dtype]
631631
)
632-
return self._get_reconciled_name_object(other).astype(common_dtype)
632+
res = self._get_reconciled_name_object(other).astype(common_dtype)
633+
if sort:
634+
return res.sort_values()
635+
return res
633636
elif not len(self):
634637
common_dtype = cudf.utils.dtypes.find_common_type(
635638
[self.dtype, other.dtype]
636639
)
637-
return other._get_reconciled_name_object(self).astype(common_dtype)
640+
res = other._get_reconciled_name_object(self).astype(common_dtype)
641+
if sort:
642+
return res.sort_values()
643+
return res
638644

639645
result = self._union(other, sort=sort)
640646
result.name = _get_result_name(self.name, other.name)
@@ -1091,9 +1097,15 @@ def difference(self, other, sort=None):
10911097
other = cudf.Index(other, name=getattr(other, "name", self.name))
10921098

10931099
if not len(other):
1094-
return self._get_reconciled_name_object(other)
1100+
res = self._get_reconciled_name_object(other)
1101+
if sort:
1102+
return res.sort_values()
1103+
return res
10951104
elif self.equals(other):
1096-
return self[:0]._get_reconciled_name_object(other)
1105+
res = self[:0]._get_reconciled_name_object(other)
1106+
if sort:
1107+
return res.sort_values()
1108+
return res
10971109

10981110
res_name = _get_result_name(self.name, other.name)
10991111

0 commit comments

Comments
 (0)