Skip to content

Commit edbb516

Browse files
phoflnoatamir
authored andcommitted
CLN: Clean groupby ops from unreached code paths (pandas-dev#48698)
* CLN: Clean groupby ops from unreached code paths * Refactor
1 parent e7302e0 commit edbb516

File tree

1 file changed

+23
-82
lines changed

1 file changed

+23
-82
lines changed

pandas/core/groupby/ops.py

+23-82
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,6 @@
7272
PeriodArray,
7373
TimedeltaArray,
7474
)
75-
from pandas.core.arrays.boolean import BooleanDtype
76-
from pandas.core.arrays.floating import FloatingDtype
77-
from pandas.core.arrays.integer import IntegerDtype
7875
from pandas.core.arrays.masked import (
7976
BaseMaskedArray,
8077
BaseMaskedDtype,
@@ -147,26 +144,6 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
147144
},
148145
}
149146

150-
# "group_any" and "group_all" are also support masks, but don't go
151-
# through WrappedCythonOp
152-
_MASKED_CYTHON_FUNCTIONS = {
153-
"cummin",
154-
"cummax",
155-
"min",
156-
"max",
157-
"last",
158-
"first",
159-
"rank",
160-
"sum",
161-
"ohlc",
162-
"cumprod",
163-
"cumsum",
164-
"prod",
165-
"mean",
166-
"var",
167-
"median",
168-
}
169-
170147
_cython_arity = {"ohlc": 4} # OHLC
171148

