diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 0c378acbc6dc3..b3814f666a546 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -565,11 +565,11 @@ def group_any_all( const uint8_t[:, :] mask, str val_test, bint skipna, - bint nullable, + uint8_t[:, ::1] result_mask, ) -> None: """ Aggregated boolean values to show truthfulness of group elements. If the - input is a nullable type (nullable=True), the result will be computed + input is a nullable type (result_mask is not None), the result will be computed using Kleene logic. Parameters @@ -587,9 +587,9 @@ def group_any_all( String object dictating whether to use any or all truth testing skipna : bool Flag to ignore nan values during truth testing - nullable : bool - Whether or not the input is a nullable type. If True, the - result will be computed using Kleene logic + result_mask : ndarray[bool, ndim=2], optional + If not None, these specify locations in the output that are NA. + Modified in-place. Notes ----- @@ -601,6 +601,7 @@ def group_any_all( Py_ssize_t i, j, N = len(labels), K = out.shape[1] intp_t lab int8_t flag_val, val + bint uses_mask = result_mask is not None if val_test == "all": # Because the 'all' value of an empty iterable in Python is True we can @@ -627,12 +628,12 @@ def group_any_all( if skipna and mask[i, j]: continue - if nullable and mask[i, j]: + if uses_mask and mask[i, j]: # Set the position as masked if `out[lab] != flag_val`, which # would indicate True/False has not yet been seen for any/all, # so by Kleene logic the result is currently unknown if out[lab, j] != flag_val: - out[lab, j] = -1 + result_mask[lab, j] = 1 continue val = values[i, j] @@ -641,6 +642,8 @@ def group_any_all( # already determined if val == flag_val: out[lab, j] = flag_val + if uses_mask: + result_mask[lab, j] = 0 # ---------------------------------------------------------------------- diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 96f39bb99e544..d9eaf2304d619 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -1658,10 +1658,10 @@ def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]: def result_to_bool( result: np.ndarray, inference: type, - nullable: bool = False, + result_mask, ) -> ArrayLike: - if nullable: - return BooleanArray(result.astype(bool, copy=False), result == -1) + if result_mask is not None: + return BooleanArray(result.astype(bool, copy=False), result_mask) else: return result.astype(inference, copy=False) @@ -1939,10 +1939,8 @@ def _preprocessing(values): return values._data, None return values, None - def _postprocessing( - vals, inference, nullable: bool = False, result_mask=None - ) -> ArrayLike: - if nullable: + def _postprocessing(vals, inference, result_mask=None) -> ArrayLike: + if result_mask is not None: if result_mask.ndim == 2: result_mask = result_mask[:, 0] return FloatingArray(np.sqrt(vals), result_mask.view(np.bool_)) @@ -3716,13 +3714,11 @@ def blk_func(values: ArrayLike) -> ArrayLike: mask = mask.reshape(-1, 1) func = partial(func, mask=mask) - if how != "std": - is_nullable = isinstance(values, BaseMaskedArray) - func = partial(func, nullable=is_nullable) - - elif isinstance(values, BaseMaskedArray): + result_mask = None + if isinstance(values, BaseMaskedArray): result_mask = np.zeros(result.shape, dtype=np.bool_) - func = partial(func, result_mask=result_mask) + + func = partial(func, result_mask=result_mask) # Call func to modify result in place if how == "std": @@ -3733,14 +3729,12 @@ def blk_func(values: ArrayLike) -> ArrayLike: if values.ndim == 1: assert result.shape[1] == 1, result.shape result = result[:, 0] + if result_mask is not None: + assert result_mask.shape[1] == 1, result_mask.shape + result_mask = result_mask[:, 0] if post_processing: - pp_kwargs: dict[str, bool | np.ndarray] = {} - pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray) - if how == "std" and pp_kwargs["nullable"]: - pp_kwargs["result_mask"] = result_mask - - result = post_processing(result, inferences, **pp_kwargs) + result = post_processing(result, inferences, result_mask=result_mask) if how == "std" and is_datetimelike: values = cast("DatetimeArray | TimedeltaArray", values)