Skip to content

Commit 35da4eb

Browse files
reidy-pPingviinituutti
authored andcommitted
ENH: Add sort parameter to set operations for some Indexes and adjust… (pandas-dev#24521)
1 parent e5b643a commit 35da4eb

File tree

15 files changed

+389
-226
lines changed

15 files changed

+389
-226
lines changed

doc/source/whatsnew/v0.24.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ Other Enhancements
413413
- :func:`read_fwf` now accepts keyword ``infer_nrows`` (:issue:`15138`).
414414
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
415415
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`, :issue:`24466`)
416-
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
416+
- :meth:`Index.difference`, :meth:`Index.intersection`, :meth:`Index.union`, and :meth:`Index.symmetric_difference` now have an optional ``sort`` parameter to control whether the results should be sorted if possible (:issue:`17839`, :issue:`24471`)
417417
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
418418
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
419419
- :meth:`DataFrame.to_stata` and :class:`pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)

pandas/_libs/lib.pyx

+20-5
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,21 @@ def item_from_zerodim(val: object) -> object:
200200

201201
@cython.wraparound(False)
202202
@cython.boundscheck(False)
203-
def fast_unique_multiple(list arrays):
203+
def fast_unique_multiple(list arrays, sort: bool=True):
204+
"""
205+
Generate a list of unique values from a list of arrays.
206+
207+
Parameters
208+
----------
209+
list : array-like
210+
A list of array-like objects
211+
sort : boolean
212+
Whether or not to sort the resulting unique list
213+
214+
Returns
215+
-------
216+
unique_list : list of unique values
217+
"""
204218
cdef:
205219
ndarray[object] buf
206220
Py_ssize_t k = len(arrays)
@@ -217,10 +231,11 @@ def fast_unique_multiple(list arrays):
217231
if val not in table:
218232
table[val] = stub
219233
uniques.append(val)
220-
try:
221-
uniques.sort()
222-
except Exception:
223-
pass
234+
if sort:
235+
try:
236+
uniques.sort()
237+
except Exception:
238+
pass
224239

225240
return uniques
226241

pandas/core/indexes/api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def _get_combined_index(indexes, intersect=False, sort=False):
112112
elif intersect:
113113
index = indexes[0]
114114
for other in indexes[1:]:
115-
index = index.intersection(other)
115+
index = index.intersection(other, sort=sort)
116116
else:
117117
index = _union_indexes(indexes, sort=sort)
118118
index = ensure_index(index)

pandas/core/indexes/base.py

+46-21
Original file line numberDiff line numberDiff line change
@@ -2241,13 +2241,17 @@ def _get_reconciled_name_object(self, other):
22412241
return self._shallow_copy(name=name)
22422242
return self
22432243

2244-
def union(self, other):
2244+
def union(self, other, sort=True):
22452245
"""
2246-
Form the union of two Index objects and sorts if possible.
2246+
Form the union of two Index objects.
22472247
22482248
Parameters
22492249
----------
22502250
other : Index or array-like
2251+
sort : bool, default True
2252+
Sort the resulting index if possible
2253+
2254+
.. versionadded:: 0.24.0
22512255
22522256
Returns
22532257
-------
@@ -2277,7 +2281,7 @@ def union(self, other):
22772281
if not is_dtype_union_equal(self.dtype, other.dtype):
22782282
this = self.astype('O')
22792283
other = other.astype('O')
2280-
return this.union(other)
2284+
return this.union(other, sort=sort)
22812285

22822286
# TODO(EA): setops-refactor, clean all this up
22832287
if is_period_dtype(self) or is_datetime64tz_dtype(self):
@@ -2311,29 +2315,33 @@ def union(self, other):
23112315
else:
23122316
result = lvals
23132317

2314-
try:
2315-
result = sorting.safe_sort(result)
2316-
except TypeError as e:
2317-
warnings.warn("%s, sort order is undefined for "
2318-
"incomparable objects" % e, RuntimeWarning,
2319-
stacklevel=3)
2318+
if sort:
2319+
try:
2320+
result = sorting.safe_sort(result)
2321+
except TypeError as e:
2322+
warnings.warn("{}, sort order is undefined for "
2323+
"incomparable objects".format(e),
2324+
RuntimeWarning, stacklevel=3)
23202325

23212326
# for subclasses
23222327
return self._wrap_setop_result(other, result)
23232328

23242329
def _wrap_setop_result(self, other, result):
23252330
return self._constructor(result, name=get_op_result_name(self, other))
23262331

2327-
def intersection(self, other):
2332+
def intersection(self, other, sort=True):
23282333
"""
23292334
Form the intersection of two Index objects.
23302335
2331-
This returns a new Index with elements common to the index and `other`,
2332-
preserving the order of the calling index.
2336+
This returns a new Index with elements common to the index and `other`.
23332337
23342338
Parameters
23352339
----------
23362340
other : Index or array-like
2341+
sort : bool, default True
2342+
Sort the resulting index if possible
2343+
2344+
.. versionadded:: 0.24.0
23372345
23382346
Returns
23392347
-------
@@ -2356,7 +2364,7 @@ def intersection(self, other):
23562364
if not is_dtype_equal(self.dtype, other.dtype):
23572365
this = self.astype('O')
23582366
other = other.astype('O')
2359-
return this.intersection(other)
2367+
return this.intersection(other, sort=sort)
23602368

