Skip to content

Commit 2bb3557

Browse files
authored
REF/TST: handle boolean dtypes in base extension tests (#54334)
1 parent b91d7f0 commit 2bb3557

File tree

5 files changed

+91
-214
lines changed

5 files changed

+91
-214
lines changed

pandas/tests/extension/base/groupby.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,25 @@ def test_grouping_grouper(self, data_for_grouping):
3030
@pytest.mark.parametrize("as_index", [True, False])
3131
def test_groupby_extension_agg(self, as_index, data_for_grouping):
3232
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
33+
34+
is_bool = data_for_grouping.dtype._is_boolean
35+
if is_bool:
36+
# only 2 unique values, and the final entry has c==b
37+
# (see data_for_grouping docstring)
38+
df = df.iloc[:-1]
39+
3340
result = df.groupby("B", as_index=as_index).A.mean()
3441
_, uniques = pd.factorize(data_for_grouping, sort=True)
3542

43+
exp_vals = [3.0, 1.0, 4.0]
44+
if is_bool:
45+
exp_vals = exp_vals[:-1]
3646
if as_index:
3747
index = pd.Index(uniques, name="B")
38-
expected = pd.Series([3.0, 1.0, 4.0], index=index, name="A")
48+
expected = pd.Series(exp_vals, index=index, name="A")
3949
self.assert_series_equal(result, expected)
4050
else:
41-
expected = pd.DataFrame({"B": uniques, "A": [3.0, 1.0, 4.0]})
51+
expected = pd.DataFrame({"B": uniques, "A": exp_vals})
4252
self.assert_frame_equal(result, expected)
4353

4454
def test_groupby_agg_extension(self, data_for_grouping):
@@ -83,19 +93,38 @@ def test_groupby_agg_extension_timedelta_cumsum_with_named_aggregation(self):
8393

8494
def test_groupby_extension_no_sort(self, data_for_grouping):
8595
df = pd.DataFrame({"A": [1, 1, 2, 2, 3, 3, 1, 4], "B": data_for_grouping})
96+
97+
is_bool = data_for_grouping.dtype._is_boolean
98+
if is_bool:
99+
# only 2 unique values, and the final entry has c==b
100+
# (see data_for_grouping docstring)
101+
df = df.iloc[:-1]
102+
86103
result = df.groupby("B", sort=False).A.mean()
87104
_, index = pd.factorize(data_for_grouping, sort=False)
88105

89106
index = pd.Index(index, name="B")
90-
expected = pd.Series([1.0, 3.0, 4.0], index=index, name="A")
107+
exp_vals = [1.0, 3.0, 4.0]
108+
if is_bool:
109+
exp_vals = exp_vals[:-1]
110+
expected = pd.Series(exp_vals, index=index, name="A")
91111
self.assert_series_equal(result, expected)
92112

93113
def test_groupby_extension_transform(self, data_for_grouping):
114+
is_bool = data_for_grouping.dtype._is_boolean
115+
94116
valid = data_for_grouping[~data_for_grouping.isna()]
95117
df = pd.DataFrame({"A": [1, 1, 3, 3, 1, 4], "B": valid})
118+
is_bool = data_for_grouping.dtype._is_boolean
119+
if is_bool:
120+
# only 2 unique values, and the final entry has c==b
121+
# (see data_for_grouping docstring)
122+
df = df.iloc[:-1]
96123

97124
result = df.groupby("B").A.transform(len)
98125
expected = pd.Series([3, 3, 2, 2, 3, 1], name="A")
126+
if is_bool:
127+
expected = expected[:-1]
99128

100129
self.assert_series_equal(result, expected)
101130

pandas/tests/extension/base/methods.py

+48-4
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,22 @@ def test_argsort_missing(self, data_missing_for_sorting):
115115

116116
def test_argmin_argmax(self, data_for_sorting, data_missing_for_sorting, na_value):
117117
# GH 24382
118+
is_bool = data_for_sorting.dtype._is_boolean
119+
120+
exp_argmax = 1
121+
exp_argmax_repeated = 3
122+
if is_bool:
123+
# See data_for_sorting docstring
124+
exp_argmax = 0
125+
exp_argmax_repeated = 1
118126

119127
# data_for_sorting -> [B, C, A] with A < B < C
120-
assert data_for_sorting.argmax() == 1
128+
assert data_for_sorting.argmax() == exp_argmax
121129
assert data_for_sorting.argmin() == 2
122130

123131
# with repeated values -> first occurrence
124132
data = data_for_sorting.take([2, 0, 0, 1, 1, 2])
125-
assert data.argmax() == 3
133+
assert data.argmax() == exp_argmax_repeated
126134
assert data.argmin() == 0
127135

128136
# with missing values
@@ -244,8 +252,15 @@ def test_unique(self, data, box, method):
244252

245253
def test_factorize(self, data_for_grouping):
246254
codes, uniques = pd.factorize(data_for_grouping, use_na_sentinel=True)
247-
expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 2], dtype=np.intp)
248-
expected_uniques = data_for_grouping.take([0, 4, 7])
255+
256+
is_bool = data_for_grouping.dtype._is_boolean
257+
if is_bool:
258+
# only 2 unique values
259+
expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 0], dtype=np.intp)
260+
expected_uniques = data_for_grouping.take([0, 4])
261+
else:
262+
expected_codes = np.array([0, 0, -1, -1, 1, 1, 0, 2], dtype=np.intp)
263+
expected_uniques = data_for_grouping.take([0, 4, 7])
249264

