Skip to content

Commit 7d73641

Browse files
authored
REF: use BlockManager.apply for DataFrameGroupBy.count (#35924)
1 parent 9dc1fd7 commit 7d73641

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

pandas/core/groupby/generic.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@
7777
from pandas.core.groupby.numba_ import generate_numba_func, split_for_numba
7878
from pandas.core.indexes.api import Index, MultiIndex, all_indexes_same
7979
import pandas.core.indexes.base as ibase
80-
from pandas.core.internals import BlockManager, make_block
80+
from pandas.core.internals import BlockManager
8181
from pandas.core.series import Series
8282
from pandas.core.util.numba_ import NUMBA_FUNC_CACHE, maybe_use_numba
8383

@@ -1750,20 +1750,24 @@ def count(self) -> DataFrame:
17501750
ids, _, ngroups = self.grouper.group_info
17511751
mask = ids != -1
17521752

1753-
# TODO(2DEA): reshape would not be necessary with 2D EAs
1754-
vals = ((mask & ~isna(blk.values).reshape(blk.shape)) for blk in data.blocks)
1755-
locs = (blk.mgr_locs for blk in data.blocks)
1753+
def hfunc(bvalues: ArrayLike) -> ArrayLike:
1754+
# TODO(2DEA): reshape would not be necessary with 2D EAs
1755+
if bvalues.ndim == 1:
1756+
# EA
1757+
masked = mask & ~isna(bvalues).reshape(1, -1)
1758+
else:
1759+
masked = mask & ~isna(bvalues)
17561760

1757-
counted = (
1758-
lib.count_level_2d(x, labels=ids, max_bin=ngroups, axis=1) for x in vals
1759-
)
1760-
blocks = [make_block(val, placement=loc) for val, loc in zip(counted, locs)]
1761+
counted = lib.count_level_2d(masked, labels=ids, max_bin=ngroups, axis=1)
1762+
return counted
1763+
1764+
new_mgr = data.apply(hfunc)
17611765

17621766
# If we are grouping on categoricals we want unobserved categories to
17631767
# return zero, rather than the default of NaN which the reindexing in
17641768
# _wrap_agged_blocks() returns. GH 35028
17651769
with com.temp_setattr(self, "observed", True):
1766-
result = self._wrap_agged_blocks(blocks, items=data.items)
1770+
result = self._wrap_agged_blocks(new_mgr.blocks, items=data.items)
17671771

17681772
return self._reindex_output(result, fill_value=0)
17691773

0 commit comments

Comments
 (0)