Skip to content

Commit bbbd15c

Browse files
authored
ENH/BUG: Use Kleene logic for groupby any/all (#40819)
1 parent ad5ee33 commit bbbd15c

File tree

6 files changed

+188
-33
lines changed

6 files changed

+188
-33
lines changed

asv_bench/benchmarks/groupby.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,19 @@ class GroupByCythonAgg:
480480
param_names = ["dtype", "method"]
481481
params = [
482482
["float64"],
483-
["sum", "prod", "min", "max", "mean", "median", "var", "first", "last"],
483+
[
484+
"sum",
485+
"prod",
486+
"min",
487+
"max",
488+
"mean",
489+
"median",
490+
"var",
491+
"first",
492+
"last",
493+
"any",
494+
"all",
495+
],
484496
]
485497

486498
def setup(self, dtype, method):

doc/source/whatsnew/v1.3.0.rst

+5
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ Other enhancements
217217
- :class:`RangeIndex` can now be constructed by passing a ``range`` object directly e.g. ``pd.RangeIndex(range(3))`` (:issue:`12067`)
218218
- :meth:`round` being enabled for the nullable integer and floating dtypes (:issue:`38844`)
219219
- :meth:`pandas.read_csv` and :meth:`pandas.read_json` expose the argument ``encoding_errors`` to control how encoding errors are handled (:issue:`39450`)
220+
- :meth:`.GroupBy.any` and :meth:`.GroupBy.all` use Kleene logic with nullable data types (:issue:`37506`)
221+
- :meth:`.GroupBy.any` and :meth:`.GroupBy.all` return a ``BooleanDtype`` for columns with nullable data types (:issue:`33449`)
222+
-
220223

221224
.. ---------------------------------------------------------------------------
222225
@@ -787,6 +790,8 @@ Groupby/resample/rolling
787790
- Bug in :meth:`Series.asfreq` and :meth:`DataFrame.asfreq` dropping rows when the index is not sorted (:issue:`39805`)
788791
- Bug in aggregation functions for :class:`DataFrame` not respecting ``numeric_only`` argument when ``level`` keyword was given (:issue:`40660`)
789792
- Bug in :class:`core.window.RollingGroupby` where ``as_index=False`` argument in ``groupby`` was ignored (:issue:`39433`)
793+
- Bug in :meth:`.GroupBy.any` and :meth:`.GroupBy.all` raising ``ValueError`` when using with nullable type columns holding ``NA`` even with ``skipna=True`` (:issue:`40585`)
794+
790795

791796
Reshaping
792797
^^^^^^^^^

pandas/_libs/groupby.pyx

+25-8
Original file line numberDiff line numberDiff line change
@@ -388,40 +388,47 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels,
388388

389389
@cython.boundscheck(False)
390390
@cython.wraparound(False)
391-
def group_any_all(uint8_t[::1] out,
392-
const uint8_t[::1] values,
391+
def group_any_all(int8_t[::1] out,
392+
const int8_t[::1] values,
393393
const intp_t[:] labels,
394394
const uint8_t[::1] mask,
395395
str val_test,
396-
bint skipna) -> None:
396+
bint skipna,
397+
bint nullable) -> None:
397398
"""
398-
Aggregated boolean values to show truthfulness of group elements.
399+
Aggregated boolean values to show truthfulness of group elements. If the
400+
input is a nullable type (nullable=True), the result will be computed
401+
using Kleene logic.
399402

400403
Parameters
401404
----------
402-
out : np.ndarray[np.uint8]
405+
out : np.ndarray[np.int8]
403406
Values into which this method will write its results.
404407
labels : np.ndarray[np.intp]
405408
Array containing unique label for each group, with its
406409
ordering matching up to the corresponding record in `values`
407-
values : np.ndarray[np.uint8]
410+
values : np.ndarray[np.int8]
408411
Containing the truth value of each element.
409412
mask : np.ndarray[np.uint8]
410413
Indicating whether a value is na or not.
411414
val_test : {'any', 'all'}
412415
String object dictating whether to use any or all truth testing
413416
skipna : bool
414417
Flag to ignore nan values during truth testing
418+
nullable : bool
419+
Whether or not the input is a nullable type. If True, the
420+
result will be computed using Kleene logic
415421

416422
Notes
417423
-----
418424
This method modifies the `out` parameter rather than returning an object.
419-
The returned values will either be 0 or 1 (False or True, respectively).
425+
The returned values will either be 0, 1 (False or True, respectively), or
426+
-1 to signify a masked position in the case of a nullable input.
420427
"""
421428
cdef:
422429
Py_ssize_t i, N = len(labels)
423430
intp_t lab
424-
uint8_t flag_val
431+
int8_t flag_val
425432

426433
if val_test == 'all':
427434
# Because the 'all' value of an empty iterable in Python is True we can
@@ -444,6 +451,16 @@ def group_any_all(uint8_t[::1] out,
444451
if lab < 0 or (skipna and mask[i]):
445452
continue
446453

454+
if nullable and mask[i]:
455+
# Set the position as masked if `out[lab] != flag_val`, which
456+
# would indicate True/False has not yet been seen for any/all,
457+
# so by Kleene logic the result is currently unknown
458+
if out[lab] != flag_val:
459+
out[lab] = -1
460+
continue
461+
462+
# If True and 'any' or False and 'all', the result is
463+
# already determined
447464
if values[i] == flag_val:
448465
out[lab] = flag_val
449466

pandas/core/groupby/groupby.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class providing the base-class of operations.
7777
from pandas.core import nanops
7878
import pandas.core.algorithms as algorithms
7979
from pandas.core.arrays import (
80+
BaseMaskedArray,
81+
BooleanArray,
8082
Categorical,
8183
ExtensionArray,
8284
)
@@ -1413,24 +1415,34 @@ def _bool_agg(self, val_test, skipna):
14131415
Shared func to call any / all Cython GroupBy implementations.
14141416
"""
14151417

1416-
def objs_to_bool(vals: np.ndarray) -> tuple[np.ndarray, type]:
1418+
def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
14171419
if is_object_dtype(vals):
14181420
vals = np.array([bool(x) for x in vals])
1421+
elif isinstance(vals, BaseMaskedArray):
1422+
vals = vals._data.astype(bool, copy=False)
14191423
else:
14201424
vals = vals.astype(bool)
14211425

1422-
return vals.view(np.uint8), bool
1426+
return vals.view(np.int8), bool
14231427

1424-
def result_to_bool(result: np.ndarray, inference: type) -> np.ndarray:
1425-
return result.astype(inference, copy=False)
1428+
def result_to_bool(
1429+
result: np.ndarray,
1430+
inference: type,
1431+
nullable: bool = False,
1432+
) -> ArrayLike:
1433+
if nullable:
1434+
return BooleanArray(result.astype(bool, copy=False), result == -1)
1435+
else:
1436+
return result.astype(inference, copy=False)
14261437

14271438
return self._get_cythonized_result(
14281439
"group_any_all",
14291440
aggregate=True,
14301441
numeric_only=False,
1431-
cython_dtype=np.dtype(np.uint8),
1442+
cython_dtype=np.dtype(np.int8),
14321443
needs_values=True,
14331444
needs_mask=True,
1445+
needs_nullable=True,
14341446
pre_processing=objs_to_bool,
14351447
post_processing=result_to_bool,
14361448
val_test=val_test,
@@ -2613,6 +2625,7 @@ def _get_cythonized_result(
26132625
needs_counts: bool = False,
26142626
needs_values: bool = False,
26152627
needs_2d: bool = False,
2628+
needs_nullable: bool = False,
26162629
min_count: int | None = None,
26172630
needs_mask: bool = False,
26182631
needs_ngroups: bool = False,
@@ -2649,6 +2662,9 @@ def _get_cythonized_result(
26492662
signature
26502663
needs_ngroups : bool, default False
26512664
Whether number of groups is part of the Cython call signature
2665+
needs_nullable : bool, default False
2666+
Whether a bool specifying if the input is nullable is part
2667+
of the Cython call signature
26522668
result_is_index : bool, default False
26532669
Whether the result of the Cython operation is an index of
26542670
values to be retrieved, instead of the actual values themselves
@@ -2664,7 +2680,8 @@ def _get_cythonized_result(
26642680
Function to be applied to result of Cython function. Should accept
26652681
an array of values as the first argument and type inferences as its
26662682
second argument, i.e. the signature should be
2667-
(ndarray, Type).
2683+
(ndarray, Type). If `needs_nullable=True`, a third argument should be
2684+
`nullable`, to allow for processing specific to nullable values.
26682685
**kwargs : dict
26692686
Extra arguments to be passed back to Cython funcs
26702687
@@ -2739,6 +2756,12 @@ def _get_cythonized_result(
27392756
if needs_ngroups:
27402757
func = partial(func, ngroups)
27412758

2759+
if needs_nullable:
2760+
is_nullable = isinstance(values, BaseMaskedArray)
2761+
func = partial(func, nullable=is_nullable)
2762+
if post_processing:
2763+
post_processing = partial(post_processing, nullable=is_nullable)
2764+
27422765
func(**kwargs) # Call func to modify indexer values in place
27432766

27442767
if needs_2d:

pandas/tests/groupby/test_any_all.py

+84
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import numpy as np
44
import pytest
55

6+
import pandas as pd
67
from pandas import (
78
DataFrame,
89
Index,
10+
Series,
911
isna,
1012
)
1113
import pandas._testing as tm
@@ -68,3 +70,85 @@ def test_bool_aggs_dup_column_labels(bool_agg_func):
6870

6971
expected = df
7072
tm.assert_frame_equal(result, expected)
73+
74+
75+
@pytest.mark.parametrize("bool_agg_func", ["any", "all"])
76+
@pytest.mark.parametrize("skipna", [True, False])
77+
@pytest.mark.parametrize(
78+
"data",
79+
[
80+
[False, False, False],
81+
[True, True, True],
82+
[pd.NA, pd.NA, pd.NA],
83+
[False, pd.NA, False],
84+
[True, pd.NA, True],
85+
[True, pd.NA, False],
86+
],
87+
)
88+
def test_masked_kleene_logic(bool_agg_func, skipna, data):
89+
# GH#37506
90+
ser = Series(data, dtype="boolean")
91+
92+
# The result should match aggregating on the whole series. Correctness
93+
# there is verified in test_reductions.py::test_any_all_boolean_kleene_logic
94+
expected_data = getattr(ser, bool_agg_func)(skipna=skipna)
95+
expected = Series(expected_data, dtype="boolean")
96+
97+
result = ser.groupby([0, 0, 0]).agg(bool_agg_func, skipna=skipna)
98+
tm.assert_series_equal(result, expected)
99+
100+
101+
@pytest.mark.parametrize(
102+
"dtype1,dtype2,exp_col1,exp_col2",
103+
[
104+
(
105+
"float",
106+
"Float64",
107+
np.array([True], dtype=bool),
108+
pd.array([pd.NA], dtype="boolean"),
109+
),
110+
(
111+
"Int64",
112+
"float",
113+
pd.array([pd.NA], dtype="boolean"),
114+
np.array([True], dtype=bool),
115+
),
116+
(
117+
"Int64",
118+
"Int64",
119+
pd.array([pd.NA], dtype="boolean"),
120+
pd.array([pd.NA], dtype="boolean"),
121+
),
122+
(
123+
"Float64",
124+
"boolean",
125+
pd.array([pd.NA], dtype="boolean"),
126+
pd.array([pd.NA], dtype="boolean"),
127+
),
128+
],
129+
)
130+
def test_masked_mixed_types(dtype1, dtype2, exp_col1, exp_col2):
131+
# GH#37506
132+
data = [1.0, np.nan]
133+
df = DataFrame(
134+
{"col1": pd.array(data, dtype=dtype1), "col2": pd.array(data, dtype=dtype2)}
135+
)
136+
result = df.groupby([1, 1]).agg("all", skipna=False)
137+
138+
expected = DataFrame({"col1": exp_col1, "col2": exp_col2}, index=[1])
139+
tm.assert_frame_equal(result, expected)
140+
141+
142+
@pytest.mark.parametrize("bool_agg_func", ["any", "all"])
143+
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
144+
@pytest.mark.parametrize("skipna", [True, False])
145+
def test_masked_bool_aggs_skipna(bool_agg_func, dtype, skipna, frame_or_series):
146+
# GH#40585
147+
obj = frame_or_series([pd.NA, 1], dtype=dtype)
148+
expected_res = True
149+
if not skipna and bool_agg_func == "all":
150+
expected_res = pd.NA
151+
expected = frame_or_series([expected_res], index=[1], dtype="boolean")
152+
153+
result = obj.groupby([1, 1]).agg(bool_agg_func, skipna=skipna)
154+
tm.assert_equal(result, expected)

pandas/tests/reductions/test_reductions.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -941,31 +941,45 @@ def test_all_any_params(self):
941941
with pytest.raises(NotImplementedError, match=msg):
942942
s.all(bool_only=True)
943943

944-
def test_all_any_boolean(self):
945-
# Check skipna, with boolean type
946-
s1 = Series([pd.NA, True], dtype="boolean")
947-
s2 = Series([pd.NA, False], dtype="boolean")
948-
assert s1.all(skipna=False) is pd.NA # NA && True => NA
949-
assert s1.all(skipna=True)
950-
assert s2.any(skipna=False) is pd.NA # NA || False => NA
951-
assert not s2.any(skipna=True)
944+
@pytest.mark.parametrize("bool_agg_func", ["any", "all"])
945+
@pytest.mark.parametrize("skipna", [True, False])
946+
@pytest.mark.parametrize(
947+
# expected_data indexed as [[skipna=False/any, skipna=False/all],
948+
# [skipna=True/any, skipna=True/all]]
949+
"data,expected_data",
950+
[
951+
([False, False, False], [[False, False], [False, False]]),
952+
([True, True, True], [[True, True], [True, True]]),
953+
([pd.NA, pd.NA, pd.NA], [[pd.NA, pd.NA], [False, True]]),
954+
([False, pd.NA, False], [[pd.NA, False], [False, False]]),
955+
([True, pd.NA, True], [[True, pd.NA], [True, True]]),
956+
([True, pd.NA, False], [[True, False], [True, False]]),
957+
],
958+
)
959+
def test_any_all_boolean_kleene_logic(
960+
self, bool_agg_func, skipna, data, expected_data
961+
):
962+
ser = Series(data, dtype="boolean")
963+
expected = expected_data[skipna][bool_agg_func == "all"]
952964

953-
# GH-33253: all True / all False values buggy with skipna=False
954-
s3 = Series([True, True], dtype="boolean")
955-
s4 = Series([False, False], dtype="boolean")
956-
assert s3.all(skipna=False)
957-
assert not s4.any(skipna=False)
965+
result = getattr(ser, bool_agg_func)(skipna=skipna)
966+
assert (result is pd.NA and expected is pd.NA) or result == expected
958967

959-
# Check level TODO(GH-33449) result should also be boolean
960-
s = Series(
968+
@pytest.mark.parametrize(
969+
"bool_agg_func,expected",
970+
[("all", [False, True, False]), ("any", [False, True, True])],
971+
)
972+
def test_any_all_boolean_level(self, bool_agg_func, expected):
973+
# GH#33449
974+
ser = Series(
961975
[False, False, True, True, False, True],
962976
index=[0, 0, 1, 1, 2, 2],
963977
dtype="boolean",
964978
)
965979
with tm.assert_produces_warning(FutureWarning):
966-
tm.assert_series_equal(s.all(level=0), Series([False, True, False]))
967-
with tm.assert_produces_warning(FutureWarning):
968-
tm.assert_series_equal(s.any(level=0), Series([False, True, True]))
980+
result = getattr(ser, bool_agg_func)(level=0)
981+
expected = Series(expected, dtype="boolean")
982+
tm.assert_series_equal(result, expected)
969983

970984
def test_any_axis1_bool_only(self):
971985
# GH#32432

0 commit comments

Comments
 (0)