250265
tm.assert_numpy_array_equal(codes, expected_codes)
251266
self.assert_extension_array_equal(uniques, expected_uniques)
@@ -457,6 +472,9 @@ def test_hash_pandas_object_works(self, data, as_frame):
457472
self.assert_equal(a, b)
458473

459474
def test_searchsorted(self, data_for_sorting, as_series):
475+
if data_for_sorting.dtype._is_boolean:
476+
return self._test_searchsorted_bool_dtypes(data_for_sorting, as_series)
477+
460478
b, c, a = data_for_sorting
461479
arr = data_for_sorting.take([2, 0, 1]) # to get [a, b, c]
462480

@@ -480,6 +498,32 @@ def test_searchsorted(self, data_for_sorting, as_series):
480498
sorter = np.array([1, 2, 0])
481499
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0
482500

501+
def _test_searchsorted_bool_dtypes(self, data_for_sorting, as_series):
502+
# We call this from test_searchsorted in cases where we have a
503+
# boolean-like dtype. The non-bool test assumes we have more than 2
504+
# unique values.
505+
dtype = data_for_sorting.dtype
506+
data_for_sorting = pd.array([True, False], dtype=dtype)
507+
b, a = data_for_sorting
508+
arr = type(data_for_sorting)._from_sequence([a, b])
509+
510+
if as_series:
511+
arr = pd.Series(arr)
512+
assert arr.searchsorted(a) == 0
513+
assert arr.searchsorted(a, side="right") == 1
514+
515+
assert arr.searchsorted(b) == 1
516+
assert arr.searchsorted(b, side="right") == 2
517+
518+
result = arr.searchsorted(arr.take([0, 1]))
519+
expected = np.array([0, 1], dtype=np.intp)
520+
521+
tm.assert_numpy_array_equal(result, expected)
522+
523+
# sorter
524+
sorter = np.array([1, 0])
525+
assert data_for_sorting.searchsorted(a, sorter=sorter) == 0
526+
483527
def test_where_series(self, data, na_value, as_frame):
484528
assert data[0] != data[1]
485529
cls = type(data)

pandas/tests/extension/conftest.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def data_for_sorting():
7676
7777
This should be three items [B, C, A] with
7878
A < B < C
79+
80+
For boolean dtypes (for which there are only 2 values available),
81+
set B=C=True
7982
"""
8083
raise NotImplementedError
8184

@@ -117,7 +120,10 @@ def data_for_grouping():
117120
118121
Expected to be like [B, B, NA, NA, A, A, B, C]
119122
120-
Where A < B < C and NA is missing
123+
Where A < B < C and NA is missing.
124+
125+
If a dtype has _is_boolean = True, i.e. only 2 unique non-NA entries,
126+
then set C=B.
121127
"""
122128
raise NotImplementedError
123129

