Skip to content

Commit 4b31df2

Browse files
committed
ENH: Add sort parameter to set operations for some Indexes and adjust tests
1 parent 449d665 commit 4b31df2

File tree

14 files changed

+400
-238
lines changed

14 files changed

+400
-238
lines changed

doc/source/whatsnew/v0.24.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ Other Enhancements
405405
- :func:`read_fwf` now accepts keyword ``infer_nrows`` (:issue:`15138`).
406406
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
407407
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`)
408-
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
408+
- :meth:`Index.difference`, :meth:`Index.intersection`, :meth:`Index.union`, and :meth:`Index.symmetric_difference` now have an optional ``sort`` parameter to control whether the results should be sorted if possible (:issue:`17839`, :issue:`24471`)
409409
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
410410
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
411411
- :meth:`DataFrame.to_stata` and :class:`pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)

pandas/_libs/lib.pyx

+6-5
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def item_from_zerodim(val: object) -> object:
199199

200200
@cython.wraparound(False)
201201
@cython.boundscheck(False)
202-
def fast_unique_multiple(list arrays):
202+
def fast_unique_multiple(list arrays, sort: bool=True):
203203
cdef:
204204
ndarray[object] buf
205205
Py_ssize_t k = len(arrays)
@@ -216,10 +216,11 @@ def fast_unique_multiple(list arrays):
216216
if val not in table:
217217
table[val] = stub
218218
uniques.append(val)
219-
try:
220-
uniques.sort()
221-
except Exception:
222-
pass
219+
if sort:
220+
try:
221+
uniques.sort()
222+
except Exception:
223+
pass
223224

224225
return uniques
225226

pandas/core/indexes/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _get_combined_index(indexes, intersect=False, sort=False):
112112
elif intersect:
113113
index = indexes[0]
114114
for other in indexes[1:]:
115-
index = index.intersection(other)
115+
index = index.intersection(other, sort=sort)
116116
else:
117117
index = _union_indexes(indexes, sort=sort)
118118
index = ensure_index(index)

pandas/core/indexes/base.py

+53-32
Original file line numberDiff line numberDiff line change
@@ -2235,13 +2235,17 @@ def _get_reconciled_name_object(self, other):
22352235
return self._shallow_copy(name=name)
22362236
return self
22372237

2238-
def union(self, other):
2238+
def union(self, other, sort=True):
22392239
"""
2240-
Form the union of two Index objects and sorts if possible.
2240+
Form the union of two Index objects.
22412241
22422242
Parameters
22432243
----------
22442244
other : Index or array-like
2245+
sort : bool, default True
2246+
Sort the resulting index if possible
2247+
2248+
.. versionadded:: 0.24.0
22452249
22462250
Returns
22472251
-------
@@ -2271,7 +2275,7 @@ def union(self, other):
22712275
if not is_dtype_union_equal(self.dtype, other.dtype):
22722276
this = self.astype('O')
22732277
other = other.astype('O')
2274-
return this.union(other)
2278+
return this.union(other, sort=sort)
22752279

22762280
# TODO(EA): setops-refactor, clean all this up
22772281
if is_period_dtype(self) or is_datetime64tz_dtype(self):
@@ -2302,44 +2306,51 @@ def union(self, other):
23022306
allow_fill=False)
23032307
result = _concat._concat_compat((lvals, other_diff))
23042308

2305-
try:
2306-
lvals[0] < other_diff[0]
2307-
except TypeError as e:
2308-
warnings.warn("%s, sort order is undefined for "
2309-
"incomparable objects" % e, RuntimeWarning,
2310-
stacklevel=3)
2311-
else:
2312-
types = frozenset((self.inferred_type,
2313-
other.inferred_type))
2314-
if not types & _unsortable_types:
2315-
result.sort()
2309+
if sort:
2310+
try:
2311+
lvals[0] < other_diff[0]
2312+
except TypeError as e:
2313+
warnings.warn("%s, sort order is undefined for "
2314+
"incomparable objects" % e,
2315+
RuntimeWarning,
2316+
stacklevel=3)
2317+
else:
2318+
types = frozenset((self.inferred_type,
2319+
other.inferred_type))
2320+
if not types & _unsortable_types:
2321+
result.sort()
23162322

23172323
else:
23182324
result = lvals
23192325

2320-
try:
2321-
result = np.sort(result)
2322-
except TypeError as e:
2323-
warnings.warn("%s, sort order is undefined for "
2324-
"incomparable objects" % e, RuntimeWarning,
2325-
stacklevel=3)
2326+
if sort:
2327+
try:
2328+
result = np.sort(result)
2329+
except TypeError as e:
2330+
warnings.warn("%s, sort order is undefined for "
2331+
"incomparable objects" % e,
2332+
RuntimeWarning,
2333+
stacklevel=3)
23262334

23272335
# for subclasses
23282336
return self._wrap_setop_result(other, result)
23292337

23302338
def _wrap_setop_result(self, other, result):
23312339
return self._constructor(result, name=get_op_result_name(self, other))
23322340

2333-
def intersection(self, other):
2341+
def intersection(self, other, sort=True):
23342342
"""
23352343
Form the intersection of two Index objects.
23362344
2337-
This returns a new Index with elements common to the index and `other`,
2338-
preserving the order of the calling index.
2345+
This returns a new Index with elements common to the index and `other`.
23392346
23402347
Parameters
23412348
----------
23422349
other : Index or array-like
2350+
sort : bool, default True
2351+
Sort the resulting index if possible
2352+
2353+
.. versionadded:: 0.24.0
23432354
23442355
Returns
23452356
-------
@@ -2362,7 +2373,7 @@ def intersection(self, other):
23622373
if not is_dtype_equal(self.dtype, other.dtype):
23632374
this = self.astype('O')
23642375
other = other.astype('O')
2365-
return this.intersection(other)
2376+
return this.intersection(other, sort=sort)
23662377

