Skip to content

ENH: Add sort parameter to set operations for some Indexes and adjust… #24521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v0.24.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ Other Enhancements
- :func:`read_fwf` now accepts keyword ``infer_nrows`` (:issue:`15138`).
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`, :issue:`24466`)
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
- :meth:`Index.difference`, :meth:`Index.intersection`, :meth:`Index.union`, and :meth:`Index.symmetric_difference` now have an optional ``sort`` parameter to control whether the results should be sorted if possible (:issue:`17839`, :issue:`24471`)
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
- :meth:`DataFrame.to_stata` and :class:`pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)
Expand Down
25 changes: 20 additions & 5 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,21 @@ def item_from_zerodim(val: object) -> object:

@cython.wraparound(False)
@cython.boundscheck(False)
def fast_unique_multiple(list arrays):
def fast_unique_multiple(list arrays, sort: bool=True):
"""
Generate a list of unique values from a list of arrays.

Parameters
----------
list : array-like
A list of array-like objects
sort : boolean
Whether or not to sort the resulting unique list

Returns
-------
unique_list : list of unique values
"""
cdef:
ndarray[object] buf
Py_ssize_t k = len(arrays)
Expand All @@ -217,10 +231,11 @@ def fast_unique_multiple(list arrays):
if val not in table:
table[val] = stub
uniques.append(val)
try:
uniques.sort()
except Exception:
pass
if sort:
try:
uniques.sort()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might see if using safe_sort makes sense here (you would have to import here, otherwise this would e circular)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried using safe_sort here but it was causing some problems. The issue seems to be that uniques here is a list of tuples which are then used to construct a MultiIndex in MultiIndex.union. However, when we use safe_sort it turns uniques into an np.array and doesn't sort it correctly so we get the wrong results. I can have a closer look at trying to resolve this if you want but it might involve changing safe_sort a bit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, let's followup up later then (new PR).

except Exception:
pass

return uniques

Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def _get_combined_index(indexes, intersect=False, sort=False):
elif intersect:
index = indexes[0]
for other in indexes[1:]:
index = index.intersection(other)
index = index.intersection(other, sort=sort)
else:
index = _union_indexes(indexes, sort=sort)
index = ensure_index(index)
Expand Down
67 changes: 46 additions & 21 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2241,13 +2241,17 @@ def _get_reconciled_name_object(self, other):
return self._shallow_copy(name=name)
return self

def union(self, other):
def union(self, other, sort=True):
"""
Form the union of two Index objects and sorts if possible.
Form the union of two Index objects.

Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible

.. versionadded:: 0.24.0

Returns
-------
Expand Down Expand Up @@ -2277,7 +2281,7 @@ def union(self, other):
if not is_dtype_union_equal(self.dtype, other.dtype):
this = self.astype('O')
other = other.astype('O')
return this.union(other)
return this.union(other, sort=sort)

# TODO(EA): setops-refactor, clean all this up
if is_period_dtype(self) or is_datetime64tz_dtype(self):
Expand Down Expand Up @@ -2311,29 +2315,33 @@ def union(self, other):
else:
result = lvals

try:
result = sorting.safe_sort(result)
except TypeError as e:
warnings.warn("%s, sort order is undefined for "
"incomparable objects" % e, RuntimeWarning,
stacklevel=3)
if sort:
try:
result = sorting.safe_sort(result)
except TypeError as e:
warnings.warn("{}, sort order is undefined for "
"incomparable objects".format(e),
RuntimeWarning, stacklevel=3)

# for subclasses
return self._wrap_setop_result(other, result)

def _wrap_setop_result(self, other, result):
return self._constructor(result, name=get_op_result_name(self, other))

def intersection(self, other):
def intersection(self, other, sort=True):
"""
Form the intersection of two Index objects.

This returns a new Index with elements common to the index and `other`,
preserving the order of the calling index.
This returns a new Index with elements common to the index and `other`.

Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible

.. versionadded:: 0.24.0

Returns
-------
Expand All @@ -2356,7 +2364,7 @@ def intersection(self, other):
if not is_dtype_equal(self.dtype, other.dtype):
this = self.astype('O')
other = other.astype('O')
return this.intersection(other)
return this.intersection(other, sort=sort)

