Skip to content

Commit b21a482

Browse files
jbrockmendeljreback
authored andcommitted
consolidate dim checks (#29536)
1 parent b2dc8bf commit b21a482

File tree

1 file changed

+11
-14
lines changed

1 file changed

+11
-14
lines changed

pandas/core/groupby/ops.py

+11-14
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def get_group_levels(self):
343343

344344
_cython_arity = {"ohlc": 4} # OHLC
345345

346-
_name_functions = {"ohlc": lambda *args: ["open", "high", "low", "close"]}
346+
_name_functions = {"ohlc": ["open", "high", "low", "close"]}
347347

348348
def _is_builtin_func(self, arg):
349349
"""
@@ -433,6 +433,13 @@ def _cython_operation(
433433
assert kind in ["transform", "aggregate"]
434434
orig_values = values
435435

436+
if values.ndim > 2:
437+
raise NotImplementedError("number of dimensions is currently limited to 2")
438+
elif values.ndim == 2:
439+
# Note: it is *not* the case that axis is always 0 for 1-dim values,
440+
# as we can have 1D ExtensionArrays that we need to treat as 2D
441+
assert axis == 1, axis
442+
436443
# can we do this operation with our cython functions
437444
# if not raise NotImplementedError
438445

@@ -545,10 +552,7 @@ def _cython_operation(
545552
if vdim == 1 and arity == 1:
546553
result = result[:, 0]
547554

548-
if how in self._name_functions:
549-
names = self._name_functions[how]() # type: Optional[List[str]]
550-
else:
551-
names = None
555+
names = self._name_functions.get(how, None) # type: Optional[List[str]]
552556

553557
if swapped:
554558
result = result.swapaxes(0, axis)
@@ -578,10 +582,7 @@ def _aggregate(
578582
is_datetimelike: bool,
579583
min_count: int = -1,
580584
):
581-
if values.ndim > 2:
582-
# punting for now
583-
raise NotImplementedError("number of dimensions is currently limited to 2")
584-
elif agg_func is libgroupby.group_nth:
585+
if agg_func is libgroupby.group_nth:
585586
# different signature from the others
586587
# TODO: should we be using min_count instead of hard-coding it?
587588
agg_func(result, counts, values, comp_ids, rank=1, min_count=-1)
@@ -595,11 +596,7 @@ def _transform(
595596
):
596597

597598
comp_ids, _, ngroups = self.group_info
598-
if values.ndim > 2:
599-
# punting for now
600-
raise NotImplementedError("number of dimensions is currently limited to 2")
601-
else:
602-
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
599+
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
603600

604601
return result
605602

0 commit comments

Comments
 (0)