Skip to content

Commit ac69522

Browse files
API/TST: expand tests for string any/all reduction + fix pyarrow-based implementation (#59414)
1 parent fe8a947 commit ac69522

File tree

2 files changed

+47
-10
lines changed

2 files changed

+47
-10
lines changed

pandas/core/arrays/string_arrow.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -697,9 +697,9 @@ def _reduce(
697697
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
698698
):
699699
if name in ["any", "all"]:
700-
if not skipna and name == "all":
701-
nas = pc.invert(pc.is_null(self._pa_array))
702-
arr = pc.and_kleene(nas, pc.not_equal(self._pa_array, ""))
700+
if not skipna:
701+
nas = pc.is_null(self._pa_array)
702+
arr = pc.or_kleene(nas, pc.not_equal(self._pa_array, ""))
703703
else:
704704
arr = pc.not_equal(self._pa_array, "")
705705
return ArrowExtensionArray(arr)._reduce(

pandas/tests/reductions/test_reductions.py

+44-7
Original file line numberDiff line numberDiff line change
@@ -1062,25 +1062,62 @@ def test_any_all_datetimelike(self):
10621062
assert df.any().all()
10631063
assert not df.all().any()
10641064

1065-
def test_any_all_pyarrow_string(self):
1065+
def test_any_all_string_dtype(self, any_string_dtype):
10661066
# GH#54591
1067-
pytest.importorskip("pyarrow")
1068-
ser = Series(["", "a"], dtype="string[pyarrow_numpy]")
1067+
if (
1068+
isinstance(any_string_dtype, pd.StringDtype)
1069+
and any_string_dtype.na_value is pd.NA
1070+
):
1071+
# the nullable string dtype currently still raise an error
1072+
# https://github.com/pandas-dev/pandas/issues/51939
1073+
ser = Series(["a", "b"], dtype=any_string_dtype)
1074+
with pytest.raises(TypeError):
1075+
ser.any()
1076+
with pytest.raises(TypeError):
1077+
ser.all()
1078+
return
1079+
1080+
ser = Series(["", "a"], dtype=any_string_dtype)
10691081
assert ser.any()
10701082
assert not ser.all()
1083+
assert ser.any(skipna=False)
1084+
assert not ser.all(skipna=False)
10711085

1072-
ser = Series([None, "a"], dtype="string[pyarrow_numpy]")
1086+
ser = Series([np.nan, "a"], dtype=any_string_dtype)
10731087
assert ser.any()
10741088
assert ser.all()
1075-
assert not ser.all(skipna=False)
1089+
assert ser.any(skipna=False)
1090+
assert ser.all(skipna=False) # NaN is considered truthy
10761091

1077-
ser = Series([None, ""], dtype="string[pyarrow_numpy]")
1092+
ser = Series([np.nan, ""], dtype=any_string_dtype)
10781093
assert not ser.any()
10791094
assert not ser.all()
1095+
assert ser.any(skipna=False) # NaN is considered truthy
1096+
assert not ser.all(skipna=False)
10801097

1081-
ser = Series(["a", "b"], dtype="string[pyarrow_numpy]")
1098+
ser = Series(["a", "b"], dtype=any_string_dtype)
10821099
assert ser.any()
10831100
assert ser.all()
1101+
assert ser.any(skipna=False)
1102+
assert ser.all(skipna=False)
1103+
1104+
ser = Series([], dtype=any_string_dtype)
1105+
assert not ser.any()
1106+
assert ser.all()
1107+
assert not ser.any(skipna=False)
1108+
assert ser.all(skipna=False)
1109+
1110+
ser = Series([""], dtype=any_string_dtype)
1111+
assert not ser.any()
1112+
assert not ser.all()
1113+
assert not ser.any(skipna=False)
1114+
assert not ser.all(skipna=False)
1115+
1116+
ser = Series([np.nan], dtype=any_string_dtype)
1117+
assert not ser.any()
1118+
assert ser.all()
1119+
assert ser.any(skipna=False) # NaN is considered truthy
1120+
assert ser.all(skipna=False) # NaN is considered truthy
10841121

10851122
def test_timedelta64_analytics(self):
10861123
# index min/max

0 commit comments

Comments
 (0)