diff --git a/pandas/core/groupby/ops.py b/pandas/core/groupby/ops.py index d0fb100fdbd9f..8f8c2ba2981c8 100644 --- a/pandas/core/groupby/ops.py +++ b/pandas/core/groupby/ops.py @@ -299,6 +299,270 @@ def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj: def uses_mask(self) -> bool: return self.how in self._MASKED_CYTHON_FUNCTIONS + @final + def _ea_wrap_cython_operation( + self, + values: ExtensionArray, + min_count: int, + ngroups: int, + comp_ids: np.ndarray, + **kwargs, + ) -> ArrayLike: + """ + If we have an ExtensionArray, unwrap, call _cython_operation, and + re-wrap if appropriate. + """ + # TODO: general case implementation overridable by EAs. + orig_values = values + + if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype): + # All of the functions implemented here are ordinal, so we can + # operate on the tz-naive equivalents + npvalues = values.view("M8[ns]") + res_values = self._cython_op_ndim_compat( + # error: Argument 1 to "_cython_op_ndim_compat" of + # "WrappedCythonOp" has incompatible type + # "Union[ExtensionArray, ndarray]"; expected "ndarray" + npvalues, # type: ignore[arg-type] + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=None, + **kwargs, + ) + if self.how in ["rank"]: + # i.e. how in WrappedCythonOp.cast_blocklist, since + # other cast_blocklist methods dont go through cython_operation + # preserve float64 dtype + return res_values + + res_values = res_values.astype("i8", copy=False) + # error: Too many arguments for "ExtensionArray" + result = type(orig_values)( # type: ignore[call-arg] + res_values, dtype=orig_values.dtype + ) + return result + + elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype): + # IntegerArray or BooleanArray + npvalues = values.to_numpy("float64", na_value=np.nan) + res_values = self._cython_op_ndim_compat( + npvalues, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=None, + **kwargs, + ) + if self.how in ["rank"]: + # i.e. how in WrappedCythonOp.cast_blocklist, since + # other cast_blocklist methods dont go through cython_operation + return res_values + + dtype = self.get_result_dtype(orig_values.dtype) + # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" + # has no attribute "construct_array_type" + cls = dtype.construct_array_type() # type: ignore[union-attr] + return cls._from_sequence(res_values, dtype=dtype) + + elif is_float_dtype(values.dtype): + # FloatingArray + # error: "ExtensionDtype" has no attribute "numpy_dtype" + npvalues = values.to_numpy( + values.dtype.numpy_dtype, # type: ignore[attr-defined] + na_value=np.nan, + ) + res_values = self._cython_op_ndim_compat( + npvalues, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=None, + **kwargs, + ) + if self.how in ["rank"]: + # i.e. how in WrappedCythonOp.cast_blocklist, since + # other cast_blocklist methods dont go through cython_operation + return res_values + + dtype = self.get_result_dtype(orig_values.dtype) + # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" + # has no attribute "construct_array_type" + cls = dtype.construct_array_type() # type: ignore[union-attr] + return cls._from_sequence(res_values, dtype=dtype) + + raise NotImplementedError( + f"function is not implemented for this dtype: {values.dtype}" + ) + + @final + def _masked_ea_wrap_cython_operation( + self, + values: BaseMaskedArray, + min_count: int, + ngroups: int, + comp_ids: np.ndarray, + **kwargs, + ) -> BaseMaskedArray: + """ + Equivalent of `_ea_wrap_cython_operation`, but optimized for masked EA's + and cython algorithms which accept a mask. + """ + orig_values = values + + # Copy to ensure input and result masks don't end up shared + mask = values._mask.copy() + arr = values._data + + res_values = self._cython_op_ndim_compat( + arr, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=mask, + **kwargs, + ) + dtype = self.get_result_dtype(orig_values.dtype) + assert isinstance(dtype, BaseMaskedDtype) + cls = dtype.construct_array_type() + + return cls(res_values.astype(dtype.type, copy=False), mask) + + def _cython_op_ndim_compat( + self, + values: np.ndarray, + *, + min_count: int, + ngroups: int, + comp_ids: np.ndarray, + mask: np.ndarray | None, + **kwargs, + ) -> np.ndarray: + if values.ndim == 1: + # expand to 2d, dispatch, then squeeze if appropriate + values2d = values[None, :] + res = self._call_cython_op( + values2d, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=mask, + **kwargs, + ) + if res.shape[0] == 1: + return res[0] + + # otherwise we have OHLC + return res.T + + return self._call_cython_op( + values, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + mask=mask, + **kwargs, + ) + + @final + def _call_cython_op( + self, + values: np.ndarray, # np.ndarray[ndim=2] + *, + min_count: int, + ngroups: int, + comp_ids: np.ndarray, + mask: np.ndarray | None, + **kwargs, + ) -> np.ndarray: # np.ndarray[ndim=2] + orig_values = values + + dtype = values.dtype + is_numeric = is_numeric_dtype(dtype) + + is_datetimelike = needs_i8_conversion(dtype) + + if is_datetimelike: + values = values.view("int64") + is_numeric = True + elif is_bool_dtype(dtype): + values = values.astype("int64") + elif is_integer_dtype(dtype): + # e.g. uint8 -> uint64, int16 -> int64 + dtype_str = dtype.kind + "8" + values = values.astype(dtype_str, copy=False) + elif is_numeric: + if not is_complex_dtype(dtype): + values = ensure_float64(values) + + values = values.T + + if mask is not None: + mask = mask.reshape(values.shape, order="C") + + out_shape = self.get_output_shape(ngroups, values) + func, values = self.get_cython_func_and_vals(values, is_numeric) + out_dtype = self.get_out_dtype(values.dtype) + + result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) + if self.kind == "aggregate": + counts = np.zeros(ngroups, dtype=np.int64) + if self.how in ["min", "max"]: + func( + result, + counts, + values, + comp_ids, + min_count, + is_datetimelike=is_datetimelike, + ) + else: + func(result, counts, values, comp_ids, min_count) + else: + # TODO: min_count + if self.uses_mask(): + func( + result, + values, + comp_ids, + ngroups, + is_datetimelike, + mask=mask, + **kwargs, + ) + else: + func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) + + if self.kind == "aggregate": + # i.e. counts is defined. Locations where count list[Index]: # ------------------------------------------------------------ # Aggregation functions - @final - def _ea_wrap_cython_operation( - self, - cy_op: WrappedCythonOp, - kind: str, - values, - how: str, - axis: int, - min_count: int = -1, - **kwargs, - ) -> ArrayLike: - """ - If we have an ExtensionArray, unwrap, call _cython_operation, and - re-wrap if appropriate. - """ - # TODO: general case implementation overridable by EAs. - orig_values = values - - if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype): - # All of the functions implemented here are ordinal, so we can - # operate on the tz-naive equivalents - npvalues = values.view("M8[ns]") - res_values = self._cython_operation( - kind, npvalues, how, axis, min_count, **kwargs - ) - if how in ["rank"]: - # i.e. how in WrappedCythonOp.cast_blocklist, since - # other cast_blocklist methods dont go through cython_operation - # preserve float64 dtype - return res_values - - res_values = res_values.astype("i8", copy=False) - result = type(orig_values)(res_values, dtype=orig_values.dtype) - return result - - elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype): - # IntegerArray or BooleanArray - values = values.to_numpy("float64", na_value=np.nan) - res_values = self._cython_operation( - kind, values, how, axis, min_count, **kwargs - ) - if how in ["rank"]: - # i.e. how in WrappedCythonOp.cast_blocklist, since - # other cast_blocklist methods dont go through cython_operation - return res_values - - dtype = cy_op.get_result_dtype(orig_values.dtype) - # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" - # has no attribute "construct_array_type" - cls = dtype.construct_array_type() # type: ignore[union-attr] - return cls._from_sequence(res_values, dtype=dtype) - - elif is_float_dtype(values.dtype): - # FloatingArray - values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan) - res_values = self._cython_operation( - kind, values, how, axis, min_count, **kwargs - ) - if how in ["rank"]: - # i.e. how in WrappedCythonOp.cast_blocklist, since - # other cast_blocklist methods dont go through cython_operation - return res_values - - dtype = cy_op.get_result_dtype(orig_values.dtype) - # error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]" - # has no attribute "construct_array_type" - cls = dtype.construct_array_type() # type: ignore[union-attr] - return cls._from_sequence(res_values, dtype=dtype) - - raise NotImplementedError( - f"function is not implemented for this dtype: {values.dtype}" - ) - - @final - def _masked_ea_wrap_cython_operation( - self, - cy_op: WrappedCythonOp, - kind: str, - values: BaseMaskedArray, - how: str, - axis: int, - min_count: int = -1, - **kwargs, - ) -> BaseMaskedArray: - """ - Equivalent of `_ea_wrap_cython_operation`, but optimized for masked EA's - and cython algorithms which accept a mask. - """ - orig_values = values - - # Copy to ensure input and result masks don't end up shared - mask = values._mask.copy() - arr = values._data - - res_values = self._cython_operation( - kind, arr, how, axis, min_count, mask=mask, **kwargs - ) - dtype = cy_op.get_result_dtype(orig_values.dtype) - assert isinstance(dtype, BaseMaskedDtype) - cls = dtype.construct_array_type() - - return cls(res_values.astype(dtype.type, copy=False), mask) - @final def _cython_operation( self, @@ -713,7 +874,6 @@ def _cython_operation( """ Returns the values of a cython operation. """ - orig_values = values assert kind in ["transform", "aggregate"] if values.ndim > 2: @@ -732,119 +892,36 @@ def _cython_operation( # if not raise NotImplementedError cy_op.disallow_invalid_ops(dtype, is_numeric) + comp_ids, _, _ = self.group_info + ngroups = self.ngroups + func_uses_mask = cy_op.uses_mask() if is_extension_array_dtype(dtype): if isinstance(values, BaseMaskedArray) and func_uses_mask: - return self._masked_ea_wrap_cython_operation( - cy_op, kind, values, how, axis, min_count, **kwargs - ) - else: - return self._ea_wrap_cython_operation( - cy_op, kind, values, how, axis, min_count, **kwargs - ) - - elif values.ndim == 1: - # expand to 2d, dispatch, then squeeze if appropriate - values2d = values[None, :] - res = self._cython_operation( - kind=kind, - values=values2d, - how=how, - axis=1, - min_count=min_count, - mask=mask, - **kwargs, - ) - if res.shape[0] == 1: - return res[0] - - # otherwise we have OHLC - return res.T - - is_datetimelike = needs_i8_conversion(dtype) - - if is_datetimelike: - values = values.view("int64") - is_numeric = True - elif is_bool_dtype(dtype): - values = values.astype("int64") - elif is_integer_dtype(dtype): - # e.g. uint8 -> uint64, int16 -> int64 - dtype = dtype.kind + "8" - values = values.astype(dtype, copy=False) - elif is_numeric: - if not is_complex_dtype(dtype): - values = ensure_float64(values) - - ngroups = self.ngroups - comp_ids, _, _ = self.group_info - - assert axis == 1 - values = values.T - - if mask is not None: - mask = mask.reshape(values.shape, order="C") - - out_shape = cy_op.get_output_shape(ngroups, values) - func, values = cy_op.get_cython_func_and_vals(values, is_numeric) - out_dtype = cy_op.get_out_dtype(values.dtype) - - result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) - if kind == "aggregate": - counts = np.zeros(ngroups, dtype=np.int64) - if how in ["min", "max"]: - func( - result, - counts, + return cy_op._masked_ea_wrap_cython_operation( values, - comp_ids, - min_count, - is_datetimelike=is_datetimelike, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, + **kwargs, ) else: - func(result, counts, values, comp_ids, min_count) - elif kind == "transform": - # TODO: min_count - if func_uses_mask: - func( - result, + return cy_op._ea_wrap_cython_operation( values, - comp_ids, - ngroups, - is_datetimelike, - mask=mask, + min_count=min_count, + ngroups=ngroups, + comp_ids=comp_ids, **kwargs, ) - else: - func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) - if kind == "aggregate": - # i.e. counts is defined. Locations where count 0] - - result = result.T - - if how not in cy_op.cast_blocklist: - # e.g. if we are int64 and need to restore to datetime64/timedelta64 - # "rank" is the only member of cast_blocklist we get here - dtype = cy_op.get_result_dtype(orig_values.dtype) - op_result = maybe_downcast_to_dtype(result, dtype) - else: - op_result = result - - return op_result + return cy_op._cython_op_ndim_compat( + values, + min_count=min_count, + ngroups=self.ngroups, + comp_ids=comp_ids, + mask=mask, + **kwargs, + ) def agg_series(self, obj: Series, func: F): # Caller is responsible for checking ngroups != 0