23612369
# TODO(EA): setops-refactor, clean all this up
23622370
if is_period_dtype(self):
@@ -2385,8 +2393,18 @@ def intersection(self, other):
23852393
indexer = indexer[indexer != -1]
23862394

23872395
taken = other.take(indexer)
2396+
2397+
if sort:
2398+
taken = sorting.safe_sort(taken.values)
2399+
if self.name != other.name:
2400+
name = None
2401+
else:
2402+
name = self.name
2403+
return self._shallow_copy(taken, name=name)
2404+
23882405
if self.name != other.name:
23892406
taken.name = None
2407+
23902408
return taken
23912409

23922410
def difference(self, other, sort=True):
@@ -2442,16 +2460,18 @@ def difference(self, other, sort=True):
24422460

24432461
return this._shallow_copy(the_diff, name=result_name, freq=None)
24442462

2445-
def symmetric_difference(self, other, result_name=None):
2463+
def symmetric_difference(self, other, result_name=None, sort=True):
24462464
"""
24472465
Compute the symmetric difference of two Index objects.
24482466
2449-
It's sorted if sorting is possible.
2450-
24512467
Parameters
24522468
----------
24532469
other : Index or array-like
24542470
result_name : str
2471+
sort : bool, default True
2472+
Sort the resulting index if possible
2473+
2474+
.. versionadded:: 0.24.0
24552475
24562476
Returns
24572477
-------
@@ -2496,10 +2516,11 @@ def symmetric_difference(self, other, result_name=None):
24962516
right_diff = other.values.take(right_indexer)
24972517

24982518
the_diff = _concat._concat_compat([left_diff, right_diff])
2499-
try:
2500-
the_diff = sorting.safe_sort(the_diff)
2501-
except TypeError:
2502-
pass
2519+
if sort:
2520+
try:
2521+
the_diff = sorting.safe_sort(the_diff)
2522+
except TypeError:
2523+
pass
25032524

25042525
attribs = self._get_attributes_dict()
25052526
attribs['name'] = result_name
@@ -3226,8 +3247,12 @@ def join(self, other, how='left', level=None, return_indexers=False,
32263247
elif how == 'right':
32273248
join_index = other
32283249
elif how == 'inner':
3229-
join_index = self.intersection(other)
3250+
# TODO: sort=False here for backwards compat. It may
3251+
# be better to use the sort parameter passed into join
3252+
join_index = self.intersection(other, sort=False)
32303253
elif how == 'outer':
3254+
# TODO: sort=True here for backwards compat. It may
3255+
# be better to use the sort parameter passed into join
32313256
join_index = self.union(other)
32323257

32333258
if sort:

pandas/core/indexes/datetimes.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,7 @@ def _wrap_setop_result(self, other, result):
594594
name = get_op_result_name(self, other)
595595
return self._shallow_copy(result, name=name, freq=None, tz=self.tz)
596596

597-
def intersection(self, other):
597+
def intersection(self, other, sort=True):
598598
"""
599599
Specialized intersection for DatetimeIndex objects. May be much faster
600600
than Index.intersection
@@ -617,7 +617,7 @@ def intersection(self, other):
617617
other = DatetimeIndex(other)
618618
except (TypeError, ValueError):
619619
pass
620-
result = Index.intersection(self, other)
620+
result = Index.intersection(self, other, sort=sort)
621621
if isinstance(result, DatetimeIndex):
622622
if result.freq is None:
623623
result.freq = to_offset(result.inferred_freq)
@@ -627,7 +627,7 @@ def intersection(self, other):
627627
other.freq != self.freq or
628628
not other.freq.isAnchored() or
629629
(not self.is_monotonic or not other.is_monotonic)):
630-
result = Index.intersection(self, other)
630+
result = Index.intersection(self, other, sort=sort)
631631
# Invalidate the freq of `result`, which may not be correct at
632632
# this point, depending on the values.
633633
result.freq = None

pandas/core/indexes/interval.py

+2-5
Original file line numberDiff line numberDiff line change
@@ -1104,11 +1104,8 @@ def func(self, other, sort=True):
11041104
'objects that have compatible dtypes')
11051105
raise TypeError(msg.format(op=op_name))
11061106

1107-
if op_name == 'difference':
1108-
result = getattr(self._multiindex, op_name)(other._multiindex,
1109-
sort)
1110-
else:
1111-
result = getattr(self._multiindex, op_name)(other._multiindex)
1107+
result = getattr(self._multiindex, op_name)(other._multiindex,
1108+
sort=sort)
11121109
result_name = get_op_result_name(self, other)
11131110

11141111
# GH 19101: ensure empty results have correct dtype

pandas/core/indexes/multi.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -2879,13 +2879,17 @@ def equal_levels(self, other):
28792879
return False
28802880
return True
28812881

