Skip to content

Commit c95b9e2

Browse files
authored
CLN: simplify groupby wrapping (#51259)
* CLN: simplify groupby wrapping * mypy fixup
1 parent 9bbb1c0 commit c95b9e2

File tree

2 files changed

+67
-65
lines changed

2 files changed

+67
-65
lines changed

pandas/core/groupby/generic.py

+19-22
Original file line numberDiff line numberDiff line change
@@ -219,16 +219,9 @@ def apply(self, func, *args, **kwargs) -> Series:
219219
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
220220

221221
if maybe_use_numba(engine):
222-
data = self._obj_with_exclusions
223-
result = self._aggregate_with_numba(
224-
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
222+
return self._aggregate_with_numba(
223+
func, *args, engine_kwargs=engine_kwargs, **kwargs
225224
)
226-
index = self.grouper.result_index
227-
result = self.obj._constructor(result.ravel(), index=index, name=data.name)
228-
if not self.as_index:
229-
result = self._insert_inaxis_grouper(result)
230-
result.index = default_index(len(result))
231-
return result
232225

233226
relabeling = func is None
234227
columns = None
@@ -1261,16 +1254,9 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
12611254
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):
12621255

12631256
if maybe_use_numba(engine):
1264-
data = self._obj_with_exclusions
1265-
result = self._aggregate_with_numba(
1266-
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
1257+
return self._aggregate_with_numba(
1258+
func, *args, engine_kwargs=engine_kwargs, **kwargs
12671259
)
1268-
index = self.grouper.result_index
1269-
result = self.obj._constructor(result, index=index, columns=data.columns)
1270-
if not self.as_index:
1271-
result = self._insert_inaxis_grouper(result)
1272-
result.index = default_index(len(result))
1273-
return result
12741260

12751261
relabeling, func, columns, order = reconstruct_func(func, **kwargs)
12761262
func = maybe_mangle_lambdas(func)
@@ -1283,7 +1269,12 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
12831269
# this should be the only (non-raising) case with relabeling
12841270
# used reordered index of columns
12851271
result = result.iloc[:, order]
1286-
result.columns = columns
1272+
result = cast(DataFrame, result)
1273+
# error: Incompatible types in assignment (expression has type
1274+
# "Optional[List[str]]", variable has type
1275+
# "Union[Union[Union[ExtensionArray, ndarray[Any, Any]],
1276+
# Index, Series], Sequence[Any]]")
1277+
result.columns = columns # type: ignore[assignment]
12871278

12881279
if result is None:
12891280

@@ -1312,11 +1303,18 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
13121303
except ValueError as err:
13131304
if "No objects to concatenate" not in str(err):
13141305
raise
1306+
# _aggregate_frame can fail with e.g. func=Series.mode,
1307+
# where it expects 1D values but would be getting 2D values
1308+
# In other tests, using aggregate_frame instead of GroupByApply
1309+
# would give correct values but incorrect dtypes
1310+
# object vs float64 in test_cython_agg_empty_buckets
1311+
# float64 vs int64 in test_category_order_apply
13151312
result = self._aggregate_frame(func)
13161313

13171314
else:
13181315
# GH#32040, GH#35246
13191316
# e.g. test_groupby_as_index_select_column_sum_empty_df
1317+
result = cast(DataFrame, result)
13201318
result.columns = self._obj_with_exclusions.columns.copy()
13211319

13221320
if not self.as_index:
@@ -1502,8 +1500,7 @@ def arr_func(bvalues: ArrayLike) -> ArrayLike:
15021500
res_mgr.set_axis(1, mgr.axes[1])
15031501

15041502
res_df = self.obj._constructor(res_mgr)
1505-
if self.axis == 1:
1506-
res_df = res_df.T
1503+
res_df = self._maybe_transpose_result(res_df)
15071504
return res_df
15081505

15091506
def _transform_general(self, func, *args, **kwargs):
@@ -1830,7 +1827,7 @@ def _iterate_column_groupbys(self, obj: DataFrame | Series):
18301827
observed=self.observed,
18311828
)
18321829

1833-
def _apply_to_column_groupbys(self, func, obj: DataFrame | Series) -> DataFrame:
1830+
def _apply_to_column_groupbys(self, func, obj: DataFrame) -> DataFrame:
18341831
from pandas.core.reshape.concat import concat
18351832

18361833
columns = obj.columns

pandas/core/groupby/groupby.py

+48-43
Original file line numberDiff line numberDiff line change
@@ -1008,17 +1008,8 @@ def _concat_objects(
10081008
):
10091009
from pandas.core.reshape.concat import concat
10101010

1011-
def reset_identity(values):
1012-
# reset the identities of the components
1013-
# of the values to prevent aliasing
1014-
for v in com.not_none(*values):
1015-
ax = v._get_axis(self.axis)
1016-
ax._reset_identity()
1017-
return values
1018-
10191011
if self.group_keys and not is_transform:
10201012

1021-
values = reset_identity(values)
10221013
if self.as_index:
10231014

10241015
# possible MI return case
@@ -1063,7 +1054,6 @@ def reset_identity(values):
10631054
result = result.reindex(ax, axis=self.axis, copy=False)
10641055

10651056
else:
1066-
values = reset_identity(values)
10671057
result = concat(values, axis=self.axis)
10681058

10691059
name = self.obj.name if self.obj.ndim == 1 else self._selection
@@ -1123,6 +1113,17 @@ def _indexed_output_to_ndframe(
11231113
) -> Series | DataFrame:
11241114
raise AbstractMethodError(self)
11251115

1116+
@final
1117+
def _maybe_transpose_result(self, result: NDFrameT) -> NDFrameT:
1118+
if self.axis == 1:
1119+
# Only relevant for DataFrameGroupBy, no-op for SeriesGroupBy
1120+
result = result.T
1121+
if result.index.equals(self.obj.index):
1122+
# Retain e.g. DatetimeIndex/TimedeltaIndex freq
1123+
# e.g. test_groupby_crash_on_nunique
1124+
result.index = self.obj.index.copy()
1125+
return result
1126+
11261127
@final
11271128
def _wrap_aggregated_output(
11281129
self,
@@ -1160,15 +1161,10 @@ def _wrap_aggregated_output(
11601161

11611162
result.index = index
11621163

1163-
if self.axis == 1:
1164-
# Only relevant for DataFrameGroupBy, no-op for SeriesGroupBy
1165-
result = result.T
1166-
if result.index.equals(self.obj.index):
1167-
# Retain e.g. DatetimeIndex/TimedeltaIndex freq
1168-
result.index = self.obj.index.copy()
1169-
# TODO: Do this more systematically
1170-
1171-
return self._reindex_output(result, qs=qs)
1164+
# error: Argument 1 to "_maybe_transpose_result" of "GroupBy" has
1165+
# incompatible type "Union[Series, DataFrame]"; expected "NDFrameT"
1166+
res = self._maybe_transpose_result(result) # type: ignore[arg-type]
1167+
return self._reindex_output(res, qs=qs)
11721168

11731169
def _wrap_applied_output(
11741170
self,
@@ -1242,17 +1238,18 @@ def _numba_agg_general(
12421238
return data._constructor(result, index=index, **result_kwargs)
12431239

12441240
@final
1245-
def _transform_with_numba(
1246-
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
1247-
):
1241+
def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
12481242
"""
12491243
Perform groupby transform routine with the numba engine.
12501244
12511245
This routine mimics the data splitting routine of the DataSplitter class
12521246
to generate the indices of each group in the sorted data and then passes the
12531247
data and indices into a Numba jitted function.
12541248
"""
1255-
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
1249+
data = self._obj_with_exclusions
1250+
df = data if data.ndim == 2 else data.to_frame()
1251+
1252+
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
12561253
numba_.validate_udf(func)
12571254
numba_transform_func = numba_.generate_numba_transform_func(
12581255
func, **get_jit_arguments(engine_kwargs, kwargs)
@@ -1262,25 +1259,33 @@ def _transform_with_numba(
12621259
sorted_index,
12631260
starts,
12641261
ends,
1265-
len(data.columns),
1262+
len(df.columns),
12661263
*args,
12671264
)
12681265
# result values needs to be resorted to their original positions since we
12691266
# evaluated the data sorted by group
1270-
return result.take(np.argsort(sorted_index), axis=0)
1267+
result = result.take(np.argsort(sorted_index), axis=0)
1268+
index = data.index
1269+
if data.ndim == 1:
1270+
result_kwargs = {"name": data.name}
1271+
result = result.ravel()
1272+
else:
1273+
result_kwargs = {"columns": data.columns}
1274+
return data._constructor(result, index=index, **result_kwargs)
12711275

12721276
@final
1273-
def _aggregate_with_numba(
1274-
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
1275-
):
1277+
def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
12761278
"""
12771279
Perform groupby aggregation routine with the numba engine.
12781280
12791281
This routine mimics the data splitting routine of the DataSplitter class
12801282
to generate the indices of each group in the sorted data and then passes the
12811283
data and indices into a Numba jitted function.
12821284
"""
1283-
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
1285+
data = self._obj_with_exclusions
1286+
df = data if data.ndim == 2 else data.to_frame()
1287+
1288+
starts, ends, sorted_index, sorted_data = self._numba_prep(df)
12841289
numba_.validate_udf(func)
12851290
numba_agg_func = numba_.generate_numba_agg_func(
12861291
func, **get_jit_arguments(engine_kwargs, kwargs)
@@ -1290,10 +1295,20 @@ def _aggregate_with_numba(
12901295
sorted_index,
12911296
starts,
12921297
ends,
1293-
len(data.columns),
1298+
len(df.columns),
12941299
*args,
12951300
)
1296-
return result
1301+
index = self.grouper.result_index
1302+
if data.ndim == 1:
1303+
result_kwargs = {"name": data.name}
1304+
result = result.ravel()
1305+
else:
1306+
result_kwargs = {"columns": data.columns}
1307+
res = data._constructor(result, index=index, **result_kwargs)
1308+
if not self.as_index:
1309+
res = self._insert_inaxis_grouper(res)
1310+
res.index = default_index(len(res))
1311+
return res
12971312

12981313
# -----------------------------------------------------------------
12991314
# apply/agg/transform
@@ -1536,19 +1551,9 @@ def _cython_transform(
15361551
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
15371552

15381553
if maybe_use_numba(engine):
1539-
data = self._obj_with_exclusions
1540-
df = data if data.ndim == 2 else data.to_frame()
1541-
result = self._transform_with_numba(
1542-
df, func, *args, engine_kwargs=engine_kwargs, **kwargs
1554+
return self._transform_with_numba(
1555+
func, *args, engine_kwargs=engine_kwargs, **kwargs
15431556
)
1544-
if self.obj.ndim == 2:
1545-
return cast(DataFrame, self.obj)._constructor(
1546-
result, index=data.index, columns=data.columns
1547-
)
1548-
else:
1549-
return cast(Series, self.obj)._constructor(
1550-
result.ravel(), index=data.index, name=data.name
1551-
)
15521557

15531558
# optimized transforms
15541559
func = com.get_cython_func(func) or func

0 commit comments

Comments
 (0)