23672378
# TODO(EA): setops-refactor, clean all this up
23682379
if is_period_dtype(self):
@@ -2393,6 +2404,13 @@ def intersection(self, other):
23932404
taken = other.take(indexer)
23942405
if self.name != other.name:
23952406
taken.name = None
2407+
if sort:
2408+
try:
2409+
taken = taken.sort_values()
2410+
except TypeError as e:
2411+
warnings.warn("%s, sort order is undefined for "
2412+
"incomparable objects" % e, RuntimeWarning,
2413+
stacklevel=3)
23962414
return taken
23972415

23982416
def difference(self, other, sort=True):
@@ -2448,16 +2466,18 @@ def difference(self, other, sort=True):
24482466

24492467
return this._shallow_copy(the_diff, name=result_name, freq=None)
24502468

2451-
def symmetric_difference(self, other, result_name=None):
2469+
def symmetric_difference(self, other, result_name=None, sort=True):
24522470
"""
24532471
Compute the symmetric difference of two Index objects.
24542472
2455-
It's sorted if sorting is possible.
2456-
24572473
Parameters
24582474
----------
24592475
other : Index or array-like
24602476
result_name : str
2477+
sort : bool, default True
2478+
Sort the resulting index if possible
2479+
2480+
.. versionadded:: 0.24.0
24612481
24622482
Returns
24632483
-------
@@ -2502,10 +2522,11 @@ def symmetric_difference(self, other, result_name=None):
25022522
right_diff = other.values.take(right_indexer)
25032523

25042524
the_diff = _concat._concat_compat([left_diff, right_diff])
2505-
try:
2506-
the_diff = sorting.safe_sort(the_diff)
2507-
except TypeError:
2508-
pass
2525+
if sort:
2526+
try:
2527+
the_diff = sorting.safe_sort(the_diff)
2528+
except TypeError:
2529+
pass
25092530

25102531
attribs = self._get_attributes_dict()
25112532
attribs['name'] = result_name
@@ -3233,7 +3254,7 @@ def join(self, other, how='left', level=None, return_indexers=False,
32333254
elif how == 'right':
32343255
join_index = other
32353256
elif how == 'inner':
3236-
join_index = self.intersection(other)
3257+
join_index = self.intersection(other, sort=sort)
32373258
elif how == 'outer':
32383259
join_index = self.union(other)
32393260

pandas/core/indexes/datetimes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def _wrap_setop_result(self, other, result):
603603
raise ValueError('Passed item and index have different timezone')
604604
return self._shallow_copy(result, name=name, freq=None, tz=self.tz)
605605

606-
def intersection(self, other):
606+
def intersection(self, other, sort=True):
607607
"""
608608
Specialized intersection for DatetimeIndex objects. May be much faster
609609
than Index.intersection
@@ -626,7 +626,7 @@ def intersection(self, other):
626626
other = DatetimeIndex(other)
627627
except (TypeError, ValueError):
628628
pass
629-
result = Index.intersection(self, other)
629+
result = Index.intersection(self, other, sort=sort)
630630
if isinstance(result, DatetimeIndex):
631631
if result.freq is None:
632632
result.freq = to_offset(result.inferred_freq)
@@ -636,7 +636,7 @@ def intersection(self, other):
636636
other.freq != self.freq or
637637
not other.freq.isAnchored() or
638638
(not self.is_monotonic or not other.is_monotonic)):
639-
result = Index.intersection(self, other)
639+
result = Index.intersection(self, other, sort=sort)
640640
result = self._shallow_copy(result._values, name=result.name,
641641
tz=result.tz, freq=None)
642642
if result.freq is None:

pandas/core/indexes/interval.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1102,11 +1102,8 @@ def func(self, other, sort=True):
11021102
'objects that have compatible dtypes')
11031103
raise TypeError(msg.format(op=op_name))
11041104

