Skip to content

Commit 079ef5c

Browse files
mzeitlin11JulianWgs
authored andcommitted
ENH/BUG: Use Kleene logic for groupby any/all (pandas-dev#40819)
1 parent c1068f3 commit 079ef5c

File tree

6 files changed

+188
-33
lines changed

6 files changed

+188
-33
lines changed

asv_bench/benchmarks/groupby.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,19 @@ class GroupByCythonAgg:
480480
param_names = ["dtype", "method"]
481481
params = [
482482
["float64"],
483-
["sum", "prod", "min", "max", "mean", "median", "var", "first", "last"],
483+
[
484+
"sum",
485+
"prod",
486+
"min",
487+
"max",
488+
"mean",
489+
"median",
490+
"var",
491+
"first",
492+
"last",
493+
"any",
494+
"all",
495+
],
484496
]
485497

486498
def setup(self, dtype, method):

doc/source/whatsnew/v1.3.0.rst

+5
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,9 @@ Other enhancements
217217
- :class:`RangeIndex` can now be constructed by passing a ``range`` object directly e.g. ``pd.RangeIndex(range(3))`` (:issue:`12067`)
218218
- :meth:`round` being enabled for the nullable integer and floating dtypes (:issue:`38844`)
219219
- :meth:`pandas.read_csv` and :meth:`pandas.read_json` expose the argument ``encoding_errors`` to control how encoding errors are handled (:issue:`39450`)
220+
- :meth:`.GroupBy.any` and :meth:`.GroupBy.all` use Kleene logic with nullable data types (:issue:`37506`)
221+
- :meth:`.GroupBy.any` and :meth:`.GroupBy.all` return a ``BooleanDtype`` for columns with nullable data types (:issue:`33449`)
222+
-
220223

221224
.. ---------------------------------------------------------------------------
222225
@@ -787,6 +790,8 @@ Groupby/resample/rolling
787790
- Bug in :meth:`Series.asfreq` and :meth:`DataFrame.asfreq` dropping rows when the index is not sorted (:issue:`39805`)
788791
- Bug in aggregation functions for :class:`DataFrame` not respecting ``numeric_only`` argument when ``level`` keyword was given (:issue:`40660`)
789792
- Bug in :class:`core.window.RollingGroupby` where ``as_index=False`` argument in ``groupby`` was ignored (:issue:`39433`)
793+
- Bug in :meth:`.GroupBy.any` and :meth:`.GroupBy.all` raising ``ValueError`` when using with nullable type columns holding ``NA`` even with ``skipna=True`` (:issue:`40585`)
794+
790795

791796
Reshaping
792797
^^^^^^^^^

pandas/_libs/groupby.pyx

+25-8
Original file line numberDiff line numberDiff line change
@@ -388,40 +388,47 @@ def group_fillna_indexer(ndarray[int64_t] out, ndarray[intp_t] labels,
388388

389389
@cython.boundscheck(False)
390390
@cython.wraparound(False)
391-
def group_any_all(uint8_t[::1] out,
392-
const uint8_t[::1] values,
391+
def group_any_all(int8_t[::1] out,
392+
const int8_t[::1] values,
393393
const intp_t[:] labels,
394394
const uint8_t[::1] mask,
395395
str val_test,
396-
bint skipna) -> None:
396+
bint skipna,
397+
bint nullable) -> None:
397398
"""
398-
Aggregated boolean values to show truthfulness of group elements.
399+
Aggregated boolean values to show truthfulness of group elements. If the
400+
input is a nullable type (nullable=True), the result will be computed
401+
using Kleene logic.
399402

400403
Parameters
401404
----------
402-
out : np.ndarray[np.uint8]
405+
out : np.ndarray[np.int8]
403406
Values into which this method will write its results.
404407
labels : np.ndarray[np.intp]
405408
Array containing unique label for each group, with its
406409
ordering matching up to the corresponding record in `values`
407-
values : np.ndarray[np.uint8]
410+
values : np.ndarray[np.int8]
408411
Containing the truth value of each element.
409412
mask : np.ndarray[np.uint8]
410413
Indicating whether a value is na or not.
411414
val_test : {'any', 'all'}
412415
String object dictating whether to use any or all truth testing
413416
skipna : bool
414417
Flag to ignore nan values during truth testing
418+
nullable : bool
419+
Whether or not the input is a nullable type. If True, the
420+
result will be computed using Kleene logic
415421

416422
Notes
417423
-----
418424
This method modifies the `out` parameter rather than returning an object.
419-
The returned values will either be 0 or 1 (False or True, respectively).
425+
The returned values will either be 0, 1 (False or True, respectively), or
426+
-1 to signify a masked position in the case of a nullable input.
420427
"""
421428
cdef:
422429
Py_ssize_t i, N = len(labels)
423430
intp_t lab
424-
uint8_t flag_val
431+
int8_t flag_val
425432

426433
if val_test == 'all':
427434
# Because the 'all' value of an empty iterable in Python is True we can
@@ -444,6 +451,16 @@ def group_any_all(uint8_t[::1] out,
444451
if lab < 0 or (skipna and mask[i]):
445452
continue
446453

454+
if nullable and mask[i]:
455+
# Set the position as masked if `out[lab] != flag_val`, which
456+
# would indicate True/False has not yet been seen for any/all,
457+
# so by Kleene logic the result is currently unknown
458+
if out[lab] != flag_val:
459+
out[lab] = -1
460+
continue
461+
462+
# If True and 'any' or False and 'all', the result is
463+
# already determined
447464
if values[i] == flag_val:
448465
out[lab] = flag_val
449466

pandas/core/groupby/groupby.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ class providing the base-class of operations.
7777
from pandas.core import nanops
7878
import pandas.core.algorithms as algorithms
7979
from pandas.core.arrays import (
80+
BaseMaskedArray,
81+
BooleanArray,
8082
Categorical,
8183
ExtensionArray,
8284
)
@@ -1418,24 +1420,34 @@ def _bool_agg(self, val_test, skipna):
14181420
Shared func to call any / all Cython GroupBy implementations.
14191421
"""
14201422

1421-
def objs_to_bool(vals: np.ndarray) -> tuple[np.ndarray, type]:
1423+
def objs_to_bool(vals: ArrayLike) -> tuple[np.ndarray, type]:
14221424
if is_object_dtype(vals):
14231425
vals = np.array([bool(x) for x in vals])
1426+
elif isinstance(vals, BaseMaskedArray):
1427+
vals = vals._data.astype(bool, copy=False)
14241428
else:
14251429
vals = vals.astype(bool)
14261430

1427-
return vals.view(np.uint8), bool
1431+
return vals.view(np.int8), bool
14281432

1429-
def result_to_bool(result: np.ndarray, inference: type) -> np.ndarray:
1430-
return result.astype(inference, copy=False)
1433+
def result_to_bool(
1434+
result: np.ndarray,
1435+
inference: type,
1436+
nullable: bool = False,
1437+
) -> ArrayLike:
1438+
if nullable:
1439+
return BooleanArray(result.astype(bool, copy=False), result == -1)
1440+
else:
1441+
return result.astype(inference, copy=False)
14311442

14321443
return self._get_cythonized_result(
14331444
"group_any_all",
14341445
aggregate=True,
14351446
numeric_only=False,
1436-
cython_dtype=np.dtype(np.uint8),
1447+
cython_dtype=np.dtype(np.int8),
14371448
needs_values=True,
14381449
needs_mask=True,
1450+
needs_nullable=True,
14391451
pre_processing=objs_to_bool,
14401452
post_processing=result_to_bool,
14411453
val_test=val_test,
@@ -2618,6 +2630,7 @@ def _get_cythonized_result(
26182630
needs_counts: bool = False,
26192631
needs_values: bool = False,
26202632
needs_2d: bool = False,
2633+
needs_nullable: bool = False,
26212634
min_count: int | None = None,
26222635
needs_mask: bool = False,
26232636
needs_ngroups: bool = False,
@@ -2654,6 +2667,9 @@ def _get_cythonized_result(
26542667
signature
26552668
needs_ngroups : bool, default False
26562669
Whether number of groups is part of the Cython call signature
2670+
needs_nullable : bool, default False
2671+
Whether a bool specifying if the input is nullable is part
2672+
of the Cython call signature
26572673
result_is_index : bool, default False
26582674
Whether the result of the Cython operation is an index of
26592675
values to be retrieved, instead of the actual values themselves
@@ -2669,7 +2685,8 @@ def _get_cythonized_result(
26692685
Function to be applied to result of Cython function. Should accept
26702686
an array of values as the first argument and type inferences as its
26712687
second argument, i.e. the signature should be
2672-
(ndarray, Type).
2688+
(ndarray, Type). If `needs_nullable=True`, a third argument should be
2689+
`nullable`, to allow for processing specific to nullable values.
26732690
**kwargs : dict
26742691
Extra arguments to be passed back to Cython funcs
26752692
@@ -2744,6 +2761,12 @@ def _get_cythonized_result(
27442761
if needs_ngroups:
27452762
func = partial(func, ngroups)
27462763

2764+
if needs_nullable:
2765+
is_nullable = isinstance(values, BaseMaskedArray)
2766+
func = partial(func, nullable=is_nullable)
2767+
if post_processing:
2768+
post_processing = partial(post_processing, nullable=is_nullable)
2769+
27472770
func(**kwargs) # Call func to modify indexer values in place
27482771

27492772
if needs_2d:

pandas/tests/groupby/test_any_all.py

+84
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import numpy as np
44
import pytest
55

6+
import pandas as pd
67
from pandas import (
78
DataFrame,
89
Index,
10+
Series,
911
isna,
1012
)
1113
import pandas._testing as tm
@@ -68,3 +70,85 @@ def test_bool_aggs_dup_column_labels(bool_agg_func):
6870

6971
expected = df
7072
tm.assert_frame_equal(result, expected)
73+
74+
75+
@pytest.mark.parametrize("bool_agg_func", ["any", "all"])
76+
@pytest.mark.parametrize("skipna", [True, False])
77+
@pytest.mark.parametrize(
78+
"data",
79+
[
80+
[False, False, False],
81+
[True, True, True],
82+
[pd.NA, pd.NA, pd.NA],
83+
[False, pd.NA, False],
84+
[True, pd.NA, True],
85+
[True, pd.NA, False],
86+
],
87+
)
88+
def test_masked_kleene_logic(bool_agg_func, skipna, data):
89+
# GH#37506
90+
ser = Series(data, dtype="boolean")
91+
92+
# The result should match aggregating on the whole series. Correctness
93+
# there is verified in test_reductions.py::test_any_all_boolean_kleene_logic
94+
expected_data = getattr(ser, bool_agg_func)(skipna=skipna)
95+
expected = Series(expected_data, dtype="boolean")
96+
97+
result = ser.groupby([0, 0, 0]).agg(bool_agg_func, skipna=skipna)
98+
tm.assert_series_equal(result, expected)
99+
100+
101+
@pytest.mark.parametrize(
102+
"dtype1,dtype2,exp_col1,exp_col2",
103+
[
104+
(
105+
"float",
106+
"Float64",
107+
np.array([True], dtype=bool),
108+
pd.array([pd.NA], dtype="boolean"),
109+
),
110+
(
111+
"Int64",
112+
"float",
113+
pd.array([pd.NA], dtype="boolean"),
114+
np.array([True], dtype=bool),
115+
),
116+
(
117+
"Int64",
118+
"Int64",
119+
pd.array([pd.NA], dtype="boolean"),
120+
pd.array([pd.NA], dtype="boolean"),
121+
),
122+
(
123+
"Float64",
124+
"boolean",
125+
pd.array([pd.NA], dtype="boolean"),
126+
pd.array([pd.NA], dtype="boolean"),
127+
),
128+
],
129+
)
130+
def test_masked_mixed_types(dtype1, dtype2, exp_col1, exp_col2):
131+
# GH#37506
132+
data = [1.0, np.nan]
133+
df = DataFrame(
134+
{"col1": pd.array(data, dtype=dtype1), "col2": pd.array(data, dtype=dtype2)}
135+
)
136+
result = df.groupby([1, 1]).agg("all", skipna=False)
137+
138+
expected = DataFrame({"col1": exp_col1, "col2": exp_col2}, index=[1])
139+
tm.assert_frame_equal(result, expected)
140+
141+
142+
@pytest.mark.parametrize("bool_agg_func", ["any", "all"])
143+
@pytest.mark.parametrize("dtype", ["Int64", "Float64", "boolean"])
144+
@pytest.mark.parametrize("skipna", [True, False])
145+
def test_masked_bool_aggs_skipna(bool_agg_func, dtype, skipna, frame_or_series):
146+
# GH#40585
147+
obj = frame_or_series([pd.NA, 1], dtype=dtype)
148+
expected_res = True
149+
if not skipna and bool_agg_func == "all":
150+
expected_res = pd.NA
151+
expected = frame_or_series([expected_res], index=[1], dtype="boolean")
152+
153+
result = obj.groupby([1, 1]).agg(bool_agg_func, skipna=skipna)
154+
tm.assert_equal(result, expected)

pandas/tests/reductions/test_reductions.py

+32-18
Original file line numberDiff line numberDiff line change
@@ -941,31 +941,45 @@ def test_all_any_params(self):
941941
with pytest.raises(NotImplementedError, match=msg):
942942
s.all(bool_only=True)
943943

944-
def test_all_any_boolean(self):
945-
# Check skipna, with boolean type
946-
s1 = Series([pd.NA, True], dtype="boolean")
947-
s2 = Series([pd.NA, False], dtype="boolean")
948-
assert s1.all(skipna=False) is pd.NA # NA && True => NA
949-
assert s1.all(skipna=True)
950-
assert s2.any(skipna=False) is pd.NA # NA || False => NA
951-
assert not s2.any(skipna=True)
944+
@pytest.mark.parametrize("bool_agg_func", ["any", "all"])
945+
@pytest.mark.parametrize("skipna", [True, False])
946+
@pytest.mark.parametrize(
947+
# expected_data indexed as [[skipna=False/any, skipna=False/all],
948+
# [skipna=True/any, skipna=True/all]]
949+
"data,expected_data",
950+
[
951+
([False, False, False], [[False, False], [False, False]]),
952+
([True, True, True], [[True, True], [True, True]]),
953+
([pd.NA, pd.NA, pd.NA], [[pd.NA, pd.NA], [False, True]]),
954+
([False, pd.NA, False], [[pd.NA, False], [False, False]]),
955+
([True, pd.NA, True], [[True, pd.NA], [True, True]]),
956+
([True, pd.NA, False], [[True, False], [True, False]]),
957+
],
958+
)
959+
def test_any_all_boolean_kleene_logic(
960+
self, bool_agg_func, skipna, data, expected_data
961+
):
962+
ser = Series(data, dtype="boolean")
963+
expected = expected_data[skipna][bool_agg_func == "all"]
952964

953-
# GH-33253: all True / all False values buggy with skipna=False
954-
s3 = Series([True, True], dtype="boolean")
955-
s4 = Series([False, False], dtype="boolean")
956-
assert s3.all(skipna=False)
957-
assert not s4.any(skipna=False)
965+
result = getattr(ser, bool_agg_func)(skipna=skipna)
966+
assert (result is pd.NA and expected is pd.NA) or result == expected
958967

959-
# Check level TODO(GH-33449) result should also be boolean
960-
s = Series(
968+
@pytest.mark.parametrize(
969+
"bool_agg_func,expected",
970+
[("all", [False, True, False]), ("any", [False, True, True])],
971+
)
972+
def test_any_all_boolean_level(self, bool_agg_func, expected):
973+
# GH#33449
974+
ser = Series(
961975
[False, False, True, True, False, True],
962976
index=[0, 0, 1, 1, 2, 2],
963977
dtype="boolean",
964978
)
965979
with tm.assert_produces_warning(FutureWarning):
966-
tm.assert_series_equal(s.all(level=0), Series([False, True, False]))
967-
with tm.assert_produces_warning(FutureWarning):
968-
tm.assert_series_equal(s.any(level=0), Series([False, True, True]))
980+
result = getattr(ser, bool_agg_func)(level=0)
981+
expected = Series(expected, dtype="boolean")
982+
tm.assert_series_equal(result, expected)
969983

970984
def test_any_axis1_bool_only(self):
971985
# GH#32432

0 commit comments

Comments
 (0)