Skip to content

Commit a5ad5fc

Browse files
reidy-pPingviinituutti
authored andcommitted
BUG: Maintain column order with groupby.nth (pandas-dev#22811)
1 parent f450394 commit a5ad5fc

File tree

14 files changed

+225
-84
lines changed

14 files changed

+225
-84
lines changed

doc/source/whatsnew/v0.24.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ Other Enhancements
288288
- Added :meth:`Interval.overlaps`, :meth:`IntervalArray.overlaps`, and :meth:`IntervalIndex.overlaps` for determining overlaps between interval-like objects (:issue:`21998`)
289289
- :func:`~DataFrame.to_parquet` now supports writing a ``DataFrame`` as a directory of parquet files partitioned by a subset of the columns when ``engine = 'pyarrow'`` (:issue:`23283`)
290290
- :meth:`Timestamp.tz_localize`, :meth:`DatetimeIndex.tz_localize`, and :meth:`Series.tz_localize` have gained the ``nonexistent`` argument for alternative handling of nonexistent times. See :ref:`timeseries.timezone_nonexistent` (:issue:`8917`)
291+
- :meth:`Index.difference` now has an optional ``sort`` parameter to specify whether the results should be sorted if possible (:issue:`17839`)
291292
- :meth:`read_excel()` now accepts ``usecols`` as a list of column names or callable (:issue:`18273`)
292293
- :meth:`MultiIndex.to_flat_index` has been added to flatten multiple levels into a single-level :class:`Index` object.
293294
- :meth:`DataFrame.to_stata` and :class:` pandas.io.stata.StataWriter117` can write mixed sting columns to Stata strl format (:issue:`23633`)
@@ -1417,6 +1418,7 @@ Groupby/Resample/Rolling
14171418
- Bug in :meth:`DataFrame.resample` and :meth:`Series.resample` when resampling by a weekly offset (``'W'``) across a DST transition (:issue:`9119`, :issue:`21459`)
14181419
- Bug in :meth:`DataFrame.expanding` in which the ``axis`` argument was not being respected during aggregations (:issue:`23372`)
14191420
- Bug in :meth:`pandas.core.groupby.DataFrameGroupBy.transform` which caused missing values when the input function can accept a :class:`DataFrame` but renames it (:issue:`23455`).
1421+
- Bug in :func:`pandas.core.groupby.GroupBy.nth` where column order was not always preserved (:issue:`20760`)
14201422

14211423
Reshaping
14221424
^^^^^^^^^

pandas/core/groupby/groupby.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,8 @@ def _set_group_selection(self):
494494

495495
if len(groupers):
496496
# GH12839 clear selected obj cache when group selection changes
497-
self._group_selection = ax.difference(Index(groupers)).tolist()
497+
self._group_selection = ax.difference(Index(groupers),
498+
sort=False).tolist()
498499
self._reset_cache('_selected_obj')
499500

500501
def _set_result_index_ordered(self, result):

pandas/core/indexes/base.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -2944,17 +2944,20 @@ def intersection(self, other):
29442944
taken.name = None
29452945
return taken
29462946

2947-
def difference(self, other):
2947+
def difference(self, other, sort=True):
29482948
"""
29492949
Return a new Index with elements from the index that are not in
29502950
`other`.
29512951
29522952
This is the set difference of two Index objects.
2953-
It's sorted if sorting is possible.
29542953
29552954
Parameters
29562955
----------
29572956
other : Index or array-like
2957+
sort : bool, default True
2958+
Sort the resulting index if possible
2959+
2960+
.. versionadded:: 0.24.0
29582961
29592962
Returns
29602963
-------
@@ -2963,10 +2966,12 @@ def difference(self, other):
29632966
Examples
29642967
--------
29652968
2966-
>>> idx1 = pd.Index([1, 2, 3, 4])
2969+
>>> idx1 = pd.Index([2, 1, 3, 4])
29672970
>>> idx2 = pd.Index([3, 4, 5, 6])
29682971
>>> idx1.difference(idx2)
29692972
Int64Index([1, 2], dtype='int64')
2973+
>>> idx1.difference(idx2, sort=False)
2974+
Int64Index([2, 1], dtype='int64')
29702975
29712976
"""
29722977
self._assert_can_do_setop(other)
@@ -2985,10 +2990,11 @@ def difference(self, other):
29852990
label_diff = np.setdiff1d(np.arange(this.size), indexer,
29862991
assume_unique=True)
29872992
the_diff = this.values.take(label_diff)
2988-
try:
2989-
the_diff = sorting.safe_sort(the_diff)
2990-
except TypeError:
2991-
pass
2993+
if sort:
2994+
try:
2995+
the_diff = sorting.safe_sort(the_diff)
2996+
except TypeError:
2997+
pass
29922998

29932999
return this._shallow_copy(the_diff, name=result_name, freq=None)
29943000

pandas/core/indexes/interval.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -1037,7 +1037,7 @@ def overlaps(self, other):
10371037
return self._data.overlaps(other)
10381038

10391039
def _setop(op_name):
1040-
def func(self, other):
1040+
def func(self, other, sort=True):
10411041
other = self._as_like_interval_index(other)
10421042

10431043
# GH 19016: ensure set op will not return a prohibited dtype
@@ -1048,7 +1048,11 @@ def func(self, other):
10481048
'objects that have compatible dtypes')
10491049
raise TypeError(msg.format(op=op_name))
10501050

