diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index e438db6c620ec..b59c7b4cb1cea 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -343,7 +343,7 @@ def get_group_levels(self): _cython_arity = {"ohlc": 4} # OHLC - _name_functions = {"ohlc": lambda *args: ["open", "high", "low", "close"]} + _name_functions = {"ohlc": ["open", "high", "low", "close"]} def _is_builtin_func(self, arg): """ @@ -399,6 +399,13 @@ def _cython_operation( assert kind in ["transform", "aggregate"] orig_values = values + if values.ndim > 2: + raise NotImplementedError("number of dimensions is currently limited to 2") + elif values.ndim == 2: + # Note: it is *not* the case that axis is always 0 for 1-dim values, + # as we can have 1D ExtensionArrays that we need to treat as 2D + assert axis == 1, axis + # can we do this operation with our cython functions # if not raise NotImplementedError @@ -524,10 +531,7 @@ def _cython_operation( if vdim == 1 and arity == 1: result = result[:, 0] - if how in self._name_functions: - names = self._name_functions[how]() # type: Optional[List[str]] - else: - names = None + names = self._name_functions.get(how, None) # type: Optional[List[str]] if swapped: result = result.swapaxes(0, axis) @@ -557,10 +561,7 @@ def _aggregate( is_datetimelike: bool, min_count: int = -1, ): - if values.ndim > 2: - # punting for now - raise NotImplementedError("number of dimensions is currently limited to 2") - elif agg_func is libgroupby.group_nth: + if agg_func is libgroupby.group_nth: # different signature from the others # TODO: should we be using min_count instead of hard-coding it? agg_func(result, counts, values, comp_ids, rank=1, min_count=-1) @@ -574,11 +575,7 @@ def _transform( ): comp_ids, _, ngroups = self.group_info - if values.ndim > 2: - # punting for now - raise NotImplementedError("number of dimensions is currently limited to 2") - else: - transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) + transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) return result