Skip to content

REF: uses_mask in group_any_all #52043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -565,11 +565,11 @@ def group_any_all(
const uint8_t[:, :] mask,
str val_test,
bint skipna,
bint nullable,
uint8_t[:, ::1] result_mask,
) -> None:
"""
Aggregated boolean values to show truthfulness of group elements. If the
input is a nullable type (nullable=True), the result will be computed
input is a nullable type (result_mask is not None), the result will be computed
using Kleene logic.

Parameters
Expand All @@ -587,9 +587,9 @@ def group_any_all(
String object dictating whether to use any or all truth testing
skipna : bool
Flag to ignore nan values during truth testing
nullable : bool
Whether or not the input is a nullable type. If True, the
result will be computed using Kleene logic
result_mask : ndarray[bool, ndim=2], optional
If not None, these specify locations in the output that are NA.
Modified in-place.

Notes
-----
Expand All @@ -601,6 +601,7 @@ def group_any_all(
Py_ssize_t i, j, N = len(labels), K = out.shape[1]
intp_t lab
int8_t flag_val, val
bint uses_mask = result_mask is not None

if val_test == "all":
# Because the 'all' value of an empty iterable in Python is True we can
Expand All @@ -627,12 +628,12 @@ def group_any_all(
if skipna and mask[i, j]:
continue

if nullable and mask[i, j]:
if uses_mask and mask[i, j]:
# Set the position as masked if `out[lab] != flag_val`, which
# would indicate True/False has not yet been seen for any/all,
# so by Kleene logic the result is currently unknown
if out[lab, j] != flag_val:
out[lab, j] = -1
result_mask[lab, j] = 1
Copy link
Member

@phofl phofl Mar 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For any: If NA is encountered as first value in the group you are setting the mask to 1 here but you don't reset it if you find another value in the group that is not NA. You'll have to update the result_mask if you find another value.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh OK. I missed that we are checking out[lab, j] != ... here as opposed to values[i, j] != .... Thanks.

continue

val = values[i, j]
Expand All @@ -641,6 +642,8 @@ def group_any_all(
# already determined
if val == flag_val:
out[lab, j] = flag_val
if uses_mask:
result_mask[lab, j] = 0


# ----------------------------------------------------------------------
Expand Down
32 changes: 13 additions & 19 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,10 +1658,10 @@ def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
def result_to_bool(
result: np.ndarray,
inference: type,
nullable: bool = False,
result_mask,
) -> ArrayLike:
if nullable:
return BooleanArray(result.astype(bool, copy=False), result == -1)
if result_mask is not None:
return BooleanArray(result.astype(bool, copy=False), result_mask)
else:
return result.astype(inference, copy=False)

Expand Down Expand Up @@ -1939,10 +1939,8 @@ def _preprocessing(values):
return values._data, None
return values, None

def _postprocessing(
vals, inference, nullable: bool = False, result_mask=None
) -> ArrayLike:
if nullable:
def _postprocessing(vals, inference, result_mask=None) -> ArrayLike:
if result_mask is not None:
if result_mask.ndim == 2:
result_mask = result_mask[:, 0]
return FloatingArray(np.sqrt(vals), result_mask.view(np.bool_))
Expand Down Expand Up @@ -3716,13 +3714,11 @@ def blk_func(values: ArrayLike) -> ArrayLike:
mask = mask.reshape(-1, 1)
func = partial(func, mask=mask)

if how != "std":
is_nullable = isinstance(values, BaseMaskedArray)
func = partial(func, nullable=is_nullable)

elif isinstance(values, BaseMaskedArray):
result_mask = None
if isinstance(values, BaseMaskedArray):
result_mask = np.zeros(result.shape, dtype=np.bool_)
func = partial(func, result_mask=result_mask)

func = partial(func, result_mask=result_mask)

# Call func to modify result in place
if how == "std":
Expand All @@ -3733,14 +3729,12 @@ def blk_func(values: ArrayLike) -> ArrayLike:
if values.ndim == 1:
assert result.shape[1] == 1, result.shape
result = result[:, 0]
if result_mask is not None:
assert result_mask.shape[1] == 1, result_mask.shape
result_mask = result_mask[:, 0]

if post_processing:
pp_kwargs: dict[str, bool | np.ndarray] = {}
pp_kwargs["nullable"] = isinstance(values, BaseMaskedArray)
if how == "std" and pp_kwargs["nullable"]:
pp_kwargs["result_mask"] = result_mask

result = post_processing(result, inferences, **pp_kwargs)
result = post_processing(result, inferences, result_mask=result_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll probably have to remove the flag is_nullable for the std post_processing function (the default is False) and check for result_mask is not None


if how == "std" and is_datetimelike:
values = cast("DatetimeArray | TimedeltaArray", values)
Expand Down