Skip to content

Commit 739e6be

Browse files
WillAydjreback
authored andcommitted
Fixed Issue Preventing Agg on RollingGroupBy Objects (#21323)
1 parent a393675 commit 739e6be

File tree

5 files changed

+67
-5
lines changed

5 files changed

+67
-5
lines changed

doc/source/whatsnew/v0.24.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -785,6 +785,7 @@ Groupby/Resample/Rolling
785785
- Bug in :meth:`Series.resample` when passing ``numpy.timedelta64`` to ``loffset`` kwarg (:issue:`7687`).
786786
- Bug in :meth:`Resampler.asfreq` when frequency of ``TimedeltaIndex`` is a subperiod of a new frequency (:issue:`13022`).
787787
- Bug in :meth:`SeriesGroupBy.mean` when values were integral but could not fit inside of int64, overflowing instead. (:issue:`22487`)
788+
- :func:`RollingGroupby.agg` and :func:`ExpandingGroupby.agg` now support multiple aggregation functions as parameters (:issue:`15072`)
788789

789790
Sparse
790791
^^^^^^

pandas/core/base.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -245,8 +245,8 @@ def _obj_with_exclusions(self):
245245

246246
def __getitem__(self, key):
247247
if self._selection is not None:
248-
raise Exception('Column(s) {selection} already selected'
249-
.format(selection=self._selection))
248+
raise IndexError('Column(s) {selection} already selected'
249+
.format(selection=self._selection))
250250

251251
if isinstance(key, (list, tuple, ABCSeries, ABCIndexClass,
252252
np.ndarray)):

pandas/core/groupby/base.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,15 @@ def _gotitem(self, key, ndim, subset=None):
4444
# we need to make a shallow copy of ourselves
4545
# with the same groupby
4646
kwargs = {attr: getattr(self, attr) for attr in self._attributes}
47+
48+
# Try to select from a DataFrame, falling back to a Series
49+
try:
50+
groupby = self._groupby[key]
51+
except IndexError:
52+
groupby = self._groupby
53+
4754
self = self.__class__(subset,
48-
groupby=self._groupby[key],
55+
groupby=groupby,
4956
parent=self,
5057
**kwargs)
5158
self._reset_cache()

pandas/tests/groupby/test_groupby.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,14 @@ def test_as_index_series_return_frame(df):
623623
assert isinstance(result2, DataFrame)
624624
assert_frame_equal(result2, expected2)
625625

626-
# corner case
627-
pytest.raises(Exception, grouped['C'].__getitem__, 'D')
626+
627+
def test_as_index_series_column_slice_raises(df):
628+
# GH15072
629+
grouped = df.groupby('A', as_index=False)
630+
msg = r"Column\(s\) C already selected"
631+
632+
with tm.assert_raises_regex(IndexError, msg):
633+
grouped['C'].__getitem__('D')
628634

629635

630636
def test_groupby_as_index_cython(df):

pandas/tests/test_window.py

+48
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import OrderedDict
12
from itertools import product
23
import pytest
34
import warnings
@@ -314,6 +315,53 @@ def test_preserve_metadata(self):
314315
assert s2.name == 'foo'
315316
assert s3.name == 'foo'
316317

318+
@pytest.mark.parametrize("func,window_size,expected_vals", [
319+
('rolling', 2, [[np.nan, np.nan, np.nan, np.nan],
320+
[15., 20., 25., 20.],
321+
[25., 30., 35., 30.],
322+
[np.nan, np.nan, np.nan, np.nan],
323+
[20., 30., 35., 30.],
324+
[35., 40., 60., 40.],
325+
[60., 80., 85., 80]]),
326+
('expanding', None, [[10., 10., 20., 20.],
327+
[15., 20., 25., 20.],
328+
[20., 30., 30., 20.],
329+
[10., 10., 30., 30.],
330+
[20., 30., 35., 30.],
331+
[26.666667, 40., 50., 30.],
332+
[40., 80., 60., 30.]])])
333+
def test_multiple_agg_funcs(self, func, window_size, expected_vals):
334+
# GH 15072
335+
df = pd.DataFrame([
336+
['A', 10, 20],
337+
['A', 20, 30],
338+
['A', 30, 40],
339+
['B', 10, 30],
340+
['B', 30, 40],
341+
['B', 40, 80],
342+
['B', 80, 90]], columns=['stock', 'low', 'high'])
343+
344+
f = getattr(df.groupby('stock'), func)
345+
if window_size:
346+
window = f(window_size)
347+
else:
348+
window = f()
349+
350+
index = pd.MultiIndex.from_tuples([
351+
('A', 0), ('A', 1), ('A', 2),
352+
('B', 3), ('B', 4), ('B', 5), ('B', 6)], names=['stock', None])
353+
columns = pd.MultiIndex.from_tuples([
354+
('low', 'mean'), ('low', 'max'), ('high', 'mean'),
355+
('high', 'min')])
356+
expected = pd.DataFrame(expected_vals, index=index, columns=columns)
357+
358+
result = window.agg(OrderedDict((
359+
('low', ['mean', 'max']),
360+
('high', ['mean', 'min']),
361+
)))
362+
363+
tm.assert_frame_equal(result, expected)
364+
317365

318366
@pytest.mark.filterwarnings("ignore:can't resolve package:ImportWarning")
319367
class TestWindow(Base):

0 commit comments

Comments
 (0)