1051-
result = getattr(self._multiindex, op_name)(other._multiindex)
1051+
if op_name == 'difference':
1052+
result = getattr(self._multiindex, op_name)(other._multiindex,
1053+
sort)
1054+
else:
1055+
result = getattr(self._multiindex, op_name)(other._multiindex)
10521056
result_name = get_op_result_name(self, other)
10531057

10541058
# GH 19101: ensure empty results have correct dtype

pandas/core/indexes/multi.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -2798,10 +2798,18 @@ def intersection(self, other):
27982798
return MultiIndex.from_arrays(lzip(*uniq_tuples), sortorder=0,
27992799
names=result_names)
28002800

2801-
def difference(self, other):
2801+
def difference(self, other, sort=True):
28022802
"""
28032803
Compute sorted set difference of two MultiIndex objects
28042804
2805+
Parameters
2806+
----------
2807+
other : MultiIndex
2808+
sort : bool, default True
2809+
Sort the resulting MultiIndex if possible
2810+
2811+
.. versionadded:: 0.24.0
2812+
28052813
Returns
28062814
-------
28072815
diff : MultiIndex
@@ -2817,8 +2825,16 @@ def difference(self, other):
28172825
labels=[[]] * self.nlevels,
28182826
names=result_names, verify_integrity=False)
28192827

2820-
difference = sorted(set(self._ndarray_values) -
2821-
set(other._ndarray_values))
2828+
this = self._get_unique_index()
2829+
2830+
indexer = this.get_indexer(other)
2831+
indexer = indexer.take((indexer != -1).nonzero()[0])
2832+
2833+
label_diff = np.setdiff1d(np.arange(this.size), indexer,
2834+
assume_unique=True)
2835+
difference = this.values.take(label_diff)
2836+
if sort:
2837+
difference = sorted(difference)
28222838

28232839
if len(difference) == 0:
28242840
return MultiIndex(levels=[[]] * self.nlevels,

pandas/tests/groupby/test_nth.py

+24
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,27 @@ def test_nth_empty():
390390
names=['a', 'b']),
391391
columns=['c'])
392392
assert_frame_equal(result, expected)
393+
394+
395+
def test_nth_column_order():
396+
# GH 20760
397+
# Check that nth preserves column order
398+
df = DataFrame([[1, 'b', 100],
399+
[1, 'a', 50],
400+
[1, 'a', np.nan],
401+
[2, 'c', 200],
402+
[2, 'd', 150]],
403+
columns=['A', 'C', 'B'])
404+
result = df.groupby('A').nth(0)
405+
expected = DataFrame([['b', 100.0],
406+
['c', 200.0]],
407+
columns=['C', 'B'],
408+
index=Index([1, 2], name='A'))
409+
assert_frame_equal(result, expected)
410+
411+
result = df.groupby('A').nth(-1, dropna='any')
412+
expected = DataFrame([['a', 50.0],
413+
['d', 150.0]],
414+
columns=['C', 'B'],
415+
index=Index([1, 2], name='A'))
416+
assert_frame_equal(result, expected)

pandas/tests/indexes/common.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -666,12 +666,13 @@ def test_union_base(self):
666666
with pytest.raises(TypeError, match=msg):
667667
first.union([1, 2, 3])
668668

669-
def test_difference_base(self):
669+
@pytest.mark.parametrize("sort", [True, False])
670+
def test_difference_base(self, sort):
670671
for name, idx in compat.iteritems(self.indices):
671672
first = idx[2:]
672673
second = idx[:4]
673674
answer = idx[4:]
674-
result = first.difference(second)
675+
result = first.difference(second, sort)
675676

676677
if isinstance(idx, CategoricalIndex):
677678
pass
@@ -685,21 +686,21 @@ def test_difference_base(self):
685686
if isinstance(idx, PeriodIndex):
686687
msg = "can only call with other PeriodIndex-ed objects"
687688
with pytest.raises(ValueError, match=msg):
688-
first.difference(case)
689+
first.difference(case, sort)
689690
elif isinstance(idx, CategoricalIndex):
690691
pass
691692
elif isinstance(idx, (DatetimeIndex, TimedeltaIndex)):
692693
assert result.__class__ == answer.__class__
693694
tm.assert_numpy_array_equal(result.sort_values().asi8,
694695
answer.sort_values().asi8)
695696
else:
696-
result = first.difference(case)
697+
result = first.difference(case, sort)
697698
assert tm.equalContents(result, answer)
698699

699700
if isinstance(idx, MultiIndex):
700701
msg = "other must be a MultiIndex or a list of tuples"
701702
with pytest.raises(TypeError, match=msg):
702-
first.difference([1, 2, 3])
703+
first.difference([1, 2, 3], sort)
703704

704705
def test_symmetric_difference(self):
705706
for name, idx in compat.iteritems(self.indices):

pandas/tests/indexes/datetimes/test_setops.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -209,47 +209,55 @@ def test_intersection_bug_1708(self):
209209
assert len(result) == 0
210210

211211
@pytest.mark.parametrize("tz", tz)
212-
def test_difference(self, tz):
213-
rng1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
212+
@pytest.mark.parametrize("sort", [True, False])
213+
def test_difference(self, tz, sort):
214+
rng_dates = ['1/2/2000', '1/3/2000', '1/1/2000', '1/4/2000',
215+
'1/5/2000']
216+
217+
rng1 = pd.DatetimeIndex(rng_dates, tz=tz)
214218
other1 = pd.date_range('1/6/2000', freq='D', periods=5, tz=tz)
215-
expected1 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
219+
expected1 = pd.DatetimeIndex(rng_dates, tz=tz)
216220

217-
rng2 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
221+
rng2 = pd.DatetimeIndex(rng_dates, tz=tz)
218222
other2 = pd.date_range('1/4/2000', freq='D', periods=5, tz=tz)
219-
expected2 = pd.date_range('1/1/2000', freq='D', periods=3, tz=tz)
223+
expected2 = pd.DatetimeIndex(rng_dates[:3], tz=tz)
220224

221-
rng3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
225+
rng3 = pd.DatetimeIndex(rng_dates, tz=tz)
222226
other3 = pd.DatetimeIndex([], tz=tz)
223-
expected3 = pd.date_range('1/1/2000', freq='D', periods=5, tz=tz)
227+
expected3 = pd.DatetimeIndex(rng_dates, tz=tz)
224228

225229
for rng, other, expected in [(rng1, other1, expected1),
226230
(rng2, other2, expected2),
227231
(rng3, other3, expected3)]:
228-
result_diff = rng.difference(other)
232+
result_diff = rng.difference(other, sort)
233+
if sort:
234+
expected = expected.sort_values()
229235
tm.assert_index_equal(result_diff, expected)
230236

231-
def test_difference_freq(self):
237+
@pytest.mark.parametrize("sort", [True, False])
238+
def test_difference_freq(self, sort):
232239
# GH14323: difference of DatetimeIndex should not preserve frequency
233240

234241
index = date_range("20160920", "20160925", freq="D")
235242
other = date_range("20160921", "20160924", freq="D")
236243
expected = DatetimeIndex(["20160920", "20160925"], freq=None)
237-
idx_diff = index.difference(other)
244+
idx_diff = index.difference(other, sort)
238245
tm.assert_index_equal(idx_diff, expected)
239246
tm.assert_attr_equal('freq', idx_diff, expected)
240247

241248
other = date_range("20160922", "20160925", freq="D")
242-
idx_diff = index.difference(other)
249+
idx_diff = index.difference(other, sort)
243250
expected = DatetimeIndex(["20160920", "20160921"], freq=None)
244251
tm.assert_index_equal(idx_diff, expected)
245252
tm.assert_attr_equal('freq', idx_diff, expected)
246253

247-
def test_datetimeindex_diff(self):
254+
@pytest.mark.parametrize("sort", [True, False])
255+
def test_datetimeindex_diff(self, sort):
248256
dti1 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
249257
periods=100)
250258
dti2 = DatetimeIndex(freq='Q-JAN', start=datetime(1997, 12, 31),
251259
periods=98)
252-
assert len(dti1.difference(dti2)) == 2
260+
assert len(dti1.difference(dti2, sort)) == 2
253261

254262
def test_datetimeindex_union_join_empty(self):
255263
dti = DatetimeIndex(start='1/1/2001', end='2/1/2001', freq='D')

pandas/tests/indexes/interval/test_interval.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -801,19 +801,26 @@ def test_intersection(self, closed):
801801
result = index.intersection(other)
802802
tm.assert_index_equal(result, expected)
803803

804-
def test_difference(self, closed):
805-
index = self.create_index(closed=closed)
806-
tm.assert_index_equal(index.difference(index[:1]), index[1:])
804+
@pytest.mark.parametrize("sort", [True, False])
805+
def test_difference(self, closed, sort):
806+
index = IntervalIndex.from_arrays([1, 0, 3, 2],
807+
[1, 2, 3, 4],
808+
closed=closed)
809+
result = index.difference(index[:1], sort)
810+
expected = index[1:]
811+
if sort:
812+
expected = expected.sort_values()
813+
tm.assert_index_equal(result, expected)
807814

808815
# GH 19101: empty result, same dtype
809-
result = index.difference(index)
816+
result = index.difference(index, sort)
810817
expected = IntervalIndex(np.array([], dtype='int64'), closed=closed)
811818
tm.assert_index_equal(result, expected)
812819

813820
# GH 19101: empty result, different dtypes
814821
other = IntervalIndex.from_arrays(index.left.astype('float64'),
815822
index.right, closed=closed)
816-
result = index.difference(other)
823+
result = index.difference(other, sort)
817824
tm.assert_index_equal(result, expected)
818825

819826
def test_symmetric_difference(self, closed):

0 commit comments

Comments
 (0)