Skip to content

Commit b0a3444

Browse files
committed
ENH: Add sort parameter to set operations for some Indexes and adjust tests
1 parent 33f91d8 commit b0a3444

File tree

14 files changed

+380
-224
lines changed

14 files changed

+380
-224
lines changed

doc/source/whatsnew/v0.24.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -410,7 +410,7 @@ Other Enhancements
410410
- :func:`read_fwf` now accepts keyword ``infer_nrows`` (:issue:`15138`).
411411
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
412412
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`, :issue:`24466`)
413-
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
413+
- :meth:`Index.difference`, :meth:`Index.intersection`, :meth:`Index.union`, and :meth:`Index.symmetric_difference` now have an optional ``sort`` parameter to control whether the results should be sorted if possible (:issue:`17839`, :issue:`24471`)
414414
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
415415
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
416416
- :meth:`DataFrame.to_stata` and :class:`pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)

pandas/_libs/lib.pyx

+6-5
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ def item_from_zerodim(val: object) -> object:
200200

201201
@cython.wraparound(False)
202202
@cython.boundscheck(False)
203-
def fast_unique_multiple(list arrays):
203+
def fast_unique_multiple(list arrays, sort: bool=True):
204204
cdef:
205205
ndarray[object] buf
206206
Py_ssize_t k = len(arrays)
@@ -217,10 +217,11 @@ def fast_unique_multiple(list arrays):
217217
if val not in table:
218218
table[val] = stub
219219
uniques.append(val)
220-
try:
221-
uniques.sort()
222-
except Exception:
223-
pass
220+
if sort:
221+
try:
222+
uniques.sort()
223+
except Exception:
224+
pass
224225

225226
return uniques
226227

pandas/core/indexes/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _get_combined_index(indexes, intersect=False, sort=False):
112112
elif intersect:
113113
index = indexes[0]
114114
for other in indexes[1:]:
115-
index = index.intersection(other)
115+
index = index.intersection(other, sort=sort)
116116
else:
117117
index = _union_indexes(indexes, sort=sort)
118118
index = ensure_index(index)

pandas/core/indexes/base.py

+39-21
Original file line numberDiff line numberDiff line change
@@ -2241,13 +2241,17 @@ def _get_reconciled_name_object(self, other):
22412241
return self._shallow_copy(name=name)
22422242
return self
22432243

2244-
def union(self, other):
2244+
def union(self, other, sort=True):
22452245
"""
2246-
Form the union of two Index objects and sorts if possible.
2246+
Form the union of two Index objects.
22472247
22482248
Parameters
22492249
----------
22502250
other : Index or array-like
2251+
sort : bool, default True
2252+
Sort the resulting index if possible
2253+
2254+
.. versionadded:: 0.24.0
22512255
22522256
Returns
22532257
-------
@@ -2277,7 +2281,7 @@ def union(self, other):
22772281
if not is_dtype_union_equal(self.dtype, other.dtype):
22782282
this = self.astype('O')
22792283
other = other.astype('O')
2280-
return this.union(other)
2284+
return this.union(other, sort=sort)
22812285

22822286
# TODO(EA): setops-refactor, clean all this up
22832287
if is_period_dtype(self) or is_datetime64tz_dtype(self):
@@ -2311,29 +2315,33 @@ def union(self, other):
23112315
else:
23122316
result = lvals
23132317

2314-
try:
2315-
result = sorting.safe_sort(result)
2316-
except TypeError as e:
2317-
warnings.warn("%s, sort order is undefined for "
2318-
"incomparable objects" % e, RuntimeWarning,
2319-
stacklevel=3)
2318+
if sort:
2319+
try:
2320+
result = sorting.safe_sort(result)
2321+
except TypeError as e:
2322+
warnings.warn("%s, sort order is undefined for "
2323+
"incomparable objects" % e, RuntimeWarning,
2324+
stacklevel=3)
23202325

23212326
# for subclasses
23222327
return self._wrap_setop_result(other, result)
23232328

23242329
def _wrap_setop_result(self, other, result):
23252330
return self._constructor(result, name=get_op_result_name(self, other))
23262331

2327-
def intersection(self, other):
2332+
def intersection(self, other, sort=True):
23282333
"""
23292334
Form the intersection of two Index objects.
23302335
2331-
This returns a new Index with elements common to the index and `other`,
2332-
preserving the order of the calling index.
2336+
This returns a new Index with elements common to the index and `other`.
23332337
23342338
Parameters
23352339
----------
23362340
other : Index or array-like
2341+
sort : bool, default True
2342+
Sort the resulting index if possible
2343+
2344+
.. versionadded:: 0.24.0
23372345
23382346
Returns
23392347
-------
@@ -2356,7 +2364,7 @@ def intersection(self, other):
23562364
if not is_dtype_equal(self.dtype, other.dtype):
23572365
this = self.astype('O')
23582366
other = other.astype('O')
2359-
return this.intersection(other)
2367+
return this.intersection(other, sort=sort)
23602368

