Skip to content

Commit 5b25df2

Browse files
API: Return BoolArray for string ops when backed by StringArray (#30239)
* API: Return BoolArray for string ops
1 parent 53a0dfd commit 5b25df2

File tree

3 files changed

+47
-13
lines changed

3 files changed

+47
-13
lines changed

doc/source/user_guide/text.rst

+8-1
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ These are places where the behavior of ``StringDtype`` objects differ from
7474
l. For ``StringDtype``, :ref:`string accessor methods<api.series.str>`
7575
that return **numeric** output will always return a nullable integer dtype,
7676
rather than either int or float dtype, depending on the presence of NA values.
77+
Methods returning **boolean** output will return a nullable boolean dtype.
7778

7879
.. ipython:: python
7980
@@ -89,7 +90,13 @@ l. For ``StringDtype``, :ref:`string accessor methods<api.series.str>`
8990
s.astype(object).str.count("a")
9091
s.astype(object).dropna().str.count("a")
9192
92-
When NA values are present, the output dtype is float64.
93+
When NA values are present, the output dtype is float64. Similarly for
94+
methods returning boolean values.
95+
96+
.. ipython:: python
97+
98+
s.str.isdigit()
99+
s.str.match("a")
93100
94101
2. Some string methods, like :meth:`Series.str.decode` are not available
95102
on ``StringArray`` because ``StringArray`` only holds strings, not

pandas/core/strings.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from functools import wraps
33
import re
44
import textwrap
5-
from typing import TYPE_CHECKING, Any, Callable, Dict, List
5+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Type, Union
66
import warnings
77

88
import numpy as np
@@ -142,7 +142,7 @@ def _map_stringarray(
142142
The value to use for missing values. By default, this is
143143
the original value (NA).
144144
dtype : Dtype
145-
The result dtype to use. Specifying this aviods an intermediate
145+
The result dtype to use. Specifying this avoids an intermediate
146146
object-dtype allocation.
147147
148148
Returns
@@ -152,14 +152,20 @@ def _map_stringarray(
152152
an ndarray.
153153
154154
"""
155-
from pandas.arrays import IntegerArray, StringArray
155+
from pandas.arrays import IntegerArray, StringArray, BooleanArray
156156

157157
mask = isna(arr)
158158

159159
assert isinstance(arr, StringArray)
160160
arr = np.asarray(arr)
161161

162-
if is_integer_dtype(dtype):
162+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
163+
constructor: Union[Type[IntegerArray], Type[BooleanArray]]
164+
if is_integer_dtype(dtype):
165+
constructor = IntegerArray
166+
else:
167+
constructor = BooleanArray
168+
163169
na_value_is_na = isna(na_value)
164170
if na_value_is_na:
165171
na_value = 1
@@ -169,21 +175,20 @@ def _map_stringarray(
169175
mask.view("uint8"),
170176
convert=False,
171177
na_value=na_value,
172-
dtype=np.dtype("int64"),
178+
dtype=np.dtype(dtype),
173179
)
174180

175181
if not na_value_is_na:
176182
mask[:] = False
177183

178-
return IntegerArray(result, mask)
184+
return constructor(result, mask)
179185

180186
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
181187
# i.e. StringDtype
182188
result = lib.map_infer_mask(
183189
arr, func, mask.view("uint8"), convert=False, na_value=na_value
184190
)
185191
return StringArray(result)
186-
# TODO: BooleanArray
187192
else:
188193
# This is when the result type is object. We reach this when
189194
# -> We know the result type is truly object (e.g. .encode returns bytes
@@ -299,7 +304,7 @@ def str_count(arr, pat, flags=0):
299304
"""
300305
regex = re.compile(pat, flags=flags)
301306
f = lambda x: len(regex.findall(x))
302-
return _na_map(f, arr, dtype=int)
307+
return _na_map(f, arr, dtype="int64")
303308

304309

305310
def str_contains(arr, pat, case=True, flags=0, na=np.nan, regex=True):
@@ -1365,7 +1370,7 @@ def str_find(arr, sub, start=0, end=None, side="left"):
13651370
else:
13661371
f = lambda x: getattr(x, method)(sub, start, end)
13671372

1368-
return _na_map(f, arr, dtype=int)
1373+
return _na_map(f, arr, dtype="int64")
13691374

13701375

13711376
def str_index(arr, sub, start=0, end=None, side="left"):
@@ -1385,7 +1390,7 @@ def str_index(arr, sub, start=0, end=None, side="left"):
13851390
else:
13861391
f = lambda x: getattr(x, method)(sub, start, end)
13871392

1388-
return _na_map(f, arr, dtype=int)
1393+
return _na_map(f, arr, dtype="int64")
13891394

13901395

13911396
def str_pad(arr, width, side="left", fillchar=" "):
@@ -3210,7 +3215,7 @@ def rindex(self, sub, start=0, end=None):
32103215
len,
32113216
docstring=_shared_docs["len"],
32123217
forbidden_types=None,
3213-
dtype=int,
3218+
dtype="int64",
32143219
returns_string=False,
32153220
)
32163221

pandas/tests/test_strings.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -1825,7 +1825,7 @@ def test_extractall_same_as_extract_subject_index(self):
18251825

18261826
def test_empty_str_methods(self):
18271827
empty_str = empty = Series(dtype=object)
1828-
empty_int = Series(dtype=int)
1828+
empty_int = Series(dtype="int64")
18291829
empty_bool = Series(dtype=bool)
18301830
empty_bytes = Series(dtype=object)
18311831

@@ -3526,6 +3526,12 @@ def test_string_array(any_string_method):
35263526
assert result.dtype == "string"
35273527
result = result.astype(object)
35283528

3529+
elif expected.dtype == "object" and lib.is_bool_array(
3530+
expected.values, skipna=True
3531+
):
3532+
assert result.dtype == "boolean"
3533+
result = result.astype(object)
3534+
35293535
elif expected.dtype == "float" and expected.isna().any():
35303536
assert result.dtype == "Int64"
35313537
result = result.astype("float")
@@ -3551,3 +3557,19 @@ def test_string_array_numeric_integer_array(method, expected):
35513557
result = getattr(s.str, method)("a")
35523558
expected = Series(expected, dtype="Int64")
35533559
tm.assert_series_equal(result, expected)
3560+
3561+
3562+
@pytest.mark.parametrize(
3563+
"method,expected",
3564+
[
3565+
("isdigit", [False, None, True]),
3566+
("isalpha", [True, None, False]),
3567+
("isalnum", [True, None, True]),
3568+
("isdigit", [False, None, True]),
3569+
],
3570+
)
3571+
def test_string_array_boolean_array(method, expected):
3572+
s = Series(["a", None, "1"], dtype="string")
3573+
result = getattr(s.str, method)()
3574+
expected = Series(expected, dtype="boolean")
3575+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)