Skip to content

Commit f9c29c8

Browse files
authored
feat: support dict param for dataframe.agg() (#1772)
* feat: support dict param for dataframe.agg() * fix lint * add more tests * fix lint
1 parent e5fe143 commit f9c29c8

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

bigframes/dataframe.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2925,9 +2925,23 @@ def nunique(self) -> bigframes.series.Series:
29252925
return bigframes.series.Series(block)
29262926

29272927
def agg(
2928-
self, func: str | typing.Sequence[str]
2928+
self,
2929+
func: str
2930+
| typing.Sequence[str]
2931+
| typing.Mapping[blocks.Label, typing.Sequence[str] | str],
29292932
) -> DataFrame | bigframes.series.Series:
2930-
if utils.is_list_like(func):
2933+
if utils.is_dict_like(func):
2934+
# Must check dict-like first because dictionaries are list-like
2935+
# according to Pandas.
2936+
agg_cols = []
2937+
for col_label, agg_func in func.items():
2938+
agg_cols.append(self[col_label].agg(agg_func))
2939+
2940+
from bigframes.core.reshape import api as reshape
2941+
2942+
return reshape.concat(agg_cols, axis=1)
2943+
2944+
elif utils.is_list_like(func):
29312945
aggregations = [agg_ops.lookup_agg_func(f) for f in func]
29322946

29332947
for dtype, agg in itertools.product(self.dtypes, aggregations):
@@ -2941,6 +2955,7 @@ def agg(
29412955
aggregations,
29422956
)
29432957
)
2958+
29442959
else:
29452960
return bigframes.series.Series(
29462961
self._block.aggregate_all_and_stack(

tests/system/small/test_dataframe.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5652,3 +5652,29 @@ def test_astype_invalid_type_fail(scalars_dfs):
56525652

56535653
with pytest.raises(TypeError, match=r".*Share your usecase with.*"):
56545654
bf_df.astype(123)
5655+
5656+
5657+
def test_agg_with_dict(scalars_dfs):
5658+
bf_df, pd_df = scalars_dfs
5659+
agg_funcs = {
5660+
"int64_too": ["min", "max"],
5661+
"int64_col": ["min", "count"],
5662+
}
5663+
5664+
bf_result = bf_df.agg(agg_funcs).to_pandas()
5665+
pd_result = pd_df.agg(agg_funcs)
5666+
5667+
pd.testing.assert_frame_equal(
5668+
bf_result, pd_result, check_dtype=False, check_index_type=False
5669+
)
5670+
5671+
5672+
def test_agg_with_dict_containing_non_existing_col_raise_key_error(scalars_dfs):
5673+
bf_df, _ = scalars_dfs
5674+
agg_funcs = {
5675+
"int64_too": ["min", "max"],
5676+
"nonexisting_col": ["count"],
5677+
}
5678+
5679+
with pytest.raises(KeyError):
5680+
bf_df.agg(agg_funcs)

0 commit comments

Comments
 (0)