23612369
# TODO(EA): setops-refactor, clean all this up
23622370
if is_period_dtype(self):
@@ -2387,6 +2395,13 @@ def intersection(self, other):
23872395
taken = other.take(indexer)
23882396
if self.name != other.name:
23892397
taken.name = None
2398+
if sort:
2399+
try:
2400+
taken = taken.sort_values()
2401+
except TypeError as e:
2402+
warnings.warn("%s, sort order is undefined for "
2403+
"incomparable objects" % e, RuntimeWarning,
2404+
stacklevel=3)
23902405
return taken
23912406

23922407
def difference(self, other, sort=True):
@@ -2442,16 +2457,18 @@ def difference(self, other, sort=True):
24422457

24432458
return this._shallow_copy(the_diff, name=result_name, freq=None)
24442459

2445-
def symmetric_difference(self, other, result_name=None):
2460+
def symmetric_difference(self, other, result_name=None, sort=True):
24462461
"""
24472462
Compute the symmetric difference of two Index objects.
24482463
2449-
It's sorted if sorting is possible.
2450-
24512464
Parameters
24522465
----------
24532466
other : Index or array-like
24542467
result_name : str
2468+
sort : bool, default True
2469+
Sort the resulting index if possible
2470+
2471+
.. versionadded:: 0.24.0
24552472
24562473
Returns
24572474
-------
@@ -2496,10 +2513,11 @@ def symmetric_difference(self, other, result_name=None):
24962513
right_diff = other.values.take(right_indexer)
24972514

24982515
the_diff = _concat._concat_compat([left_diff, right_diff])
2499-
try:
2500-
the_diff = sorting.safe_sort(the_diff)
2501-
except TypeError:
2502-
pass
2516+
if sort:
2517+
try:
2518+
the_diff = sorting.safe_sort(the_diff)
2519+
except TypeError:
2520+
pass
25032521

25042522
attribs = self._get_attributes_dict()
25052523
attribs['name'] = result_name
@@ -3226,7 +3244,7 @@ def join(self, other, how='left', level=None, return_indexers=False,
32263244
elif how == 'right':
32273245
join_index = other
32283246
elif how == 'inner':
3229-
join_index = self.intersection(other)
3247+
join_index = self.intersection(other, sort=sort)
32303248
elif how == 'outer':
32313249
join_index = self.union(other)
32323250

pandas/core/indexes/datetimes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def _wrap_setop_result(self, other, result):
594594
name = get_op_result_name(self, other)
595595
return self._shallow_copy(result, name=name, freq=None, tz=self.tz)
596596

597-
def intersection(self, other):
597+
def intersection(self, other, sort=True):
598598
"""
599599
Specialized intersection for DatetimeIndex objects. May be much faster
600600
than Index.intersection
@@ -617,7 +617,7 @@ def intersection(self, other):
617617
other = DatetimeIndex(other)
618618
except (TypeError, ValueError):
619619
pass
620-
result = Index.intersection(self, other)
620+
result = Index.intersection(self, other, sort=sort)
621621
if isinstance(result, DatetimeIndex):
622622
if result.freq is None:
623623
result.freq = to_offset(result.inferred_freq)
@@ -627,7 +627,7 @@ def intersection(self, other):
627627
other.freq != self.freq or
628628
not other.freq.isAnchored() or
629629
(not self.is_monotonic or not other.is_monotonic)):
630-
result = Index.intersection(self, other)
630+
result = Index.intersection(self, other, sort=sort)
631631
# Invalidate the freq of `result`, which may not be correct at
632632
# this point, depending on the values.
633633
result.freq = None

pandas/core/indexes/interval.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1102,11 +1102,8 @@ def func(self, other, sort=True):
11021102
'objects that have compatible dtypes')
11031103
raise TypeError(msg.format(op=op_name))
11041104

1105-
if op_name == 'difference':
1106-
result = getattr(self._multiindex, op_name)(other._multiindex,
1107-
sort)
1108-
else:
1109-
result = getattr(self._multiindex, op_name)(other._multiindex)
1105+
result = getattr(self._multiindex, op_name)(other._multiindex,
1106+
sort=sort)
11101107
result_name = get_op_result_name(self, other)
11111108

11121109
# GH 19101: ensure empty results have correct dtype

