Skip to content

CLN: simplify groupby wrapping #51259

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,16 +219,9 @@ def apply(self, func, *args, **kwargs) -> Series:
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
data = self._obj_with_exclusions
result = self._aggregate_with_numba(
data.to_frame(), func, *args, engine_kwargs=engine_kwargs, **kwargs
return self._aggregate_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)
index = self.grouper.result_index
result = self.obj._constructor(result.ravel(), index=index, name=data.name)
if not self.as_index:
result = self._insert_inaxis_grouper(result)
result.index = default_index(len(result))
return result

relabeling = func is None
columns = None
Expand Down Expand Up @@ -1261,16 +1254,9 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
data = self._obj_with_exclusions
result = self._aggregate_with_numba(
data, func, *args, engine_kwargs=engine_kwargs, **kwargs
return self._aggregate_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)
index = self.grouper.result_index
result = self.obj._constructor(result, index=index, columns=data.columns)
if not self.as_index:
result = self._insert_inaxis_grouper(result)
result.index = default_index(len(result))
return result

relabeling, func, columns, order = reconstruct_func(func, **kwargs)
func = maybe_mangle_lambdas(func)
Expand All @@ -1283,7 +1269,12 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
# this should be the only (non-raising) case with relabeling
# used reordered index of columns
result = result.iloc[:, order]
result.columns = columns
result = cast(DataFrame, result)
# error: Incompatible types in assignment (expression has type
# "Optional[List[str]]", variable has type
# "Union[Union[Union[ExtensionArray, ndarray[Any, Any]],
# Index, Series], Sequence[Any]]")
result.columns = columns # type: ignore[assignment]

if result is None:

Expand Down Expand Up @@ -1312,11 +1303,18 @@ def aggregate(self, func=None, *args, engine=None, engine_kwargs=None, **kwargs)
except ValueError as err:
if "No objects to concatenate" not in str(err):
raise
# _aggregate_frame can fail with e.g. func=Series.mode,
# where it expects 1D values but would be getting 2D values
# In other tests, using aggregate_frame instead of GroupByApply
# would give correct values but incorrect dtypes
# object vs float64 in test_cython_agg_empty_buckets
# float64 vs int64 in test_category_order_apply
result = self._aggregate_frame(func)

else:
# GH#32040, GH#35246
# e.g. test_groupby_as_index_select_column_sum_empty_df
result = cast(DataFrame, result)
result.columns = self._obj_with_exclusions.columns.copy()

if not self.as_index:
Expand Down Expand Up @@ -1502,8 +1500,7 @@ def arr_func(bvalues: ArrayLike) -> ArrayLike:
res_mgr.set_axis(1, mgr.axes[1])

res_df = self.obj._constructor(res_mgr)
if self.axis == 1:
res_df = res_df.T
res_df = self._maybe_transpose_result(res_df)
return res_df

def _transform_general(self, func, *args, **kwargs):
Expand Down Expand Up @@ -1830,7 +1827,7 @@ def _iterate_column_groupbys(self, obj: DataFrame | Series):
observed=self.observed,
)

def _apply_to_column_groupbys(self, func, obj: DataFrame | Series) -> DataFrame:
def _apply_to_column_groupbys(self, func, obj: DataFrame) -> DataFrame:
from pandas.core.reshape.concat import concat

columns = obj.columns
Expand Down
91 changes: 48 additions & 43 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,17 +1008,8 @@ def _concat_objects(
):
from pandas.core.reshape.concat import concat

def reset_identity(values):
# reset the identities of the components
# of the values to prevent aliasing
for v in com.not_none(*values):
ax = v._get_axis(self.axis)
ax._reset_identity()
return values

if self.group_keys and not is_transform:

values = reset_identity(values)
if self.as_index:

# possible MI return case
Expand Down Expand Up @@ -1063,7 +1054,6 @@ def reset_identity(values):
result = result.reindex(ax, axis=self.axis, copy=False)

else:
values = reset_identity(values)
result = concat(values, axis=self.axis)

name = self.obj.name if self.obj.ndim == 1 else self._selection
Expand Down Expand Up @@ -1123,6 +1113,17 @@ def _indexed_output_to_ndframe(
) -> Series | DataFrame:
raise AbstractMethodError(self)

