Skip to content

Commit ecfbb26

Browse files
authored
REF: de-duplicate WrappedCythonOp._ea_wrap_cython_operation (#42521)
1 parent 0d393d7 commit ecfbb26

File tree

1 file changed

+38
-78
lines changed

1 file changed

+38
-78
lines changed

pandas/core/groupby/ops.py

+38-78
Original file line numberDiff line numberDiff line change
@@ -341,95 +341,54 @@ def _ea_wrap_cython_operation(
341341
comp_ids=comp_ids,
342342
**kwargs,
343343
)
344-
orig_values = values
345344

346-
if isinstance(orig_values, (DatetimeArray, PeriodArray)):
345+
if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
347346
# All of the functions implemented here are ordinal, so we can
348347
# operate on the tz-naive equivalents
349-
npvalues = orig_values._ndarray.view("M8[ns]")
350-
res_values = self._cython_op_ndim_compat(
351-
npvalues,
352-
min_count=min_count,
353-
ngroups=ngroups,
354-
comp_ids=comp_ids,
355-
mask=None,
356-
**kwargs,
357-
)
358-
if self.how in ["rank"]:
359-
# i.e. how in WrappedCythonOp.cast_blocklist, since
360-
# other cast_blocklist methods dont go through cython_operation
361-
# preserve float64 dtype
362-
return res_values
363-
364-
res_values = res_values.view("i8")
365-
result = type(orig_values)(res_values, dtype=orig_values.dtype)
366-
return result
367-
368-
elif isinstance(orig_values, TimedeltaArray):
369-
# We have an ExtensionArray but not ExtensionDtype
370-
res_values = self._cython_op_ndim_compat(
371-
orig_values._ndarray,
372-
min_count=min_count,
373-
ngroups=ngroups,
374-
comp_ids=comp_ids,
375-
mask=None,
376-
**kwargs,
377-
)
378-
if self.how in ["rank"]:
379-
# i.e. how in WrappedCythonOp.cast_blocklist, since
380-
# other cast_blocklist methods dont go through cython_operation
381-
# preserve float64 dtype
382-
return res_values
383-
384-
# otherwise res_values has the same dtype as original values
385-
return type(orig_values)(res_values)
386-
348+
npvalues = values._ndarray.view("M8[ns]")
387349
elif isinstance(values.dtype, (BooleanDtype, _IntegerDtype)):
388350
# IntegerArray or BooleanArray
389351
npvalues = values.to_numpy("float64", na_value=np.nan)
390-
res_values = self._cython_op_ndim_compat(
391-
npvalues,
392-
min_count=min_count,
393-
ngroups=ngroups,
394-
comp_ids=comp_ids,
395-
mask=None,
396-
**kwargs,
397-
)
398-
if self.how in ["rank"]:
399-
# i.e. how in WrappedCythonOp.cast_blocklist, since
400-
# other cast_blocklist methods dont go through cython_operation
401-
return res_values
402-
403-
dtype = self._get_result_dtype(orig_values.dtype)
404-
cls = dtype.construct_array_type()
405-
return cls._from_sequence(res_values, dtype=dtype)
406-
407352
elif isinstance(values.dtype, FloatingDtype):
408353
# FloatingArray
409-
npvalues = values.to_numpy(
410-
values.dtype.numpy_dtype,
411-
na_value=np.nan,
412-
)
413-
res_values = self._cython_op_ndim_compat(
414-
npvalues,
415-
min_count=min_count,
416-
ngroups=ngroups,
417-
comp_ids=comp_ids,
418-
mask=None,
419-
**kwargs,
354+
npvalues = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
355+
else:
356+
raise NotImplementedError(
357+
f"function is not implemented for this dtype: {values.dtype}"
420358
)
421-
if self.how in ["rank"]:
422-
# i.e. how in WrappedCythonOp.cast_blocklist, since
423-
# other cast_blocklist methods dont go through cython_operation
424-
return res_values
425359

426-
dtype = self._get_result_dtype(orig_values.dtype)
360+
res_values = self._cython_op_ndim_compat(
361+
npvalues,
362+
min_count=min_count,
363+
ngroups=ngroups,
364+
comp_ids=comp_ids,
365+
mask=None,
366+
**kwargs,
367+
)
368+
369+
if self.how in ["rank"]:
370+
# i.e. how in WrappedCythonOp.cast_blocklist, since
371+
# other cast_blocklist methods dont go through cython_operation
372+
return res_values
373+
374+
return self._reconstruct_ea_result(values, res_values)
375+
376+
def _reconstruct_ea_result(self, values, res_values):
377+
"""
378+
Construct an ExtensionArray result from an ndarray result.
379+
"""
380+
# TODO: allow EAs to override this logic
381+
382+
if isinstance(values.dtype, (BooleanDtype, _IntegerDtype, FloatingDtype)):
383+
dtype = self._get_result_dtype(values.dtype)
427384
cls = dtype.construct_array_type()
428385
return cls._from_sequence(res_values, dtype=dtype)
429386

430-
raise NotImplementedError(
431-
f"function is not implemented for this dtype: {values.dtype}"
432-
)
387+
elif needs_i8_conversion(values.dtype):
388+
i8values = res_values.view("i8")
389+
return type(values)(i8values, dtype=values.dtype)
390+
391+
raise NotImplementedError
433392

434393
@final
435394
def _masked_ea_wrap_cython_operation(
@@ -478,6 +437,8 @@ def _cython_op_ndim_compat(
478437
if values.ndim == 1:
479438
# expand to 2d, dispatch, then squeeze if appropriate
480439
values2d = values[None, :]
440+
if mask is not None:
441+
mask = mask[None, :]
481442
res = self._call_cython_op(
482443
values2d,
483444
min_count=min_count,
@@ -533,9 +494,8 @@ def _call_cython_op(
533494
values = ensure_float64(values)
534495

535496
values = values.T
536-
537497
if mask is not None:
538-
mask = mask.reshape(values.shape, order="C")
498+
mask = mask.T
539499

540500
out_shape = self._get_output_shape(ngroups, values)
541501
func, values = self.get_cython_func_and_vals(values, is_numeric)

0 commit comments

Comments
 (0)