1105-
if op_name == 'difference':
1106-
result = getattr(self._multiindex, op_name)(other._multiindex,
1107-
sort)
1108-
else:
1109-
result = getattr(self._multiindex, op_name)(other._multiindex)
1105+
result = getattr(self._multiindex, op_name)(other._multiindex,
1106+
sort=sort)
11101107
result_name = get_op_result_name(self, other)
11111108

11121109
# GH 19101: ensure empty results have correct dtype

pandas/core/indexes/multi.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -2878,13 +2878,17 @@ def equal_levels(self, other):
28782878
return False
28792879
return True
28802880

2881-
def union(self, other):
2881+
def union(self, other, sort=True):
28822882
"""
2883-
Form the union of two MultiIndex objects, sorting if possible
2883+
Form the union of two MultiIndex objects
28842884
28852885
Parameters
28862886
----------
28872887
other : MultiIndex or array / Index of tuples
2888+
sort : bool, default True
2889+
Sort the resulting MultiIndex if possible
2890+
2891+
.. versionadded:: 0.24.0
28882892
28892893
Returns
28902894
-------
@@ -2899,17 +2903,23 @@ def union(self, other):
28992903
return self
29002904

29012905
uniq_tuples = lib.fast_unique_multiple([self._ndarray_values,
2902-
other._ndarray_values])
2906+
other._ndarray_values],
2907+
sort=sort)
2908+
29032909
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
29042910
names=result_names)
29052911

2906-
def intersection(self, other):
2912+
def intersection(self, other, sort=True):
29072913
"""
2908-
Form the intersection of two MultiIndex objects, sorting if possible
2914+
Form the intersection of two MultiIndex objects.
29092915
29102916
Parameters
29112917
----------
29122918
other : MultiIndex or array / Index of tuples
2919+
sort : bool, default True
2920+
Sort the resulting MultiIndex if possible
2921+
2922+
.. versionadded:: 0.24.0
29132923
29142924
Returns
29152925
-------
@@ -2923,7 +2933,11 @@ def intersection(self, other):
29232933

29242934
self_tuples = self._ndarray_values
29252935
other_tuples = other._ndarray_values
2926-
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
2936+
uniq_tuples = set(self_tuples) & set(other_tuples)
2937+
2938+
if sort:
2939+
uniq_tuples = sorted(uniq_tuples)
2940+
29272941
if len(uniq_tuples) == 0:
29282942
return MultiIndex(levels=self.levels,
29292943
codes=[[]] * self.nlevels,
@@ -2934,7 +2948,7 @@ def intersection(self, other):
29342948

29352949
def difference(self, other, sort=True):
29362950
"""
2937-
Compute sorted set difference of two MultiIndex objects
2951+
Compute set difference of two MultiIndex objects
29382952
29392953
Parameters
29402954
----------

pandas/core/indexes/range.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,17 @@ def equals(self, other):
343343

344344
return super(RangeIndex, self).equals(other)
345345

346-
def intersection(self, other):
346+
def intersection(self, other, sort=True):
347347
"""
348-
Form the intersection of two Index objects. Sortedness of the result is
349-
not guaranteed
348+
Form the intersection of two Index objects.
350349
351350
Parameters
352351
----------
353352
other : Index or array-like
353+
sort : bool, default True
354+
Sort the resulting index if possible
355+
356+
.. versionadded:: 0.24.0
354357
355358
Returns
356359
-------
@@ -361,7 +364,7 @@ def intersection(self, other):
361364
return self._get_reconciled_name_object(other)
362365

363366
if not isinstance(other, RangeIndex):
364-
return super(RangeIndex, self).intersection(other)
367+
return super(RangeIndex, self).intersection(other, sort=sort)
365368

366369
if not len(self) or not len(other):
367370
return RangeIndex._simple_new(None)
@@ -398,6 +401,8 @@ def intersection(self, other):
398401

399402
if (self._step < 0 and other._step < 0) is not (new_index._step < 0):
400403
new_index = new_index[::-1]
404+
if sort:
405+
new_index = new_index.sort_values()
401406
return new_index
402407

403408
def _min_fitting_element(self, lower_limit):

pandas/tests/indexes/datetimes/test_setops.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def test_intersection2(self):
138138

139139
@pytest.mark.parametrize("tz", [None, 'Asia/Tokyo', 'US/Eastern',
140140
'dateutil/US/Pacific'])
141-
def test_intersection(self, tz):
141+
@pytest.mark.parametrize("sort", [True, False])
142+
def test_intersection(self, tz, sort):
142143
# GH 4690 (with tz)
143144
base = date_range('6/1/2000', '6/30/2000', freq='D', name='idx')
144145

@@ -185,7 +186,9 @@ def test_intersection(self, tz):
185186

186187
for (rng, expected) in [(rng2, expected2), (rng3, expected3),
187188
(rng4, expected4)]:
188-
result = base.intersection(rng)
189+
result = base.intersection(rng, sort=sort)
190+
if sort:
191+
expected = expected.sort_values()
189192
tm.assert_index_equal(result, expected)
190193
assert result.name == expected.name
191194
assert result.freq is None

0 commit comments

Comments
 (0)