@final
def _maybe_transpose_result(self, result: NDFrameT) -> NDFrameT:
if self.axis == 1:
# Only relevant for DataFrameGroupBy, no-op for SeriesGroupBy
result = result.T
if result.index.equals(self.obj.index):
# Retain e.g. DatetimeIndex/TimedeltaIndex freq
# e.g. test_groupby_crash_on_nunique
result.index = self.obj.index.copy()
return result

@final
def _wrap_aggregated_output(
self,
Expand Down Expand Up @@ -1160,15 +1161,10 @@ def _wrap_aggregated_output(

result.index = index

if self.axis == 1:
# Only relevant for DataFrameGroupBy, no-op for SeriesGroupBy
result = result.T
if result.index.equals(self.obj.index):
# Retain e.g. DatetimeIndex/TimedeltaIndex freq
result.index = self.obj.index.copy()
# TODO: Do this more systematically

return self._reindex_output(result, qs=qs)
# error: Argument 1 to "_maybe_transpose_result" of "GroupBy" has
# incompatible type "Union[Series, DataFrame]"; expected "NDFrameT"
res = self._maybe_transpose_result(result) # type: ignore[arg-type]
return self._reindex_output(res, qs=qs)

def _wrap_applied_output(
self,
Expand Down Expand Up @@ -1242,17 +1238,18 @@ def _numba_agg_general(
return data._constructor(result, index=index, **result_kwargs)

@final
def _transform_with_numba(
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
):
def _transform_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
"""
Perform groupby transform routine with the numba engine.

This routine mimics the data splitting routine of the DataSplitter class
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
data = self._obj_with_exclusions
df = data if data.ndim == 2 else data.to_frame()

starts, ends, sorted_index, sorted_data = self._numba_prep(df)
numba_.validate_udf(func)
numba_transform_func = numba_.generate_numba_transform_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
Expand All @@ -1262,25 +1259,33 @@ def _transform_with_numba(
sorted_index,
starts,
ends,
len(data.columns),
len(df.columns),
*args,
)
# result values needs to be resorted to their original positions since we
# evaluated the data sorted by group
return result.take(np.argsort(sorted_index), axis=0)
result = result.take(np.argsort(sorted_index), axis=0)
index = data.index
if data.ndim == 1:
result_kwargs = {"name": data.name}
result = result.ravel()
else:
result_kwargs = {"columns": data.columns}
return data._constructor(result, index=index, **result_kwargs)

@final
def _aggregate_with_numba(
self, data: DataFrame, func, *args, engine_kwargs=None, **kwargs
):
def _aggregate_with_numba(self, func, *args, engine_kwargs=None, **kwargs):
"""
Perform groupby aggregation routine with the numba engine.

This routine mimics the data splitting routine of the DataSplitter class
to generate the indices of each group in the sorted data and then passes the
data and indices into a Numba jitted function.
"""
starts, ends, sorted_index, sorted_data = self._numba_prep(data)
data = self._obj_with_exclusions
df = data if data.ndim == 2 else data.to_frame()

starts, ends, sorted_index, sorted_data = self._numba_prep(df)
numba_.validate_udf(func)
numba_agg_func = numba_.generate_numba_agg_func(
func, **get_jit_arguments(engine_kwargs, kwargs)
Expand All @@ -1290,10 +1295,20 @@ def _aggregate_with_numba(
sorted_index,
starts,
ends,
len(data.columns),
len(df.columns),
*args,
)
return result
index = self.grouper.result_index
if data.ndim == 1:
result_kwargs = {"name": data.name}
result = result.ravel()
else:
result_kwargs = {"columns": data.columns}
res = data._constructor(result, index=index, **result_kwargs)
if not self.as_index:
res = self._insert_inaxis_grouper(res)
res.index = default_index(len(res))
return res

# -----------------------------------------------------------------
# apply/agg/transform
Expand Down Expand Up @@ -1536,19 +1551,9 @@ def _cython_transform(
def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):

if maybe_use_numba(engine):
data = self._obj_with_exclusions
df = data if data.ndim == 2 else data.to_frame()
result = self._transform_with_numba(
df, func, *args, engine_kwargs=engine_kwargs, **kwargs
return self._transform_with_numba(
func, *args, engine_kwargs=engine_kwargs, **kwargs
)
if self.obj.ndim == 2:
return cast(DataFrame, self.obj)._constructor(
result, index=data.index, columns=data.columns
)
else:
return cast(Series, self.obj)._constructor(
result.ravel(), index=data.index, name=data.name
)

# optimized transforms
func = com.get_cython_func(func) or func
Expand Down