diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 99b9aea4f82df..746e27c31f423 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -609,6 +609,23 @@ def _cython_operation( kind, values, how, axis, min_count, **kwargs ) + elif values.ndim == 1: + # expand to 2d, dispatch, then squeeze if appropriate + values2d = values[None, :] + res = self._cython_operation( + kind=kind, + values=values2d, + how=how, + axis=1, + min_count=min_count, + **kwargs, + ) + if res.shape[0] == 1: + return res[0] + + # otherwise we have OHLC + return res.T + is_datetimelike = needs_i8_conversion(dtype) if is_datetimelike: @@ -629,22 +646,20 @@ def _cython_operation( values = values.astype(object) arity = self._cython_arity.get(how, 1) + ngroups = self.ngroups - vdim = values.ndim - swapped = False - if vdim == 1: - values = values[:, None] - out_shape = (self.ngroups, arity) + assert axis == 1 + values = values.T + if how == "ohlc": + out_shape = (ngroups, 4) + elif arity > 1: + raise NotImplementedError( + "arity of more than 1 is not supported for the 'how' argument" + ) + elif kind == "transform": + out_shape = values.shape else: - if axis > 0: - swapped = True - assert axis == 1, axis - values = values.T - if arity > 1: - raise NotImplementedError( - "arity of more than 1 is not supported for the 'how' argument" - ) - out_shape = (self.ngroups,) + values.shape[1:] + out_shape = (ngroups,) + values.shape[1:] func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric) @@ -658,13 +673,11 @@ def _cython_operation( codes, _, _ = self.group_info + result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) if kind == "aggregate": - result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) counts = np.zeros(self.ngroups, dtype=np.int64) result = self._aggregate(result, counts, values, codes, func, min_count) elif kind == "transform": - result = maybe_fill(np.empty(values.shape, dtype=out_dtype)) - # TODO: min_count result = self._transform( result, values, codes, func, is_datetimelike, **kwargs @@ -680,11 +693,7 @@ def _cython_operation( assert result.ndim != 2 result = result[counts > 0] - if vdim == 1 and arity == 1: - result = result[:, 0] - - if swapped: - result = result.swapaxes(0, axis) + result = result.T if how not in base.cython_cast_blocklist: # e.g. if we are int64 and need to restore to datetime64/timedelta64