pandas/tests/extension/test_arrow.py

+1-59
Original file line numberDiff line numberDiff line change
@@ -587,38 +587,6 @@ def test_reduce_series(
587587

588588

589589
class TestBaseGroupby(base.BaseGroupbyTests):
590-
def test_groupby_extension_no_sort(self, data_for_grouping, request):
591-
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
592-
if pa.types.is_boolean(pa_dtype):
593-
request.node.add_marker(
594-
pytest.mark.xfail(
595-
reason=f"{pa_dtype} only has 2 unique possible values",
596-
)
597-
)
598-
super().test_groupby_extension_no_sort(data_for_grouping)
599-
600-
def test_groupby_extension_transform(self, data_for_grouping, request):
601-
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
602-
if pa.types.is_boolean(pa_dtype):
603-
request.node.add_marker(
604-
pytest.mark.xfail(
605-
reason=f"{pa_dtype} only has 2 unique possible values",
606-
)
607-
)
608-
super().test_groupby_extension_transform(data_for_grouping)
609-
610-
@pytest.mark.parametrize("as_index", [True, False])
611-
def test_groupby_extension_agg(self, as_index, data_for_grouping, request):
612-
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
613-
if pa.types.is_boolean(pa_dtype):
614-
request.node.add_marker(
615-
pytest.mark.xfail(
616-
raises=ValueError,
617-
reason=f"{pa_dtype} only has 2 unique possible values",
618-
)
619-
)
620-
super().test_groupby_extension_agg(as_index, data_for_grouping)
621-
622590
def test_in_numeric_groupby(self, data_for_grouping):
623591
dtype = data_for_grouping.dtype
624592
if is_string_dtype(dtype):
@@ -845,13 +813,7 @@ def test_argmin_argmax(
845813
self, data_for_sorting, data_missing_for_sorting, na_value, request
846814
):
847815
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
848-
if pa.types.is_boolean(pa_dtype):
849-
request.node.add_marker(
850-
pytest.mark.xfail(
851-
reason=f"{pa_dtype} only has 2 unique possible values",
852-
)
853-
)
854-
elif pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
816+
if pa.types.is_decimal(pa_dtype) and pa_version_under7p0:
855817
request.node.add_marker(
856818
pytest.mark.xfail(
857819
reason=f"No pyarrow kernel for {pa_dtype}",
@@ -888,16 +850,6 @@ def test_argreduce_series(
888850
data_missing_for_sorting, op_name, skipna, expected
889851
)
890852

891-
def test_factorize(self, data_for_grouping, request):
892-
pa_dtype = data_for_grouping.dtype.pyarrow_dtype
893-
if pa.types.is_boolean(pa_dtype):
894-
request.node.add_marker(
895-
pytest.mark.xfail(
896-
reason=f"{pa_dtype} only has 2 unique possible values",
897-
)
898-
)
899-
super().test_factorize(data_for_grouping)
900-
901853
_combine_le_expected_dtype = "bool[pyarrow]"
902854

903855
def test_combine_add(self, data_repeated, request):
@@ -913,16 +865,6 @@ def test_combine_add(self, data_repeated, request):
913865
else:
914866
super().test_combine_add(data_repeated)
915867

916-
def test_searchsorted(self, data_for_sorting, as_series, request):
917-
pa_dtype = data_for_sorting.dtype.pyarrow_dtype
918-
if pa.types.is_boolean(pa_dtype):
919-
request.node.add_marker(
920-
pytest.mark.xfail(
921-
reason=f"{pa_dtype} only has 2 unique possible values",
922-
)
923-
)
924-
super().test_searchsorted(data_for_sorting, as_series)
925-
926868
def test_basic_equals(self, data):
927869
# https://github.com/pandas-dev/pandas/issues/34660
928870
assert pd.Series(data).equals(pd.Series(data))

0 commit comments

Comments
 (0)