Skip to content

Commit 9764f3d

Browse files
authored
BUG: Fix agg ingore arg/kwargs when given list like func (#50863)
1 parent c4caed6 commit 9764f3d

File tree

7 files changed

+130
-6
lines changed

7 files changed

+130
-6
lines changed

doc/source/whatsnew/v2.0.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,7 @@ Groupby/resample/rolling
13451345
- Bug in :meth:`.DataFrameGroupBy.transform` and :meth:`.SeriesGroupBy.transform` would raise incorrectly when grouper had ``axis=1`` for ``"ngroup"`` argument (:issue:`45986`)
13461346
- Bug in :meth:`.DataFrameGroupBy.describe` produced incorrect results when data had duplicate columns (:issue:`50806`)
13471347
- Bug in :meth:`.DataFrameGroupBy.agg` with ``engine="numba"`` failing to respect ``as_index=False`` (:issue:`51228`)
1348+
- Bug in :meth:`DataFrameGroupBy.agg`, :meth:`SeriesGroupBy.agg`, and :meth:`Resampler.agg` would ignore arguments when passed a list of functions (:issue:`50863`)
13481349
-
13491350

13501351
Reshaping
@@ -1358,6 +1359,7 @@ Reshaping
13581359
- Clarified error message in :func:`merge` when passing invalid ``validate`` option (:issue:`49417`)
13591360
- Bug in :meth:`DataFrame.explode` raising ``ValueError`` on multiple columns with ``NaN`` values or empty lists (:issue:`46084`)
13601361
- Bug in :meth:`DataFrame.transpose` with ``IntervalDtype`` column with ``timedelta64[ns]`` endpoints (:issue:`44917`)
1362+
- Bug in :meth:`DataFrame.agg` and :meth:`Series.agg` would ignore arguments when passed a list of functions (:issue:`50863`)
13611363
-
13621364

13631365
Sparse

pandas/core/apply.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -332,19 +332,28 @@ def agg_list_like(self) -> DataFrame | Series:
332332

333333
for a in arg:
334334
colg = obj._gotitem(selected_obj.name, ndim=1, subset=selected_obj)
335-
new_res = colg.aggregate(a)
335+
if isinstance(colg, (ABCSeries, ABCDataFrame)):
336+
new_res = colg.aggregate(
337+
a, self.axis, *self.args, **self.kwargs
338+
)
339+
else:
340+
new_res = colg.aggregate(a, *self.args, **self.kwargs)
336341
results.append(new_res)
337342

338343
# make sure we find a good name
339344
name = com.get_callable_name(a) or a
340345
keys.append(name)
341346

342-
# multiples
343347
else:
344348
indices = []
345349
for index, col in enumerate(selected_obj):
346350
colg = obj._gotitem(col, ndim=1, subset=selected_obj.iloc[:, index])
347-
new_res = colg.aggregate(arg)
351+
if isinstance(colg, (ABCSeries, ABCDataFrame)):
352+
new_res = colg.aggregate(
353+
arg, self.axis, *self.args, **self.kwargs
354+
)
355+
else:
356+
new_res = colg.aggregate(arg, *self.args, **self.kwargs)
348357
results.append(new_res)
349358
indices.append(index)
350359
keys = selected_obj.columns.take(indices)

pandas/core/groupby/generic.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
236236
# Catch instances of lists / tuples
237237
# but not the class list / tuple itself.
238238
func = maybe_mangle_lambdas(func)
239-
ret = self._aggregate_multiple_funcs(func)
239+
ret = self._aggregate_multiple_funcs(func, *args, **kwargs)
240240
if relabeling:
241241
# columns is not narrowed by mypy from relabeling flag
242242
assert columns is not None # for mypy
@@ -268,7 +268,7 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
268268

269269
agg = aggregate
270270

271-
def _aggregate_multiple_funcs(self, arg) -> DataFrame:
271+
def _aggregate_multiple_funcs(self, arg, *args, **kwargs) -> DataFrame:
272272
if isinstance(arg, dict):
273273
if self.as_index:
274274
# GH 15931
@@ -293,7 +293,7 @@ def _aggregate_multiple_funcs(self, arg) -> DataFrame:
293293
for idx, (name, func) in enumerate(arg):
294294

295295
key = base.OutputKey(label=name, position=idx)
296-
results[key] = self.aggregate(func)
296+
results[key] = self.aggregate(func, *args, **kwargs)
297297

298298
if any(isinstance(x, DataFrame) for x in results.values()):
299299
from pandas import concat

pandas/tests/apply/test_frame_apply.py

+22
Original file line numberDiff line numberDiff line change
@@ -1623,3 +1623,25 @@ def test_any_apply_keyword_non_zero_axis_regression():
16231623

16241624
result = df.apply("any", 1)
16251625
tm.assert_series_equal(result, expected)
1626+
1627+
1628+
def test_agg_list_like_func_with_args():
1629+
# GH 50624
1630+
df = DataFrame({"x": [1, 2, 3]})
1631+
1632+
def foo1(x, a=1, c=0):
1633+
return x + a + c
1634+
1635+
def foo2(x, b=2, c=0):
1636+
return x + b + c
1637+
1638+
msg = r"foo1\(\) got an unexpected keyword argument 'b'"
1639+
with pytest.raises(TypeError, match=msg):
1640+
df.agg([foo1, foo2], 0, 3, b=3, c=4)
1641+
1642+
result = df.agg([foo1, foo2], 0, 3, c=4)
1643+
expected = DataFrame(
1644+
[[8, 8], [9, 9], [10, 10]],
1645+
columns=MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]),
1646+
)
1647+
tm.assert_frame_equal(result, expected)

