Skip to content

Commit b7929b8

Browse files
authored
refactor core/groupby (#37583)
1 parent 7759d0f commit b7929b8

File tree

5 files changed

+35
-42
lines changed

5 files changed

+35
-42
lines changed

pandas/core/groupby/base.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,8 @@ def _gotitem(self, key, ndim, subset=None):
6363

6464
self = type(self)(subset, groupby=groupby, parent=self, **kwargs)
6565
self._reset_cache()
66-
if subset.ndim == 2:
67-
if is_scalar(key) and key in subset or is_list_like(key):
68-
self._selection = key
66+
if subset.ndim == 2 and (is_scalar(key) and key in subset or is_list_like(key)):
67+
self._selection = key
6968
return self
7069

7170

pandas/core/groupby/generic.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -512,12 +512,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
512512
elif func not in base.transform_kernel_allowlist:
513513
msg = f"'{func}' is not a valid function name for transform(name)"
514514
raise ValueError(msg)
515-
elif func in base.cythonized_kernels:
515+
elif func in base.cythonized_kernels or func in base.transformation_kernels:
516516
# cythonized transform or canned "agg+broadcast"
517517
return getattr(self, func)(*args, **kwargs)
518-
elif func in base.transformation_kernels:
519-
return getattr(self, func)(*args, **kwargs)
520-
521518
# If func is a reduction, we need to broadcast the
522519
# result to the whole group. Compute func result
523520
# and deal with possible broadcasting below.
@@ -1111,8 +1108,7 @@ def blk_func(bvalues: ArrayLike) -> ArrayLike:
11111108
# unwrap DataFrame to get array
11121109
result = result._mgr.blocks[0].values
11131110

1114-
res_values = cast_agg_result(result, bvalues, how)
1115-
return res_values
1111+
return cast_agg_result(result, bvalues, how)
11161112

11171113
# TypeError -> we may have an exception in trying to aggregate
11181114
# continue and exclude the block
@@ -1368,12 +1364,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
13681364
elif func not in base.transform_kernel_allowlist:
13691365
msg = f"'{func}' is not a valid function name for transform(name)"
13701366
raise ValueError(msg)
1371-
elif func in base.cythonized_kernels:
1367+
elif func in base.cythonized_kernels or func in base.transformation_kernels:
13721368
# cythonized transformation or canned "reduction+broadcast"
13731369
return getattr(self, func)(*args, **kwargs)
1374-
elif func in base.transformation_kernels:
1375-
return getattr(self, func)(*args, **kwargs)
1376-
13771370
# GH 30918
13781371
# Use _transform_fast only when we know func is an aggregation
13791372
if func in base.reduction_kernels:
@@ -1401,9 +1394,10 @@ def _transform_fast(self, result: DataFrame) -> DataFrame:
14011394
# by take operation
14021395
ids, _, ngroup = self.grouper.group_info
14031396
result = result.reindex(self.grouper.result_index, copy=False)
1404-
output = []
1405-
for i, _ in enumerate(result.columns):
1406-
output.append(algorithms.take_1d(result.iloc[:, i].values, ids))
1397+
output = [
1398+
algorithms.take_1d(result.iloc[:, i].values, ids)
1399+
for i, _ in enumerate(result.columns)
1400+
]
14071401

14081402
return self.obj._constructor._from_arrays(
14091403
output, columns=result.columns, index=obj.index
@@ -1462,7 +1456,7 @@ def _transform_item_by_item(self, obj: DataFrame, wrapper) -> DataFrame:
14621456
else:
14631457
inds.append(i)
14641458

1465-
if len(output) == 0:
1459+
if not output:
14661460
raise TypeError("Transform function invalid for data types")
14671461

14681462
columns = obj.columns

pandas/core/groupby/groupby.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -1001,7 +1001,7 @@ def _cython_transform(self, how: str, numeric_only: bool = True, **kwargs):
10011001
key = base.OutputKey(label=name, position=idx)
10021002
output[key] = result
10031003

1004-
if len(output) == 0:
1004+
if not output:
10051005
raise DataError("No numeric types to aggregate")
10061006

10071007
return self._wrap_transformed_output(output)
@@ -1084,7 +1084,7 @@ def _cython_agg_general(
10841084
output[key] = maybe_cast_result(result, obj, how=how)
10851085
idx += 1
10861086

1087-
if len(output) == 0:
1087+
if not output:
10881088
raise DataError("No numeric types to aggregate")
10891089

10901090
return self._wrap_aggregated_output(output, index=self.grouper.result_index)
@@ -1182,7 +1182,7 @@ def _python_agg_general(self, func, *args, **kwargs):
11821182
key = base.OutputKey(label=name, position=idx)
11831183
output[key] = maybe_cast_result(result, obj, numeric_only=True)
11841184

1185-
if len(output) == 0:
1185+
if not output:
11861186
return self._python_apply_general(f, self._selected_obj)
11871187

11881188
if self.grouper._filter_empty_groups:
@@ -2550,9 +2550,8 @@ def _get_cythonized_result(
25502550
"""
25512551
if result_is_index and aggregate:
25522552
raise ValueError("'result_is_index' and 'aggregate' cannot both be True!")
2553-
if post_processing:
2554-
if not callable(post_processing):
2555-
raise ValueError("'post_processing' must be a callable!")
2553+
if post_processing and not callable(post_processing):
2554+
raise ValueError("'post_processing' must be a callable!")
25562555
if pre_processing:
25572556
if not callable(pre_processing):
25582557
raise ValueError("'pre_processing' must be a callable!")
@@ -2631,7 +2630,7 @@ def _get_cythonized_result(
26312630
output[key] = result
26322631

26332632
# error_msg is "" on an frame/series with no rows or columns
2634-
if len(output) == 0 and error_msg != "":
2633+
if not output and error_msg != "":
26352634
raise TypeError(error_msg)
26362635

26372636
if aggregate:

pandas/core/groupby/grouper.py

+18-16
Original file line numberDiff line numberDiff line change
@@ -593,23 +593,25 @@ def group_index(self) -> Index:
593593
return self._group_index
594594

595595
def _make_codes(self) -> None:
596-
if self._codes is None or self._group_index is None:
597-
# we have a list of groupers
598-
if isinstance(self.grouper, ops.BaseGrouper):
599-
codes = self.grouper.codes_info
600-
uniques = self.grouper.result_index
596+
if self._codes is not None and self._group_index is not None:
597+
return
598+
599+
# we have a list of groupers
600+
if isinstance(self.grouper, ops.BaseGrouper):
601+
codes = self.grouper.codes_info
602+
uniques = self.grouper.result_index
603+
else:
604+
# GH35667, replace dropna=False with na_sentinel=None
605+
if not self.dropna:
606+
na_sentinel = None
601607
else:
602-
# GH35667, replace dropna=False with na_sentinel=None
603-
if not self.dropna:
604-
na_sentinel = None
605-
else:
606-
na_sentinel = -1
607-
codes, uniques = algorithms.factorize(
608-
self.grouper, sort=self.sort, na_sentinel=na_sentinel
609-
)
610-
uniques = Index(uniques, name=self.name)
611-
self._codes = codes
612-
self._group_index = uniques
608+
na_sentinel = -1
609+
codes, uniques = algorithms.factorize(
610+
self.grouper, sort=self.sort, na_sentinel=na_sentinel
611+
)
612+
uniques = Index(uniques, name=self.name)
613+
self._codes = codes
614+
self._group_index = uniques
613615

614616
@cache_readonly
615617
def groups(self) -> Dict[Hashable, np.ndarray]:

pandas/core/groupby/ops.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -322,10 +322,9 @@ def result_index(self) -> Index:
322322

323323
codes = self.reconstructed_codes
324324
levels = [ping.result_index for ping in self.groupings]
325-
result = MultiIndex(
325+
return MultiIndex(
326326
levels=levels, codes=codes, verify_integrity=False, names=self.names
327327
)
328-
return result
329328

330329
def get_group_levels(self) -> List[Index]:
331330
if not self.compressed and len(self.groupings) == 1:

0 commit comments

Comments
 (0)