Skip to content

Commit 428ae90

Browse files
simonjayhawkinsyeshsurya
authored andcommitted
ENH: [ArrowStringArray] Enable the string methods for the arrow-backed StringArray (pandas-dev#40708)
1 parent 447a670 commit 428ae90

File tree

6 files changed

+115
-25
lines changed

6 files changed

+115
-25
lines changed

pandas/conftest.py

+2
Original file line numberDiff line numberDiff line change
@@ -1144,6 +1144,8 @@ def nullable_string_dtype(request):
11441144
* 'string'
11451145
* 'arrow_string'
11461146
"""
1147+
from pandas.core.arrays.string_arrow import ArrowStringDtype # noqa: F401
1148+
11471149
return request.param
11481150

11491151

pandas/core/arrays/string_arrow.py

+79-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
from pandas.util._validators import validate_fillna_kwargs
2525

2626
from pandas.core.dtypes.base import ExtensionDtype
27+
from pandas.core.dtypes.common import (
28+
is_object_dtype,
29+
is_string_dtype,
30+
)
2731
from pandas.core.dtypes.dtypes import register_extension_dtype
2832
from pandas.core.dtypes.missing import isna
2933

@@ -41,6 +45,7 @@
4145
check_array_indexer,
4246
validate_indices,
4347
)
48+
from pandas.core.strings.object_array import ObjectStringArrayMixin
4449

4550
try:
4651
import pyarrow as pa
@@ -149,7 +154,12 @@ def __eq__(self, other) -> bool:
149154
return False
150155

151156

152-
class ArrowStringArray(OpsMixin, ExtensionArray):
157+
# TODO: Inherit directly from BaseStringArrayMethods. Currently we inherit from
158+
# ObjectStringArrayMixin because we want to have the object-dtype based methods as
159+
# fallback for the ones that pyarrow doesn't yet support
160+
161+
162+
class ArrowStringArray(OpsMixin, ExtensionArray, ObjectStringArrayMixin):
153163
"""
154164
Extension array for string data in a ``pyarrow.ChunkedArray``.
155165
@@ -676,3 +686,71 @@ def value_counts(self, dropna: bool = True) -> Series:
676686
raise NotImplementedError("yo")
677687

678688
return Series(counts, index=index).astype("Int64")
689+
690+
# ------------------------------------------------------------------------
691+
# String methods interface
692+
693+
_str_na_value = ArrowStringDtype.na_value
694+
695+
def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
696+
# TODO: de-duplicate with StringArray method. This method is moreless copy and
697+
# paste.
698+
699+
from pandas.arrays import (
700+
BooleanArray,
701+
IntegerArray,
702+
)
703+
704+
if dtype is None:
705+
dtype = self.dtype
706+
if na_value is None:
707+
na_value = self.dtype.na_value
708+
709+
mask = isna(self)
710+
arr = np.asarray(self)
711+
712+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
713+
constructor: type[IntegerArray] | type[BooleanArray]
714+
if is_integer_dtype(dtype):
715+
constructor = IntegerArray
716+
else:
717+
constructor = BooleanArray
718+
719+
na_value_is_na = isna(na_value)
720+
if na_value_is_na:
721+
na_value = 1
722+
result = lib.map_infer_mask(
723+
arr,
724+
f,
725+
mask.view("uint8"),
726+
convert=False,
727+
na_value=na_value,
728+
# error: Value of type variable "_DTypeScalar" of "dtype" cannot be
729+
# "object"
730+
# error: Argument 1 to "dtype" has incompatible type
731+
# "Union[ExtensionDtype, str, dtype[Any], Type[object]]"; expected
732+
# "Type[object]"
733+
dtype=np.dtype(dtype), # type: ignore[type-var,arg-type]
734+
)
735+
736+
if not na_value_is_na:
737+
mask[:] = False
738+
739+
# error: Argument 1 to "IntegerArray" has incompatible type
740+
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
741+
# error: Argument 1 to "BooleanArray" has incompatible type
742+
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
743+
return constructor(result, mask) # type: ignore[arg-type]
744+
745+
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
746+
# i.e. StringDtype
747+
result = lib.map_infer_mask(
748+
arr, f, mask.view("uint8"), convert=False, na_value=na_value
749+
)
750+
return self._from_sequence(result)
751+
else:
752+
# This is when the result type is object. We reach this when
753+
# -> We know the result type is truly object (e.g. .encode returns bytes
754+
# or .findall returns a list).
755+
# -> We don't know the result type. E.g. `.get` can return anything.
756+
return lib.map_infer_mask(arr, f, mask.view("uint8"))

pandas/core/strings/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
# - StringArray
2626
# - PandasArray
2727
# - Categorical
28+
# - ArrowStringArray
2829

2930
from pandas.core.strings.accessor import StringMethods
3031
from pandas.core.strings.base import BaseStringArrayMethods

pandas/core/strings/accessor.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,11 @@ class StringMethods(NoNewAttributesMixin):
154154

155155
def __init__(self, data):
156156
from pandas.core.arrays.string_ import StringDtype
157+
from pandas.core.arrays.string_arrow import ArrowStringDtype
157158

158159
self._inferred_dtype = self._validate(data)
159160
self._is_categorical = is_categorical_dtype(data.dtype)
160-
self._is_string = isinstance(data.dtype, StringDtype)
161+
self._is_string = isinstance(data.dtype, (StringDtype, ArrowStringDtype))
161162
self._data = data
162163

163164
self._index = self._name = None
@@ -316,7 +317,7 @@ def cons_row(x):
316317
# This is a mess.
317318
dtype: Optional[str]
318319
if self._is_string and returns_string:
319-
dtype = "string"
320+
dtype = self._orig.dtype
320321
else:
321322
dtype = None
322323

pandas/tests/arrays/string_/test_string.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -91,17 +91,11 @@ def test_setitem_with_scalar_string(dtype):
9191
@pytest.mark.parametrize(
9292
"input, method",
9393
[
94-
(["a", "b", "c"], operator.methodcaller("capitalize")),
9594
(["a", "b", "c"], operator.methodcaller("capitalize")),
9695
(["a b", "a bc. de"], operator.methodcaller("capitalize")),
9796
],
9897
)
99-
def test_string_methods(input, method, dtype, request):
100-
if dtype == "arrow_string":
101-
reason = "AttributeError: 'ArrowStringDtype' object has no attribute 'base'"
102-
mark = pytest.mark.xfail(reason=reason)
103-
request.node.add_marker(mark)
104-
98+
def test_string_methods(input, method, dtype):
10599
a = pd.Series(input, dtype=dtype)
106100
b = pd.Series(input, dtype="object")
107101
result = method(a.str)

pandas/tests/strings/test_string_array.py

+29-15
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,22 @@
1111
)
1212

