Skip to content

Commit 1f7a7f2

Browse files
authored
REF: uses_mask in group_any_all (#52043)
1 parent fb282b6 commit 1f7a7f2

File tree

2 files changed

+23
-26
lines changed

2 files changed

+23
-26
lines changed

pandas/_libs/groupby.pyx

+10-7
Original file line numberDiff line numberDiff line change
@@ -565,11 +565,11 @@ def group_any_all(
565565
const uint8_t[:, :] mask,
566566
str val_test,
567567
bint skipna,
568-
bint nullable,
568+
uint8_t[:, ::1] result_mask,
569569
) -> None:
570570
"""
571571
Aggregated boolean values to show truthfulness of group elements. If the
572-
input is a nullable type (nullable=True), the result will be computed
572+
input is a nullable type (result_mask is not None), the result will be computed
573573
using Kleene logic.
574574

575575
Parameters
@@ -587,9 +587,9 @@ def group_any_all(
587587
String object dictating whether to use any or all truth testing
588588
skipna : bool
589589
Flag to ignore nan values during truth testing
590-
nullable : bool
591-
Whether or not the input is a nullable type. If True, the
592-
result will be computed using Kleene logic
590+
result_mask : ndarray[bool, ndim=2], optional
591+
If not None, these specify locations in the output that are NA.
592+
Modified in-place.
593593

594594
Notes
595595
-----
@@ -601,6 +601,7 @@ def group_any_all(
601601
Py_ssize_t i, j, N = len(labels), K = out.shape[1]
602602
intp_t lab
603603
int8_t flag_val, val
604+
bint uses_mask = result_mask is not None
604605

605606
if val_test == "all":
606607
# Because the 'all' value of an empty iterable in Python is True we can
@@ -627,12 +628,12 @@ def group_any_all(
627628
if skipna and mask[i, j]:
628629
continue
629630

630-
if nullable and mask[i, j]:
631+
if uses_mask and mask[i, j]:
631632
# Set the position as masked if `out[lab] != flag_val`, which
632633
# would indicate True/False has not yet been seen for any/all,
633634
# so by Kleene logic the result is currently unknown
634635
if out[lab, j] != flag_val:
635-
out[lab, j] = -1
636+
result_mask[lab, j] = 1
636637
continue
637638

638639
val = values[i, j]
@@ -641,6 +642,8 @@ def group_any_all(
641642
# already determined
642643
if val == flag_val:
643644
out[lab, j] = flag_val
645+
if uses_mask:
646+
result_mask[lab, j] = 0
644647

645648

646649
# ----------------------------------------------------------------------

pandas/core/groupby/groupby.py

+13-19
Original file line numberDiff line numberDiff line change
@@ -1704,10 +1704,10 @@ def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
17041704
def result_to_bool(
17051705
result: np.ndarray,
17061706
inference: type,
1707-
nullable: bool = False,
1707+
result_mask,
17081708
) -> ArrayLike:
1709-
if nullable:
1710-
return BooleanArray(result.astype(bool, copy=False), result == -1)
1709+
if result_mask is not None:
1710+
return BooleanArray(result.astype(bool, copy=False), result_mask)
17111711
else:
17121712
return result.astype(inference, copy=False)
17131713

@@ -1985,10 +1985,8 @@ def _preprocessing(values):
19851985
return values._data, None
19861986
return values, None
19871987

1988-
def _postprocessing(
1989-
vals, inference, nullable: bool = False, result_mask=None
1990-
) -> ArrayLike:
1991-
if nullable:
1988+
def _postprocessing(vals, inference, result_mask=None) -> ArrayLike:
1989+
if result_mask is not None:
19921990
if result_mask.ndim == 2:
19931991
result_mask = result_mask[:, 0]
19941992
return FloatingArray(np.sqrt(vals), result_mask.view(np.bool_))
@@ -3808,13 +3806,11 @@ def blk_func(values: ArrayLike) -> ArrayLike:
38083806
mask = mask.reshape(-1, 1)
38093807
func = partial(func, mask=mask)
38103808

3811-
if how != "std":
3812-
is_nullable = isinstance(values, BaseMaskedArray)
3813-
func = partial(func, nullable=is_nullable)
3814-
3815-
elif isinstance(values, BaseMaskedArray):
3809+
result_mask = None
3810+
if isinstance(values, BaseMaskedArray):
38163811
result_mask = np.zeros(result.shape, dtype=np.bool_)
3817-
func = partial(func, result_mask=result_mask)
3812+
3813+
func = partial(func, result_mask=result_mask)
38183814

38193815
# Call func to modify result in place
38203816
if how == "std":
@@ -3825,14 +3821,12 @@ def blk_func(values: ArrayLike) -> ArrayLike:
38253821
if values.ndim == 1:
38263822
assert result.shape[1] == 1, result.shape
38273823
result = result[:, 0]
3824+
if result_mask is not None:
3825+
assert result_mask.shape[1] == 1, result_mask.shape
3826+
result_mask = result_mask[:, 0]
38283827

38293828
if post_processing:
3830-
pp_kwargs: dict[str, bool | np.ndarray] = {}
3831-
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
3832-
if how == "std" and pp_kwargs["nullable"]:
3833-
pp_kwargs["result_mask"] = result_mask
3834-
3835-
result = post_processing(result, inferences, **pp_kwargs)
3829+
result = post_processing(result, inferences, result_mask=result_mask)
38363830

38373831
if how == "std" and is_datetimelike:
38383832
values = cast("DatetimeArray | TimedeltaArray", values)

0 commit comments

Comments
 (0)