diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index 874d7395b1950..84d7f2a0b8cb9 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -341,95 +341,54 @@ def _ea_wrap_cython_operation( comp_ids=comp_ids, **kwargs, ) - orig_values = values - if isinstance(orig_values, (DatetimeArray, PeriodArray)): + if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)): # All of the functions implemented here are ordinal, so we can # operate on the tz-naive equivalents - npvalues = orig_values._ndarray.view("M8[ns]") - res_values = self._cython_op_ndim_compat( - npvalues, - min_count=min_count, - ngroups=ngroups, - comp_ids=comp_ids, - mask=None, - **kwargs, - ) - if self.how in ["rank"]: - # i.e. how in WrappedCythonOp.cast_blocklist, since - # other cast_blocklist methods dont go through cython_operation - # preserve float64 dtype - return res_values - - res_values = res_values.view("i8") - result = type(orig_values)(res_values, dtype=orig_values.dtype) - return result - - elif isinstance(orig_values, TimedeltaArray): - # We have an ExtensionArray but not ExtensionDtype - res_values = self._cython_op_ndim_compat( - orig_values._ndarray, - min_count=min_count, - ngroups=ngroups, - comp_ids=comp_ids, - mask=None, - **kwargs, - ) - if self.how in ["rank"]: - # i.e. how in WrappedCythonOp.cast_blocklist, since - # other cast_blocklist methods dont go through cython_operation - # preserve float64 dtype - return res_values - - # otherwise res_values has the same dtype as original values - return type(orig_values)(res_values) - + npvalues = values._ndarray.view("M8[ns]") elif isinstance(values.dtype, (BooleanDtype, _IntegerDtype)): # IntegerArray or BooleanArray npvalues = values.to_numpy("float64", na_value=np.nan) - res_values = self._cython_op_ndim_compat( - npvalues, - min_count=min_count, - ngroups=ngroups, - comp_ids=comp_ids, - mask=None, - **kwargs, - ) - if self.how in ["rank"]: - # i.e. how in WrappedCythonOp.cast_blocklist, since - # other cast_blocklist methods dont go through cython_operation - return res_values - - dtype = self._get_result_dtype(orig_values.dtype) - cls = dtype.construct_array_type() - return cls._from_sequence(res_values, dtype=dtype) - elif isinstance(values.dtype, FloatingDtype): # FloatingArray - npvalues = values.to_numpy( - values.dtype.numpy_dtype, - na_value=np.nan, - ) - res_values = self._cython_op_ndim_compat( - npvalues, - min_count=min_count, - ngroups=ngroups, - comp_ids=comp_ids, - mask=None, - **kwargs, + npvalues = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan) + else: + raise NotImplementedError( + f"function is not implemented for this dtype: {values.dtype}" ) - if self.how in ["rank"]: - # i.e. how in WrappedCythonOp.cast_blocklist, since - # other cast_blocklist methods dont go through cython_operation - return res_values - dtype = self._get_result_dtype(orig_values.dtype) + res_values = self._cython_op_ndim_compat( + npvalues, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=None, + **kwargs, + ) + + if self.how in ["rank"]: + # i.e. how in WrappedCythonOp.cast_blocklist, since + # other cast_blocklist methods dont go through cython_operation + return res_values + + return self._reconstruct_ea_result(values, res_values) + + def _reconstruct_ea_result(self, values, res_values): + """ + Construct an ExtensionArray result from an ndarray result. + """ + # TODO: allow EAs to override this logic + + if isinstance(values.dtype, (BooleanDtype, _IntegerDtype, FloatingDtype)): + dtype = self._get_result_dtype(values.dtype) cls = dtype.construct_array_type() return cls._from_sequence(res_values, dtype=dtype) - raise NotImplementedError( - f"function is not implemented for this dtype: {values.dtype}" - ) + elif needs_i8_conversion(values.dtype): + i8values = res_values.view("i8") + return type(values)(i8values, dtype=values.dtype) + + raise NotImplementedError @final def _masked_ea_wrap_cython_operation( @@ -478,6 +437,8 @@ def _cython_op_ndim_compat( if values.ndim == 1: # expand to 2d, dispatch, then squeeze if appropriate values2d = values[None, :] + if mask is not None: + mask = mask[None, :] res = self._call_cython_op( values2d, min_count=min_count, @@ -533,9 +494,8 @@ def _call_cython_op( values = ensure_float64(values) values = values.T - if mask is not None: - mask = mask.reshape(values.shape, order="C") + mask = mask.T out_shape = self._get_output_shape(ngroups, values) func, values = self.get_cython_func_and_vals(values, is_numeric)