# TODO(EA): setops-refactor, clean all this up
if is_period_dtype(self):
Expand Down Expand Up @@ -2385,8 +2393,18 @@ def intersection(self, other):
indexer = indexer[indexer != -1]

taken = other.take(indexer)

if sort:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, I still think we can use _get_reconcile_name_object here (maybe need to wrap taken in Index), but can do in a followup.

taken = sorting.safe_sort(taken.values)
if self.name != other.name:
name = None
else:
name = self.name
return self._shallow_copy(taken, name=name)

if self.name != other.name:
taken.name = None

return taken

def difference(self, other, sort=True):
Expand Down Expand Up @@ -2442,16 +2460,18 @@ def difference(self, other, sort=True):

return this._shallow_copy(the_diff, name=result_name, freq=None)

def symmetric_difference(self, other, result_name=None):
def symmetric_difference(self, other, result_name=None, sort=True):
"""
Compute the symmetric difference of two Index objects.

It's sorted if sorting is possible.

Parameters
----------
other : Index or array-like
result_name : str
sort : bool, default True
Sort the resulting index if possible

.. versionadded:: 0.24.0

Returns
-------
Expand Down Expand Up @@ -2496,10 +2516,11 @@ def symmetric_difference(self, other, result_name=None):
right_diff = other.values.take(right_indexer)

the_diff = _concat._concat_compat([left_diff, right_diff])
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass
if sort:
try:
the_diff = sorting.safe_sort(the_diff)
except TypeError:
pass

attribs = self._get_attributes_dict()
attribs['name'] = result_name
Expand Down Expand Up @@ -3226,8 +3247,12 @@ def join(self, other, how='left', level=None, return_indexers=False,
elif how == 'right':
join_index = other
elif how == 'inner':
join_index = self.intersection(other)
# TODO: sort=False here for backwards compat. It may
# be better to use the sort parameter passed into join
join_index = self.intersection(other, sort=False)
elif how == 'outer':
# TODO: sort=True here for backwards compat. It may
# be better to use the sort parameter passed into join
join_index = self.union(other)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On Master union does generally sort by default so I have left the default sort=True here for now for compatibility reasons so the current behaviour of join(..., how='outer') does not change for now and all of the tests expecting the results of this type of join to be sorted don't break.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a TODO


if sort:
Expand Down
6 changes: 3 additions & 3 deletions pandas/core/indexes/datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def _wrap_setop_result(self, other, result):
name = get_op_result_name(self, other)
return self._shallow_copy(result, name=name, freq=None, tz=self.tz)

def intersection(self, other):
def intersection(self, other, sort=True):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sort parameter is not fully implemented for intersection on DatetimeIndex yet. I just put this in for now to fix some failing tests

"""
Specialized intersection for DatetimeIndex objects. May be much faster
than Index.intersection
Expand All @@ -617,7 +617,7 @@ def intersection(self, other):
other = DatetimeIndex(other)
except (TypeError, ValueError):
pass
result = Index.intersection(self, other)
result = Index.intersection(self, other, sort=sort)
if isinstance(result, DatetimeIndex):
if result.freq is None:
result.freq = to_offset(result.inferred_freq)
Expand All @@ -627,7 +627,7 @@ def intersection(self, other):
other.freq != self.freq or
not other.freq.isAnchored() or
(not self.is_monotonic or not other.is_monotonic)):
result = Index.intersection(self, other)
result = Index.intersection(self, other, sort=sort)
# Invalidate the freq of `result`, which may not be correct at
# this point, depending on the values.
result.freq = None
Expand Down
7 changes: 2 additions & 5 deletions pandas/core/indexes/interval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1104,11 +1104,8 @@ def func(self, other, sort=True):
'objects that have compatible dtypes')
raise TypeError(msg.format(op=op_name))

if op_name == 'difference':
result = getattr(self._multiindex, op_name)(other._multiindex,
sort)
else:
result = getattr(self._multiindex, op_name)(other._multiindex)
result = getattr(self._multiindex, op_name)(other._multiindex,
sort=sort)
result_name = get_op_result_name(self, other)

# GH 19101: ensure empty results have correct dtype
Expand Down
28 changes: 21 additions & 7 deletions pandas/core/indexes/multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,13 +2879,17 @@ def equal_levels(self, other):
return False
return True