2882-
def union(self, other):
2882+
def union(self, other, sort=True):
28832883
"""
2884-
Form the union of two MultiIndex objects, sorting if possible
2884+
Form the union of two MultiIndex objects
28852885
28862886
Parameters
28872887
----------
28882888
other : MultiIndex or array / Index of tuples
2889+
sort : bool, default True
2890+
Sort the resulting MultiIndex if possible
2891+
2892+
.. versionadded:: 0.24.0
28892893
28902894
Returns
28912895
-------
@@ -2900,17 +2904,23 @@ def union(self, other):
29002904
return self
29012905

29022906
uniq_tuples = lib.fast_unique_multiple([self._ndarray_values,
2903-
other._ndarray_values])
2907+
other._ndarray_values],
2908+
sort=sort)
2909+
29042910
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
29052911
names=result_names)
29062912

2907-
def intersection(self, other):
2913+
def intersection(self, other, sort=True):
29082914
"""
2909-
Form the intersection of two MultiIndex objects, sorting if possible
2915+
Form the intersection of two MultiIndex objects.
29102916
29112917
Parameters
29122918
----------
29132919
other : MultiIndex or array / Index of tuples
2920+
sort : bool, default True
2921+
Sort the resulting MultiIndex if possible
2922+
2923+
.. versionadded:: 0.24.0
29142924
29152925
Returns
29162926
-------
@@ -2924,7 +2934,11 @@ def intersection(self, other):
29242934

29252935
self_tuples = self._ndarray_values
29262936
other_tuples = other._ndarray_values
2927-
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
2937+
uniq_tuples = set(self_tuples) & set(other_tuples)
2938+
2939+
if sort:
2940+
uniq_tuples = sorted(uniq_tuples)
2941+
29282942
if len(uniq_tuples) == 0:
29292943
return MultiIndex(levels=self.levels,
29302944
codes=[[]] * self.nlevels,
@@ -2935,7 +2949,7 @@ def intersection(self, other):
29352949

29362950
def difference(self, other, sort=True):
29372951
"""
2938-
Compute sorted set difference of two MultiIndex objects
2952+
Compute set difference of two MultiIndex objects
29392953
29402954
Parameters
29412955
----------

pandas/core/indexes/range.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -343,14 +343,17 @@ def equals(self, other):
343343

344344
return super(RangeIndex, self).equals(other)
345345

346-
def intersection(self, other):
346+
def intersection(self, other, sort=True):
347347
"""
348-
Form the intersection of two Index objects. Sortedness of the result is
349-
not guaranteed
348+
Form the intersection of two Index objects.
350349
351350
Parameters
352351
----------
353352
other : Index or array-like
353+
sort : bool, default True
354+
Sort the resulting index if possible
355+
356+
.. versionadded:: 0.24.0
354357
355358
Returns
356359
-------
@@ -361,7 +364,7 @@ def intersection(self, other):
361364
return self._get_reconciled_name_object(other)
362365

363366
if not isinstance(other, RangeIndex):
364-
return super(RangeIndex, self).intersection(other)
367+
return super(RangeIndex, self).intersection(other, sort=sort)
365368

366369
if not len(self) or not len(other):
367370
return RangeIndex._simple_new(None)
@@ -398,6 +401,8 @@ def intersection(self, other):
398401

399402
if (self._step < 0 and other._step < 0) is not (new_index._step < 0):
400403
new_index = new_index[::-1]
404+
if sort:
405+
new_index = new_index.sort_values()
401406
return new_index
402407

403408
def _min_fitting_element(self, lower_limit):

pandas/io/pytables.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4473,7 +4473,7 @@ def _reindex_axis(obj, axis, labels, other=None):
44734473

44744474
labels = ensure_index(labels.unique())
44754475
if other is not None:
4476-
labels = ensure_index(other.unique()) & labels
4476+
labels = ensure_index(other.unique()).intersection(labels, sort=False)
44774477
if not labels.equals(ax):
44784478
slicer = [slice(None, None)] * obj.ndim
44794479
slicer[axis] = labels

pandas/tests/indexes/datetimes/test_setops.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def test_intersection2(self):
138138

139139
@pytest.mark.parametrize("tz", [None, 'Asia/Tokyo', 'US/Eastern',
140140
'dateutil/US/Pacific'])
141-
def test_intersection(self, tz):
141+
@pytest.mark.parametrize("sort", [True, False])
142+
def test_intersection(self, tz, sort):
142143
# GH 4690 (with tz)
143144
base = date_range('6/1/2000', '6/30/2000', freq='D', name='idx')
144145

@@ -185,7 +186,9 @@ def test_intersection(self, tz):
185186

186187
for (rng, expected) in [(rng2, expected2), (rng3, expected3),
187188
(rng4, expected4)]:
188-
result = base.intersection(rng)
189+
result = base.intersection(rng, sort=sort)
190+
if sort:
191+
expected = expected.sort_values()
189192
tm.assert_index_equal(result, expected)
190193
assert result.name == expected.name
191194
assert result.freq is None

0 commit comments

Comments
 (0)