Skip to content

Commit 6d46116

Browse files
authored
Merge pull request #158 from pandas-dev/master
REF: _cython_operation handle values.ndim==1 case up-front (pandas-dev#40672)
2 parents 5bad13f + 3c3589b commit 6d46116

File tree

1 file changed

+31
-22
lines changed

1 file changed

+31
-22
lines changed

pandas/core/groupby/ops.py

+31-22
Original file line numberDiff line numberDiff line change
@@ -609,6 +609,23 @@ def _cython_operation(
609609
kind, values, how, axis, min_count, **kwargs
610610
)
611611

612+
elif values.ndim == 1:
613+
# expand to 2d, dispatch, then squeeze if appropriate
614+
values2d = values[None, :]
615+
res = self._cython_operation(
616+
kind=kind,
617+
values=values2d,
618+
how=how,
619+
axis=1,
620+
min_count=min_count,
621+
**kwargs,
622+
)
623+
if res.shape[0] == 1:
624+
return res[0]
625+
626+
# otherwise we have OHLC
627+
return res.T
628+
612629
is_datetimelike = needs_i8_conversion(dtype)
613630

614631
if is_datetimelike:
@@ -629,22 +646,20 @@ def _cython_operation(
629646
values = values.astype(object)
630647

631648
arity = self._cython_arity.get(how, 1)
649+
ngroups = self.ngroups
632650

633-
vdim = values.ndim
634-
swapped = False
635-
if vdim == 1:
636-
values = values[:, None]
637-
out_shape = (self.ngroups, arity)
651+
assert axis == 1
652+
values = values.T
653+
if how == "ohlc":
654+
out_shape = (ngroups, 4)
655+
elif arity > 1:
656+
raise NotImplementedError(
657+
"arity of more than 1 is not supported for the 'how' argument"
658+
)
659+
elif kind == "transform":
660+
out_shape = values.shape
638661
else:
639-
if axis > 0:
640-
swapped = True
641-
assert axis == 1, axis
642-
values = values.T
643-
if arity > 1:
644-
raise NotImplementedError(
645-
"arity of more than 1 is not supported for the 'how' argument"
646-
)
647-
out_shape = (self.ngroups,) + values.shape[1:]
662+
out_shape = (ngroups,) + values.shape[1:]
648663

649664
func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric)
650665

@@ -658,13 +673,11 @@ def _cython_operation(
658673

659674
codes, _, _ = self.group_info
660675

676+
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
661677
if kind == "aggregate":
662-
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
663678
counts = np.zeros(self.ngroups, dtype=np.int64)
664679
result = self._aggregate(result, counts, values, codes, func, min_count)
665680
elif kind == "transform":
666-
result = maybe_fill(np.empty(values.shape, dtype=out_dtype))
667-
668681
# TODO: min_count
669682
result = self._transform(
670683
result, values, codes, func, is_datetimelike, **kwargs
@@ -680,11 +693,7 @@ def _cython_operation(
680693
assert result.ndim != 2
681694
result = result[counts > 0]
682695

683-
if vdim == 1 and arity == 1:
684-
result = result[:, 0]
685-
686-
if swapped:
687-
result = result.swapaxes(0, axis)
696+
result = result.T
688697

689698
if how not in base.cython_cast_blocklist:
690699
# e.g. if we are int64 and need to restore to datetime64/timedelta64

0 commit comments

Comments
 (0)