def union(self, other):
def union(self, other, sort=True):
"""
Form the union of two MultiIndex objects, sorting if possible
Form the union of two MultiIndex objects

Parameters
----------
other : MultiIndex or array / Index of tuples
sort : bool, default True
Sort the resulting MultiIndex if possible

.. versionadded:: 0.24.0

Returns
-------
Expand All @@ -2900,17 +2904,23 @@ def union(self, other):
return self

uniq_tuples = lib.fast_unique_multiple([self._ndarray_values,
other._ndarray_values])
other._ndarray_values],
sort=sort)

return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
names=result_names)

def intersection(self, other):
def intersection(self, other, sort=True):
"""
Form the intersection of two MultiIndex objects, sorting if possible
Form the intersection of two MultiIndex objects.

Parameters
----------
other : MultiIndex or array / Index of tuples
sort : bool, default True
Sort the resulting MultiIndex if possible

.. versionadded:: 0.24.0

Returns
-------
Expand All @@ -2924,7 +2934,11 @@ def intersection(self, other):

self_tuples = self._ndarray_values
other_tuples = other._ndarray_values
uniq_tuples = sorted(set(self_tuples) & set(other_tuples))
uniq_tuples = set(self_tuples) & set(other_tuples)

if sort:
uniq_tuples = sorted(uniq_tuples)

if len(uniq_tuples) == 0:
return MultiIndex(levels=self.levels,
codes=[[]] * self.nlevels,
Expand All @@ -2935,7 +2949,7 @@ def intersection(self, other):

def difference(self, other, sort=True):
"""
Compute sorted set difference of two MultiIndex objects
Compute set difference of two MultiIndex objects

Parameters
----------
Expand Down
13 changes: 9 additions & 4 deletions pandas/core/indexes/range.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,17 @@ def equals(self, other):

return super(RangeIndex, self).equals(other)

def intersection(self, other):
def intersection(self, other, sort=True):
"""
Form the intersection of two Index objects. Sortedness of the result is
not guaranteed
Form the intersection of two Index objects.

Parameters
----------
other : Index or array-like
sort : bool, default True
Sort the resulting index if possible

.. versionadded:: 0.24.0

Returns
-------
Expand All @@ -361,7 +364,7 @@ def intersection(self, other):
return self._get_reconciled_name_object(other)

if not isinstance(other, RangeIndex):
return super(RangeIndex, self).intersection(other)
return super(RangeIndex, self).intersection(other, sort=sort)

if not len(self) or not len(other):
return RangeIndex._simple_new(None)
Expand Down Expand Up @@ -398,6 +401,8 @@ def intersection(self, other):

if (self._step < 0 and other._step < 0) is not (new_index._step < 0):
new_index = new_index[::-1]
if sort:
new_index = new_index.sort_values()
return new_index

def _min_fitting_element(self, lower_limit):
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/pytables.py
Original file line number Diff line number Diff line change
Expand Up @@ -4473,7 +4473,7 @@ def _reindex_axis(obj, axis, labels, other=None):

labels = ensure_index(labels.unique())
if other is not None:
labels = ensure_index(other.unique()) & labels
labels = ensure_index(other.unique()).intersection(labels, sort=False)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, on Master intersection does not generally sort by default so I have added sort=False here for compatibility

if not labels.equals(ax):
slicer = [slice(None, None)] * obj.ndim
slicer[axis] = labels
Expand Down
7 changes: 5 additions & 2 deletions pandas/tests/indexes/datetimes/test_setops.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ def test_intersection2(self):

@pytest.mark.parametrize("tz", [None, 'Asia/Tokyo', 'US/Eastern',
'dateutil/US/Pacific'])
def test_intersection(self, tz):
@pytest.mark.parametrize("sort", [True, False])
def test_intersection(self, tz, sort):
# GH 4690 (with tz)
base = date_range('6/1/2000', '6/30/2000', freq='D', name='idx')

Expand Down Expand Up @@ -185,7 +186,9 @@ def test_intersection(self, tz):

for (rng, expected) in [(rng2, expected2), (rng3, expected3),
(rng4, expected4)]:
result = base.intersection(rng)
result = base.intersection(rng, sort=sort)
if sort:
expected = expected.sort_values()
tm.assert_index_equal(result, expected)
assert result.name == expected.name
assert result.freq is None
Expand Down
Loading