Skip to content

Commit 188ce73

Browse files
authored
REF: Avoid post-processing in blockwise op (#35356)
1 parent a9cb64a commit 188ce73

File tree

1 file changed

+47
-50
lines changed

1 file changed

+47
-50
lines changed

pandas/core/groupby/generic.py

+47-50
Original file line numberDiff line numberDiff line change
@@ -1029,11 +1029,36 @@ def _cython_agg_blocks(
10291029
agg_blocks: List[Block] = []
10301030
new_items: List[np.ndarray] = []
10311031
deleted_items: List[np.ndarray] = []
1032-
# Some object-dtype blocks might be split into List[Block[T], Block[U]]
1033-
split_items: List[np.ndarray] = []
1034-
split_frames: List[DataFrame] = []
10351032

10361033
no_result = object()
1034+
1035+
def cast_result_block(result, block: "Block", how: str) -> "Block":
1036+
# see if we can cast the block to the desired dtype
1037+
# this may not be the original dtype
1038+
assert not isinstance(result, DataFrame)
1039+
assert result is not no_result
1040+
1041+
dtype = maybe_cast_result_dtype(block.dtype, how)
1042+
result = maybe_downcast_numeric(result, dtype)
1043+
1044+
if block.is_extension and isinstance(result, np.ndarray):
1045+
# e.g. block.values was an IntegerArray
1046+
# (1, N) case can occur if block.values was Categorical
1047+
# and result is ndarray[object]
1048+
# TODO(EA2D): special casing not needed with 2D EAs
1049+
assert result.ndim == 1 or result.shape[0] == 1
1050+
try:
1051+
# Cast back if feasible
1052+
result = type(block.values)._from_sequence(
1053+
result.ravel(), dtype=block.values.dtype
1054+
)
1055+
except (ValueError, TypeError):
1056+
# reshape to be valid for non-Extension Block
1057+
result = result.reshape(1, -1)
1058+
1059+
agg_block: Block = block.make_block(result)
1060+
return agg_block
1061+
10371062
for block in data.blocks:
10381063
# Avoid inheriting result from earlier in the loop
10391064
result = no_result
@@ -1065,9 +1090,9 @@ def _cython_agg_blocks(
10651090
# not try to add missing categories if grouping over multiple
10661091
# Categoricals. This will done by later self._reindex_output()
10671092
# Doing it here creates an error. See GH#34951
1068-
s = get_groupby(obj, self.grouper, observed=True)
1093+
sgb = get_groupby(obj, self.grouper, observed=True)
10691094
try:
1070-
result = s.aggregate(lambda x: alt(x, axis=self.axis))
1095+
result = sgb.aggregate(lambda x: alt(x, axis=self.axis))
10711096
except TypeError:
10721097
# we may have an exception in trying to aggregate
10731098
# continue and exclude the block
@@ -1081,54 +1106,26 @@ def _cython_agg_blocks(
10811106
# about a single block input returning a single block output
10821107
# is a lie. To keep the code-path for the typical non-split case
10831108
# clean, we choose to clean up this mess later on.
1084-
split_items.append(locs)
1085-
split_frames.append(result)
1086-
continue
1087-
1088-
assert len(result._mgr.blocks) == 1
1089-
result = result._mgr.blocks[0].values
1090-
if isinstance(result, np.ndarray) and result.ndim == 1:
1091-
result = result.reshape(1, -1)
1092-
1093-
assert not isinstance(result, DataFrame)
1094-
1095-
if result is not no_result:
1096-
# see if we can cast the block to the desired dtype
1097-
# this may not be the original dtype
1098-
dtype = maybe_cast_result_dtype(block.dtype, how)
1099-
result = maybe_downcast_numeric(result, dtype)
1100-
1101-
if block.is_extension and isinstance(result, np.ndarray):
1102-
# e.g. block.values was an IntegerArray
1103-
# (1, N) case can occur if block.values was Categorical
1104-
# and result is ndarray[object]
1105-
# TODO(EA2D): special casing not needed with 2D EAs
1106-
assert result.ndim == 1 or result.shape[0] == 1
1107-
try:
1108-
# Cast back if feasible
1109-
result = type(block.values)._from_sequence(
1110-
result.ravel(), dtype=block.values.dtype
1111-
)
1112-
except (ValueError, TypeError):
1113-
# reshape to be valid for non-Extension Block
1114-
result = result.reshape(1, -1)
1115-
1116-
agg_block: Block = block.make_block(result)
1117-
1118-
new_items.append(locs)
1119-
agg_blocks.append(agg_block)
1109+
assert len(locs) == result.shape[1]
1110+
for i, loc in enumerate(locs):
1111+
new_items.append(np.array([loc], dtype=locs.dtype))
1112+
agg_block = result.iloc[:, [i]]._mgr.blocks[0]
1113+
agg_blocks.append(agg_block)
1114+
else:
1115+
result = result._mgr.blocks[0].values
1116+
if isinstance(result, np.ndarray) and result.ndim == 1:
1117+
result = result.reshape(1, -1)
1118+
agg_block = cast_result_block(result, block, how)
1119+
new_items.append(locs)
1120+
agg_blocks.append(agg_block)
1121+
else:
1122+
agg_block = cast_result_block(result, block, how)
1123+
new_items.append(locs)
1124+
agg_blocks.append(agg_block)
11201125

1121-
if not (agg_blocks or split_frames):
1126+
if not agg_blocks:
11221127
raise DataError("No numeric types to aggregate")
11231128

1124-
if split_items:
1125-
# Clean up the mess left over from split blocks.
1126-
for locs, result in zip(split_items, split_frames):
1127-
assert len(locs) == result.shape[1]
1128-
for i, loc in enumerate(locs):
1129-
new_items.append(np.array([loc], dtype=locs.dtype))
1130-
agg_blocks.append(result.iloc[:, [i]]._mgr.blocks[0])
1131-
11321129
# reset the locs in the blocks to correspond to our
11331130
# current ordering
11341131
indexer = np.concatenate(new_items)

0 commit comments

Comments
 (0)