1313

14-
def test_string_array(any_string_method):
14+
def test_string_array(nullable_string_dtype, any_string_method, request):
1515
method_name, args, kwargs = any_string_method
1616
if method_name == "decode":
1717
pytest.skip("decode requires bytes.")
1818

19+
if nullable_string_dtype == "arrow_string" and method_name in {
20+
"extract",
21+
"extractall",
22+
}:
23+
reason = "extract/extractall does not yet dispatch to array"
24+
mark = pytest.mark.xfail(reason=reason)
25+
request.node.add_marker(mark)
26+
1927
data = ["a", "bb", np.nan, "ccc"]
2028
a = Series(data, dtype=object)
21-
b = Series(data, dtype="string")
29+
b = Series(data, dtype=nullable_string_dtype)
2230

2331
expected = getattr(a.str, method_name)(*args, **kwargs)
2432
result = getattr(b.str, method_name)(*args, **kwargs)
@@ -27,7 +35,7 @@ def test_string_array(any_string_method):
2735
if expected.dtype == "object" and lib.is_string_array(
2836
expected.dropna().values,
2937
):
30-
assert result.dtype == "string"
38+
assert result.dtype == nullable_string_dtype
3139
result = result.astype(object)
3240

3341
elif expected.dtype == "object" and lib.is_bool_array(
@@ -46,7 +54,7 @@ def test_string_array(any_string_method):
4654

4755
elif isinstance(expected, DataFrame):
4856
columns = expected.select_dtypes(include="object").columns
49-
assert all(result[columns].dtypes == "string")
57+
assert all(result[columns].dtypes == nullable_string_dtype)
5058
result[columns] = result[columns].astype(object)
5159
tm.assert_equal(result, expected)
5260

@@ -60,8 +68,8 @@ def test_string_array(any_string_method):
6068
("rindex", [2, None]),
6169
],
6270
)
63-
def test_string_array_numeric_integer_array(method, expected):
64-
s = Series(["aba", None], dtype="string")
71+
def test_string_array_numeric_integer_array(nullable_string_dtype, method, expected):
72+
s = Series(["aba", None], dtype=nullable_string_dtype)
6573
result = getattr(s.str, method)("a")
6674
expected = Series(expected, dtype="Int64")
6775
tm.assert_series_equal(result, expected)
@@ -73,33 +81,39 @@ def test_string_array_numeric_integer_array(method, expected):
7381
("isdigit", [False, None, True]),
7482
("isalpha", [True, None, False]),
7583
("isalnum", [True, None, True]),
76-
("isdigit", [False, None, True]),
84+
("isnumeric", [False, None, True]),
7785
],
7886
)
79-
def test_string_array_boolean_array(method, expected):
80-
s = Series(["a", None, "1"], dtype="string")
87+
def test_string_array_boolean_array(nullable_string_dtype, method, expected):
88+
s = Series(["a", None, "1"], dtype=nullable_string_dtype)
8189
result = getattr(s.str, method)()
8290
expected = Series(expected, dtype="boolean")
8391
tm.assert_series_equal(result, expected)
8492

8593

86-
def test_string_array_extract():
94+
def test_string_array_extract(nullable_string_dtype, request):
8795
# https://github.com/pandas-dev/pandas/issues/30969
8896
# Only expand=False & multiple groups was failing
89-
a = Series(["a1", "b2", "cc"], dtype="string")
97+
98+
if nullable_string_dtype == "arrow_string":
99+
reason = "extract does not yet dispatch to array"
100+
mark = pytest.mark.xfail(reason=reason)
101+
request.node.add_marker(mark)
102+
103+
a = Series(["a1", "b2", "cc"], dtype=nullable_string_dtype)
90104
b = Series(["a1", "b2", "cc"], dtype="object")
91105
pat = r"(\w)(\d)"
92106

93107
result = a.str.extract(pat, expand=False)
94108
expected = b.str.extract(pat, expand=False)
95-
assert all(result.dtypes == "string")
109+
assert all(result.dtypes == nullable_string_dtype)
96110

97111
result = result.astype(object)
98112
tm.assert_equal(result, expected)
99113

100114

101-
def test_str_get_stringarray_multiple_nans():
102-
s = Series(pd.array(["a", "ab", pd.NA, "abc"]))
115+
def test_str_get_stringarray_multiple_nans(nullable_string_dtype):
116+
s = Series(pd.array(["a", "ab", pd.NA, "abc"], dtype=nullable_string_dtype))
103117
result = s.str.get(2)
104-
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"]))
118+
expected = Series(pd.array([pd.NA, pd.NA, pd.NA, "c"], dtype=nullable_string_dtype))
105119
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)