Skip to content

Commit 0cfb5e9

Browse files
authored
REF: simplify _cython_operation return (pandas-dev#38253)
1 parent 8ac84fa commit 0cfb5e9

File tree

3 files changed

+13
-19
lines changed

3 files changed

+13
-19
lines changed

pandas/core/groupby/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ def py_fallback(bvalues: ArrayLike) -> ArrayLike:
10871087
def blk_func(bvalues: ArrayLike) -> ArrayLike:
10881088

10891089
try:
1090-
result, _ = self.grouper._cython_operation(
1090+
result = self.grouper._cython_operation(
10911091
"aggregate", bvalues, how, axis=1, min_count=min_count
10921092
)
10931093
except NotImplementedError:

pandas/core/groupby/groupby.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -994,7 +994,7 @@ def _cython_transform(
994994
continue
995995

996996
try:
997-
result, _ = self.grouper._cython_operation(
997+
result = self.grouper._cython_operation(
998998
"transform", obj._values, how, axis, **kwargs
999999
)
10001000
except NotImplementedError:
@@ -1069,12 +1069,13 @@ def _cython_agg_general(
10691069
if numeric_only and not is_numeric:
10701070
continue
10711071

1072-
result, agg_names = self.grouper._cython_operation(
1072+
result = self.grouper._cython_operation(
10731073
"aggregate", obj._values, how, axis=0, min_count=min_count
10741074
)
10751075

1076-
if agg_names:
1076+
if how == "ohlc":
10771077
# e.g. ohlc
1078+
agg_names = ["open", "high", "low", "close"]
10781079
assert len(agg_names) == result.shape[1]
10791080
for result_column, result_name in zip(result.T, agg_names):
10801081
key = base.OutputKey(label=result_name, position=idx)

pandas/core/groupby/ops.py

+8-15
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,6 @@ def get_group_levels(self) -> List[Index]:
370370

371371
_cython_arity = {"ohlc": 4} # OHLC
372372

373-
_name_functions = {"ohlc": ["open", "high", "low", "close"]}
374-
375373
def _is_builtin_func(self, arg):
376374
"""
377375
if we define a builtin function for this argument, return it,
@@ -492,36 +490,33 @@ def _ea_wrap_cython_operation(
492490
# All of the functions implemented here are ordinal, so we can
493491
# operate on the tz-naive equivalents
494492
values = values.view("M8[ns]")
495-
res_values, names = self._cython_operation(
493+
res_values = self._cython_operation(
496494
kind, values, how, axis, min_count, **kwargs
497495
)
498496
if how in ["rank"]:
499497
# preserve float64 dtype
500-
return res_values, names
498+
return res_values
501499

502500
res_values = res_values.astype("i8", copy=False)
503501
result = type(orig_values)._simple_new(res_values, dtype=orig_values.dtype)
504-
return result, names
502+
return result
505503

506504
elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
507505
# IntegerArray or BooleanArray
508506
values = ensure_int_or_float(values)
509-
res_values, names = self._cython_operation(
507+
res_values = self._cython_operation(
510508
kind, values, how, axis, min_count, **kwargs
511509
)
512510
result = maybe_cast_result(result=res_values, obj=orig_values, how=how)
513-
return result, names
511+
return result
514512

515513
raise NotImplementedError(values.dtype)
516514

517515
def _cython_operation(
518516
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
519-
) -> Tuple[np.ndarray, Optional[List[str]]]:
517+
) -> np.ndarray:
520518
"""
521-
Returns the values of a cython operation as a Tuple of [data, names].
522-
523-
Names is only useful when dealing with 2D results, like ohlc
524-
(see self._name_functions).
519+
Returns the values of a cython operation.
525520
"""
526521
orig_values = values
527522
assert kind in ["transform", "aggregate"]
@@ -619,8 +614,6 @@ def _cython_operation(
619614
if vdim == 1 and arity == 1:
620615
result = result[:, 0]
621616

622-
names: Optional[List[str]] = self._name_functions.get(how, None)
623-
624617
if swapped:
625618
result = result.swapaxes(0, axis)
626619

@@ -630,7 +623,7 @@ def _cython_operation(
630623
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
631624
result = maybe_downcast_to_dtype(result, dtype)
632625

633-
return result, names
626+
return result
634627

635628
def _aggregate(
636629
self, result, counts, values, comp_ids, agg_func, min_count: int = -1

0 commit comments

Comments
 (0)