pandas/tests/apply/test_series_apply.py

+20
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,26 @@ def f(x, a=0, b=0, c=0):
107107
tm.assert_series_equal(result, expected)
108108

109109

110+
def test_agg_list_like_func_with_args():
111+
# GH 50624
112+
113+
s = Series([1, 2, 3])
114+
115+
def foo1(x, a=1, c=0):
116+
return x + a + c
117+
118+
def foo2(x, b=2, c=0):
119+
return x + b + c
120+
121+
msg = r"foo1\(\) got an unexpected keyword argument 'b'"
122+
with pytest.raises(TypeError, match=msg):
123+
s.agg([foo1, foo2], 0, 3, b=3, c=4)
124+
125+
result = s.agg([foo1, foo2], 0, 3, c=4)
126+
expected = DataFrame({"foo1": [8, 9, 10], "foo2": [8, 9, 10]})
127+
tm.assert_frame_equal(result, expected)
128+
129+
110130
def test_series_map_box_timestamps():
111131
# GH#2689, GH#2627
112132
ser = Series(pd.date_range("1/1/2000", periods=10))

pandas/tests/groupby/aggregate/test_aggregate.py

+46
Original file line numberDiff line numberDiff line change
@@ -1470,3 +1470,49 @@ def test_agg_of_mode_list(test, constant):
14701470
expected = expected.set_index(0)
14711471

14721472
tm.assert_frame_equal(result, expected)
1473+
1474+
1475+
def test__dataframe_groupy_agg_list_like_func_with_args():
1476+
# GH 50624
1477+
df = DataFrame({"x": [1, 2, 3], "y": ["a", "b", "c"]})
1478+
gb = df.groupby("y")
1479+
1480+
def foo1(x, a=1, c=0):
1481+
return x.sum() + a + c
1482+
1483+
def foo2(x, b=2, c=0):
1484+
return x.sum() + b + c
1485+
1486+
msg = r"foo1\(\) got an unexpected keyword argument 'b'"
1487+
with pytest.raises(TypeError, match=msg):
1488+
gb.agg([foo1, foo2], 3, b=3, c=4)
1489+
1490+
result = gb.agg([foo1, foo2], 3, c=4)
1491+
expected = DataFrame(
1492+
[[8, 8], [9, 9], [10, 10]],
1493+
index=Index(["a", "b", "c"], name="y"),
1494+
columns=MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]),
1495+
)
1496+
tm.assert_frame_equal(result, expected)
1497+
1498+
1499+
def test__series_groupy_agg_list_like_func_with_args():
1500+
# GH 50624
1501+
s = Series([1, 2, 3])
1502+
sgb = s.groupby(s)
1503+
1504+
def foo1(x, a=1, c=0):
1505+
return x.sum() + a + c
1506+
1507+
def foo2(x, b=2, c=0):
1508+
return x.sum() + b + c
1509+
1510+
msg = r"foo1\(\) got an unexpected keyword argument 'b'"
1511+
with pytest.raises(TypeError, match=msg):
1512+
sgb.agg([foo1, foo2], 3, b=3, c=4)
1513+
1514+
result = sgb.agg([foo1, foo2], 3, c=4)
1515+
expected = DataFrame(
1516+
[[8, 8], [9, 9], [10, 10]], index=Index([1, 2, 3]), columns=["foo1", "foo2"]
1517+
)
1518+
tm.assert_frame_equal(result, expected)

pandas/tests/resample/test_resample_api.py

+25
Original file line numberDiff line numberDiff line change
@@ -633,6 +633,31 @@ def test_try_aggregate_non_existing_column():
633633
df.resample("30T").agg({"x": ["mean"], "y": ["median"], "z": ["sum"]})
634634

635635

636+
def test_agg_list_like_func_with_args():
637+
# 50624
638+
df = DataFrame(
639+
{"x": [1, 2, 3]}, index=date_range("2020-01-01", periods=3, freq="D")
640+
)
641+
642+
def foo1(x, a=1, c=0):
643+
return x + a + c
644+
645+
def foo2(x, b=2, c=0):
646+
return x + b + c
647+
648+
msg = r"foo1\(\) got an unexpected keyword argument 'b'"
649+
with pytest.raises(TypeError, match=msg):
650+
df.resample("D").agg([foo1, foo2], 3, b=3, c=4)
651+
652+
result = df.resample("D").agg([foo1, foo2], 3, c=4)
653+
expected = DataFrame(
654+
[[8, 8], [9, 9], [10, 10]],
655+
index=date_range("2020-01-01", periods=3, freq="D"),
656+
columns=pd.MultiIndex.from_tuples([("x", "foo1"), ("x", "foo2")]),
657+
)
658+
tm.assert_frame_equal(result, expected)
659+
660+
636661
def test_selection_api_validation():
637662
# GH 13500
638663
index = date_range(datetime(2005, 1, 1), datetime(2005, 1, 10), freq="D")

0 commit comments

Comments
 (0)