172149
# Note: we make this a classmethod and pass kind+how so that caching
@@ -220,8 +197,8 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
220197
"""
221198
how = self.how
222199

223-
if how in ["median"]:
224-
# these two only have float64 implementations
200+
if how == "median":
201+
# median only has a float64 implementation
225202
# We should only get here with is_numeric, as non-numeric cases
226203
# should raise in _get_cython_function
227204
values = ensure_float64(values)
@@ -293,7 +270,7 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
293270

294271
out_shape: Shape
295272
if how == "ohlc":
296-
out_shape = (ngroups, 4)
273+
out_shape = (ngroups, arity)
297274
elif arity > 1:
298275
raise NotImplementedError(
299276
"arity of more than 1 is not supported for the 'how' argument"
@@ -342,9 +319,6 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
342319
return np.dtype(np.float64)
343320
return dtype
344321

345-
def uses_mask(self) -> bool:
346-
return self.how in self._MASKED_CYTHON_FUNCTIONS
347-
348322
@final
349323
def _ea_wrap_cython_operation(
350324
self,
@@ -358,7 +332,7 @@ def _ea_wrap_cython_operation(
358332
If we have an ExtensionArray, unwrap, call _cython_operation, and
359333
re-wrap if appropriate.
360334
"""
361-
if isinstance(values, BaseMaskedArray) and self.uses_mask():
335+
if isinstance(values, BaseMaskedArray):
362336
return self._masked_ea_wrap_cython_operation(
363337
values,
364338
min_count=min_count,
@@ -367,7 +341,7 @@ def _ea_wrap_cython_operation(
367341
**kwargs,
368342
)
369343

370-
elif isinstance(values, Categorical) and self.uses_mask():
344+
elif isinstance(values, Categorical):
371345
assert self.how == "rank" # the only one implemented ATM
372346
assert values.ordered # checked earlier
373347
mask = values.isna()
@@ -398,7 +372,7 @@ def _ea_wrap_cython_operation(
398372
)
399373

400374
if self.how in self.cast_blocklist:
401-
# i.e. how in ["rank"], since other cast_blocklist methods dont go
375+
# i.e. how in ["rank"], since other cast_blocklist methods don't go
402376
# through cython_operation
403377
return res_values
404378

@@ -411,12 +385,6 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
411385
# All of the functions implemented here are ordinal, so we can
412386
# operate on the tz-naive equivalents
413387
npvalues = values._ndarray.view("M8[ns]")
414-
elif isinstance(values.dtype, (BooleanDtype, IntegerDtype)):
415-
# IntegerArray or BooleanArray
416-
npvalues = values.to_numpy("float64", na_value=np.nan)
417-
elif isinstance(values.dtype, FloatingDtype):
418-
# FloatingArray
419-
npvalues = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
420388
elif isinstance(values.dtype, StringDtype):
421389
# StringArray
422390
npvalues = values.to_numpy(object, na_value=np.nan)
@@ -440,12 +408,6 @@ def _reconstruct_ea_result(
440408
string_array_cls = dtype.construct_array_type()
441409
return string_array_cls._from_sequence(res_values, dtype=dtype)
442410

443-
elif isinstance(values.dtype, BaseMaskedDtype):
444-
new_dtype = self._get_result_dtype(values.dtype.numpy_dtype)
445-
dtype = BaseMaskedDtype.from_numpy_dtype(new_dtype)
446-
masked_array_cls = dtype.construct_array_type()
447-
return masked_array_cls._from_sequence(res_values, dtype=dtype)
448-
449411
elif isinstance(values, (DatetimeArray, TimedeltaArray, PeriodArray)):
450412
# In to_cython_values we took a view as M8[ns]
451413
assert res_values.dtype == "M8[ns]"
@@ -489,7 +451,8 @@ def _masked_ea_wrap_cython_operation(
489451
)
490452

491453
if self.how == "ohlc":
492-
result_mask = np.tile(result_mask, (4, 1)).T
454+
arity = self._cython_arity.get(self.how, 1)
455+
result_mask = np.tile(result_mask, (arity, 1)).T
493456

494457
# res_values should already have the correct dtype, we just need to
495458
# wrap in a MaskedArray
@@ -580,7 +543,7 @@ def _call_cython_op(
580543
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
581544
if self.kind == "aggregate":
582545
counts = np.zeros(ngroups, dtype=np.int64)
583-
if self.how in ["min", "max", "mean", "last", "first"]:
546+
if self.how in ["min", "max", "mean", "last", "first", "sum"]:
584547
func(
585548
out=result,
586549
counts=counts,
@@ -591,18 +554,6 @@ def _call_cython_op(
591554
result_mask=result_mask,
592555
is_datetimelike=is_datetimelike,
593556
)
594-
elif self.how in ["sum"]:
595-
# We support datetimelike
596-
func(
597-
out=result,
598-
counts=counts,
599-
values=values,
600-
labels=comp_ids,
601-
mask=mask,
602-
result_mask=result_mask,
603-
min_count=min_count,
604-
is_datetimelike=is_datetimelike,
605-
)
606557
elif self.how in ["var", "ohlc", "prod", "median"]:
607558
func(
608559
result,
@@ -615,31 +566,21 @@ def _call_cython_op(
615566
**kwargs,
616567
)
617568
else:
618-
func(result, counts, values, comp_ids, min_count)
569+
raise NotImplementedError(f"{self.how} is not implemented")
619570
else:
620571
# TODO: min_count
621-
if self.uses_mask():
622-
if self.how != "rank":
623-
# TODO: should rank take result_mask?
624-
kwargs["result_mask"] = result_mask
625-
func(
626-
out=result,
627-
values=values,
628-
labels=comp_ids,
629-
ngroups=ngroups,
630-
is_datetimelike=is_datetimelike,
631-
mask=mask,
632-
**kwargs,
633-
)
634-
else:
635-
func(
636-
out=result,
637-
values=values,
638-
labels=comp_ids,
639-
ngroups=ngroups,
640-
is_datetimelike=is_datetimelike,
641-
**kwargs,
642-
)
572+
if self.how != "rank":
573+
# TODO: should rank take result_mask?
574+
kwargs["result_mask"] = result_mask
575+
func(
576+
out=result,
577+
values=values,
578+
labels=comp_ids,
579+
ngroups=ngroups,
580+
is_datetimelike=is_datetimelike,
581+
mask=mask,
582+
**kwargs,
583+
)
643584

644585
if self.kind == "aggregate":
645586
# i.e. counts is defined. Locations where count<min_count
@@ -650,7 +591,7 @@ def _call_cython_op(
650591
cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
651592
empty_groups = counts < cutoff
652593
if empty_groups.any():
653-
if result_mask is not None and self.uses_mask():
594+
if result_mask is not None:
654595
assert result_mask[empty_groups].all()
655596
else:
656597
# Note: this conversion could be lossy, see GH#40767

0 commit comments

Comments
 (0)