pandas/core/indexes/multi.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -2879,13 +2879,17 @@ def equal_levels(self, other):
28792879
return False
28802880
return True
28812881

2882-
def union(self, other):
2882+
def union(self, other, sort=True):
28832883
"""
2884-
Form the union of two MultiIndex objects, sorting if possible
2884+
Form the union of two MultiIndex objects
28852885
28862886
Parameters
28872887
----------
28882888
other : MultiIndex or array / Index of tuples
2889+
sort : bool, default True
2890+
Sort the resulting MultiIndex if possible
2891+
2892+
.. versionadded:: 0.24.0
28892893
28902894
Returns
28912895
-------
@@ -2900,17 +2904,23 @@ def union(self, other):
29002904
return self
29012905

29022906
uniq_tuples = lib.fast_unique_multiple([self._ndarray_values,
2903-
other._ndarray_values])
2907+
other._ndarray_values],
2908+
sort=sort)
2909+
29042910
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
29052911
names=result_names)
29062912

2907-
def intersection(self, other):
2913+
def intersection(self, other, sort=True):
29082914
"""
2909-
Form the intersection of two MultiIndex objects, sorting if possible
2915+
Form the intersection of two MultiIndex objects.
29102916
29112917
Parameters
29122918
----------
29132919
other : MultiIndex or array / Index of tuples
2920+
sort : bool, default True
2921+
Sort the resulting MultiIndex if possible
2922+
2923+
.. versionadded:: 0.24.0
29142924
29152925
Returns
29162926
-------
@@ -2924,7 +2934,11 @@ def intersection(self, other):
29242934

29252935
self_tuples = self._ndarray_values
29262936
other_tuples = other._ndarray_values
2927-
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
2937+
uniq_tuples = set(self_tuples) & set(other_tuples)
2938+
2939+
if sort:
2940+
uniq_tuples = sorted(uniq_tuples)
2941+
29282942
if len(uniq_tuples) == 0:
29292943
return MultiIndex(levels=self.levels,
29302944
codes=[[]] * self.nlevels,
@@ -2935,7 +2949,7 @@ def intersection(self, other):
29352949

29362950
def difference(self, other, sort=True):
29372951
"""
2938-
Compute sorted set difference of two MultiIndex objects
2952+
Compute set difference of two MultiIndex objects
29392953
29402954
Parameters
29412955
----------

pandas/core/indexes/range.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,17 @@ def equals(self, other):
343343

344344
return super(RangeIndex, self).equals(other)
345345

346-
def intersection(self, other):
346+
def intersection(self, other, sort=True):
347347
"""
348-
Form the intersection of two Index objects. Sortedness of the result is
349-
not guaranteed
348+
Form the intersection of two Index objects.
350349
351350
Parameters
352351
----------
353352
other : Index or array-like
353+
sort : bool, default True
354+
Sort the resulting index if possible
355+
356+
.. versionadded:: 0.24.0
354357
355358
Returns
356359
-------
@@ -361,7 +364,7 @@ def intersection(self, other):
361364
return self._get_reconciled_name_object(other)
362365

363366
if not isinstance(other, RangeIndex):
364-
return super(RangeIndex, self).intersection(other)
367+
return super(RangeIndex, self).intersection(other, sort=sort)
365368

366369
if not len(self) or not len(other):
367370
return RangeIndex._simple_new(None)
@@ -398,6 +401,8 @@ def intersection(self, other):
398401

399402
if (self._step < 0 and other._step < 0) is not (new_index._step < 0):
400403
new_index = new_index[::-1]
404+
if sort:
405+
new_index = new_index.sort_values()
401406
return new_index
402407

403408
def _min_fitting_element(self, lower_limit):

pandas/tests/indexes/datetimes/test_setops.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def test_intersection2(self):
138138

139139
@pytest.mark.parametrize("tz", [None, 'Asia/Tokyo', 'US/Eastern',
140140
'dateutil/US/Pacific'])
141-
def test_intersection(self, tz):
141+
@pytest.mark.parametrize("sort", [True, False])
142+
def test_intersection(self, tz, sort):
142143
# GH 4690 (with tz)
143144
base = date_range('6/1/2000', '6/30/2000', freq='D', name='idx')
144145

@@ -185,7 +186,9 @@ def test_intersection(self, tz):
185186

186187
for (rng, expected) in [(rng2, expected2), (rng3, expected3),
187188
(rng4, expected4)]:
188-
result = base.intersection(rng)
189+
result = base.intersection(rng, sort=sort)
190+
if sort:
191+
expected = expected.sort_values()
189192
tm.assert_index_equal(result, expected)
190193
assert result.name == expected.name
191194
assert result.freq is None

0 commit comments

Comments
 (0)