Skip to content

Commit c2cf93f

Browse files
authored
REF: simplify groupby.ops (#46196)
1 parent 4985b89 commit c2cf93f

File tree

1 file changed

+45
-36
lines changed

1 file changed

+45
-36
lines changed

pandas/core/groupby/ops.py

+45-36
Original file line numberDiff line numberDiff line change
@@ -123,14 +123,14 @@ def __init__(self, kind: str, how: str):
123123
"min": "group_min",
124124
"max": "group_max",
125125
"mean": "group_mean",
126-
"median": "group_median",
126+
"median": "group_median_float64",
127127
"var": "group_var",
128128
"first": "group_nth",
129129
"last": "group_last",
130130
"ohlc": "group_ohlc",
131131
},
132132
"transform": {
133-
"cumprod": "group_cumprod",
133+
"cumprod": "group_cumprod_float64",
134134
"cumsum": "group_cumsum",
135135
"cummin": "group_cummin",
136136
"cummax": "group_cummax",
@@ -161,52 +161,54 @@ def _get_cython_function(
161161
if is_numeric:
162162
return f
163163
elif dtype == object:
164-
if "object" not in f.__signatures__:
164+
if how in ["median", "cumprod"]:
165+
# no fused types -> no __signatures__
166+
raise NotImplementedError(
167+
f"function is not implemented for this dtype: "
168+
f"[how->{how},dtype->{dtype_str}]"
169+
)
170+
elif "object" not in f.__signatures__:
165171
# raise NotImplementedError here rather than TypeError later
166172
raise NotImplementedError(
167173
f"function is not implemented for this dtype: "
168174
f"[how->{how},dtype->{dtype_str}]"
169175
)
170176
return f
177+
else:
178+
raise NotImplementedError(
179+
"This should not be reached. Please report a bug at "
180+
"github.com/pandas-dev/pandas/",
181+
dtype,
182+
)
171183

172-
def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
184+
def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
173185
"""
174-
Find the appropriate cython function, casting if necessary.
186+
Cast numeric dtypes to float64 for functions that only support that.
175187
176188
Parameters
177189
----------
178190
values : np.ndarray
179-
is_numeric : bool
180191
181192
Returns
182193
-------
183-
func : callable
184194
values : np.ndarray
185195
"""
186196
how = self.how
187-
kind = self.kind
188197

189198
if how in ["median", "cumprod"]:
190199
# these two only have float64 implementations
191-
if is_numeric:
192-
values = ensure_float64(values)
193-
else:
194-
raise NotImplementedError(
195-
f"function is not implemented for this dtype: "
196-
f"[how->{how},dtype->{values.dtype.name}]"
197-
)
198-
func = getattr(libgroupby, f"group_{how}_float64")
199-
return func, values
200-
201-
func = self._get_cython_function(kind, how, values.dtype, is_numeric)
200+
# We should only get here with is_numeric, as non-numeric cases
201+
# should raise in _get_cython_function
202+
values = ensure_float64(values)
202203

203-
if values.dtype.kind in ["i", "u"]:
204+
elif values.dtype.kind in ["i", "u"]:
204205
if how in ["add", "var", "prod", "mean", "ohlc"]:
205206
# result may still include NaN, so we have to cast
206207
values = ensure_float64(values)
207208

208-
return func, values
209+
return values
209210

211+
# TODO: general case implementation overridable by EAs.
210212
def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
211213
"""
212214
Check if we can do this operation with our cython functions.
@@ -235,6 +237,7 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
235237
# are not setup for dim transforming
236238
raise NotImplementedError(f"{dtype} dtype not supported")
237239
elif is_datetime64_any_dtype(dtype):
240+
# TODO: same for period_dtype? no for these methods with Period
238241
# we raise NotImplemented if this is an invalid operation
239242
# entirely, e.g. adding datetimes
240243
if how in ["add", "prod", "cumsum", "cumprod"]:
@@ -262,7 +265,7 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
262265
out_shape = (ngroups,) + values.shape[1:]
263266
return out_shape
264267

265-
def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
268+
def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
266269
how = self.how
267270

268271
if how == "rank":
@@ -282,6 +285,7 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
282285
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
283286
... # pragma: no cover
284287

288+
# TODO: general case implementation overridable by EAs.
285289
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
286290
"""
287291
Get the desired dtype of a result based on the
@@ -329,7 +333,6 @@ def _ea_wrap_cython_operation(
329333
If we have an ExtensionArray, unwrap, call _cython_operation, and
330334
re-wrap if appropriate.
331335
"""
332-
# TODO: general case implementation overridable by EAs.
333336
if isinstance(values, BaseMaskedArray) and self.uses_mask():
334337
return self._masked_ea_wrap_cython_operation(
335338
values,
@@ -357,7 +360,8 @@ def _ea_wrap_cython_operation(
357360

358361
return self._reconstruct_ea_result(values, res_values)
359362

360-
def _ea_to_cython_values(self, values: ExtensionArray):
363+
# TODO: general case implementation overridable by EAs.
364+
def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
361365
# GH#43682
362366
if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
363367
# All of the functions implemented here are ordinal, so we can
@@ -378,23 +382,24 @@ def _ea_to_cython_values(self, values: ExtensionArray):
378382
)
379383
return npvalues
380384

381-
def _reconstruct_ea_result(self, values, res_values):
385+
# TODO: general case implementation overridable by EAs.
386+
def _reconstruct_ea_result(
387+
self, values: ExtensionArray, res_values: np.ndarray
388+
) -> ExtensionArray:
382389
"""
383390
Construct an ExtensionArray result from an ndarray result.
384391
"""
385-
# TODO: allow EAs to override this logic
386392

387-
if isinstance(
388-
values.dtype, (BooleanDtype, IntegerDtype, FloatingDtype, StringDtype)
389-
):
393+
if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)):
390394
dtype = self._get_result_dtype(values.dtype)
391395
cls = dtype.construct_array_type()
392396
return cls._from_sequence(res_values, dtype=dtype)
393397

394398
elif needs_i8_conversion(values.dtype):
395399
assert res_values.dtype.kind != "f" # just to be on the safe side
396400
i8values = res_values.view("i8")
397-
return type(values)(i8values, dtype=values.dtype)
401+
# error: Too many arguments for "ExtensionArray"
402+
return type(values)(i8values, dtype=values.dtype) # type: ignore[call-arg]
398403

399404
raise NotImplementedError
400405

@@ -429,13 +434,16 @@ def _masked_ea_wrap_cython_operation(
429434
)
430435

431436
dtype = self._get_result_dtype(orig_values.dtype)
432-
assert isinstance(dtype, BaseMaskedDtype)
433-
cls = dtype.construct_array_type()
437+
# TODO: avoid cast as res_values *should* already have the right
438+
# dtype; last attempt ran into trouble on 32bit linux build
439+
res_values = res_values.astype(dtype.type, copy=False)
434440

435441
if self.kind != "aggregate":
436-
return cls(res_values.astype(dtype.type, copy=False), mask)
442+
out_mask = mask
437443
else:
438-
return cls(res_values.astype(dtype.type, copy=False), result_mask)
444+
out_mask = result_mask
445+
446+
return orig_values._maybe_mask_result(res_values, out_mask)
439447

440448
@final
441449
def _cython_op_ndim_compat(
@@ -521,8 +529,9 @@ def _call_cython_op(
521529
result_mask = result_mask.T
522530

523531
out_shape = self._get_output_shape(ngroups, values)
524-
func, values = self.get_cython_func_and_vals(values, is_numeric)
525-
out_dtype = self.get_out_dtype(values.dtype)
532+
func = self._get_cython_function(self.kind, self.how, values.dtype, is_numeric)
533+
values = self._get_cython_vals(values)
534+
out_dtype = self._get_out_dtype(values.dtype)
526535

527536
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
528537
if self.kind == "aggregate":

0 commit comments

Comments
 (0)