Skip to content

Commit 5c25c14

Browse files
committed
checking dtype of input in pre-processing function and updating tests
1 parent 3def19b commit 5c25c14

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

pandas/core/groupby/groupby.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ class providing the base-class of operations.
3030
)
3131

3232
import numpy as np
33+
from pandas.core.arrays.boolean import BooleanDtype
3334

3435
from pandas._config.config import option_context
3536

@@ -1077,12 +1078,18 @@ def _bool_agg(self, val_test, skipna):
10771078
"""
10781079

10791080
def objs_to_bool(vals: np.ndarray) -> Tuple[np.ndarray, Type]:
1081+
result_dtype = None
1082+
if type(vals.dtype) == BooleanDtype:
1083+
result_dtype = BooleanDtype
1084+
else:
1085+
result_dtype = np.bool
1086+
10801087
if is_object_dtype(vals):
10811088
vals = np.array([bool(x) for x in vals])
10821089
else:
10831090
vals = vals.astype(np.bool)
10841091

1085-
return vals.view(np.uint8), np.bool
1092+
return vals.view(np.uint8), result_dtype
10861093

10871094
def result_to_bool(result: np.ndarray, inference: Type) -> np.ndarray:
10881095
return result.astype(inference, copy=False)
@@ -1867,7 +1874,6 @@ def pre_processor(vals: np.ndarray) -> Tuple[np.ndarray, Optional[Type]]:
18671874
raise TypeError(
18681875
"'quantile' cannot be performed against 'object' dtypes!"
18691876
)
1870-
18711877
inference = None
18721878
if is_integer_dtype(vals.dtype):
18731879
if is_extension_array_dtype(vals.dtype):

pandas/tests/reductions/test_reductions.py

+18-12
Original file line numberDiff line numberDiff line change
@@ -906,19 +906,25 @@ def test_all_any_boolean(self):
906906
assert s3.all(skipna=False)
907907
assert not s4.any(skipna=False)
908908

909-
# Check level TODO(GH-33449) result should also be boolean
910-
s = pd.Series(
911-
[False, False, True, True, False, True],
912-
index=[0, 0, 1, 1, 2, 2],
913-
dtype="boolean",
914-
)
915-
result = s.all(level=0)
916-
expected = Series([False, True, False], dtype="boolean")
917-
tm.assert_series_equal(result, expected)
909+
def test_any_all_bool_dtypes(self):
910+
911+
def test_any_all_on_bool_dtypes_utility(bool_dtype):
912+
# Check level TODO(GH-33449) result should also be boolean
913+
s = pd.Series(
914+
[False, False, True, True, False, True],
915+
index=[0, 0, 1, 1, 2, 2],
916+
dtype=bool_dtype,
917+
)
918+
result = s.all(level=0)
919+
expected = Series([False, True, False], dtype=bool_dtype)
920+
tm.assert_series_equal(result, expected)
918921

919-
result = s.any(level=0)
920-
expected = Series([False, True, True], dtype="boolean")
921-
tm.assert_series_equal(result, expected)
922+
result = s.any(level=0)
923+
expected = Series([False, True, True], dtype=bool_dtype)
924+
tm.assert_series_equal(result, expected)
925+
926+
test_any_all_on_bool_dtypes_utility("boolean")
927+
test_any_all_on_bool_dtypes_utility("bool")
922928

923929
def test_timedelta64_analytics(self):
924930

0 commit comments

Comments
 (0)