Skip to content

Commit 99bcfe3

Browse files
Backport PR #52499 on branch 2.0.x (ENH: Implement str.r/split for ArrowDtype) (#52603)
1 parent d3902a7 commit 99bcfe3

File tree

4 files changed

+105
-29
lines changed

4 files changed

+105
-29
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Bug fixes
3636

3737
Other
3838
~~~~~
39+
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
3940
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
4041
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
4142

pandas/core/arrays/arrow/array.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -1929,16 +1929,24 @@ def _str_rfind(self, sub, start: int = 0, end=None):
19291929
)
19301930

19311931
def _str_split(
1932-
self, pat=None, n=-1, expand: bool = False, regex: bool | None = None
1932+
self,
1933+
pat: str | None = None,
1934+
n: int | None = -1,
1935+
expand: bool = False,
1936+
regex: bool | None = None,
19331937
):
1934-
raise NotImplementedError(
1935-
"str.split not supported with pd.ArrowDtype(pa.string())."
1936-
)
1938+
if n in {-1, 0}:
1939+
n = None
1940+
if regex:
1941+
split_func = pc.split_pattern_regex
1942+
else:
1943+
split_func = pc.split_pattern
1944+
return type(self)(split_func(self._data, pat, max_splits=n))
19371945

1938-
def _str_rsplit(self, pat=None, n=-1):
1939-
raise NotImplementedError(
1940-
"str.rsplit not supported with pd.ArrowDtype(pa.string())."
1941-
)
1946+
def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
1947+
if n in {-1, 0}:
1948+
n = None
1949+
return type(self)(pc.split_pattern(self._data, pat, max_splits=n, reverse=True))
19421950

19431951
def _str_translate(self, table):
19441952
raise NotImplementedError(

pandas/core/strings/accessor.py

+32-19
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from pandas.core.dtypes.missing import isna
4343

44+
from pandas.core.arrays.arrow.dtype import ArrowDtype
4445
from pandas.core.base import NoNewAttributesMixin
4546
from pandas.core.construction import extract_array
4647

@@ -267,27 +268,39 @@ def _wrap_result(
267268
# infer from ndim if expand is not specified
268269
expand = result.ndim != 1
269270

270-
elif (
271-
expand is True
272-
and is_object_dtype(result)
273-
and not isinstance(self._orig, ABCIndex)
274-
):
271+
elif expand is True and not isinstance(self._orig, ABCIndex):
275272
# required when expand=True is explicitly specified
276273
# not needed when inferred
277-
278-
def cons_row(x):
279-
if is_list_like(x):
280-
return x
281-
else:
282-
return [x]
283-
284-
result = [cons_row(x) for x in result]
285-
if result and not self._is_string:
286-
# propagate nan values to match longest sequence (GH 18450)
287-
max_len = max(len(x) for x in result)
288-
result = [
289-
x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result
290-
]
274+
if isinstance(result.dtype, ArrowDtype):
275+
import pyarrow as pa
276+
277+
from pandas.core.arrays.arrow.array import ArrowExtensionArray
278+
279+
max_len = pa.compute.max(
280+
result._data.combine_chunks().value_lengths()
281+
).as_py()
282+
if result.isna().any():
283+
result._data = result._data.fill_null([None] * max_len)
284+
result = {
285+
i: ArrowExtensionArray(pa.array(res))
286+
for i, res in enumerate(zip(*result.tolist()))
287+
}
288+
elif is_object_dtype(result):
289+
290+
def cons_row(x):
291+
if is_list_like(x):
292+
return x
293+
else:
294+
return [x]
295+
296+
result = [cons_row(x) for x in result]
297+
if result and not self._is_string:
298+
# propagate nan values to match longest sequence (GH 18450)
299+
max_len = max(len(x) for x in result)
300+
result = [
301+
x * max_len if len(x) == 0 or x[0] is np.nan else x
302+
for x in result
303+
]
291304

292305
if not isinstance(expand, bool):
293306
raise ValueError("expand must be True or False")

pandas/tests/extension/test_arrow.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -2098,6 +2098,62 @@ def test_str_removesuffix(val):
20982098
tm.assert_series_equal(result, expected)
20992099

21002100

2101+
def test_str_split():
2102+
# GH 52401
2103+
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
2104+
result = ser.str.split("c")
2105+
expected = pd.Series(
2106+
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
2107+
)
2108+
tm.assert_series_equal(result, expected)
2109+
2110+
result = ser.str.split("c", n=1)
2111+
expected = pd.Series(
2112+
ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None]))
2113+
)
2114+
tm.assert_series_equal(result, expected)
2115+
2116+
result = ser.str.split("[1-2]", regex=True)
2117+
expected = pd.Series(
2118+
ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None]))
2119+
)
2120+
tm.assert_series_equal(result, expected)
2121+
2122+
result = ser.str.split("[1-2]", regex=True, expand=True)
2123+
expected = pd.DataFrame(
2124+
{
2125+
0: ArrowExtensionArray(pa.array(["a", "a", None])),
2126+
1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])),
2127+
}
2128+
)
2129+
tm.assert_frame_equal(result, expected)
2130+
2131+
2132+
def test_str_rsplit():
2133+
# GH 52401
2134+
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
2135+
result = ser.str.rsplit("c")
2136+
expected = pd.Series(
2137+
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
2138+
)
2139+
tm.assert_series_equal(result, expected)
2140+
2141+
result = ser.str.rsplit("c", n=1)
2142+
expected = pd.Series(
2143+
ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None]))
2144+
)
2145+
tm.assert_series_equal(result, expected)
2146+
2147+
result = ser.str.rsplit("c", n=1, expand=True)
2148+
expected = pd.DataFrame(
2149+
{
2150+
0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])),
2151+
1: ArrowExtensionArray(pa.array(["b", "b", None])),
2152+
}
2153+
)
2154+
tm.assert_frame_equal(result, expected)
2155+
2156+
21012157
@pytest.mark.parametrize(
21022158
"method, args",
21032159
[
@@ -2113,8 +2169,6 @@ def test_str_removesuffix(val):
21132169
["rindex", ("abc",)],
21142170
["normalize", ("abc",)],
21152171
["rfind", ("abc",)],
2116-
["split", ()],
2117-
["rsplit", ()],
21182172
["translate", ("abc",)],
21192173
["wrap", ("abc",)],
21202174
],

0 commit comments

Comments
 (0)