Skip to content

Commit 656e4fd

Browse files
authored
ENH: Implement str.r/split for ArrowDtype (#52499)
* Add str split for ArrowDtype(pa.string()) * Add tests * Fix tests and add whats * Typing * More whatsnew note * undo whatsnew
1 parent 658ac5b commit 656e4fd

File tree

4 files changed

+106
-28
lines changed

4 files changed

+106
-28
lines changed

doc/source/whatsnew/v2.0.1.rst

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ Bug fixes
3636

3737
Other
3838
~~~~~
39+
- Implemented :meth:`Series.str.split` and :meth:`Series.str.rsplit` for :class:`ArrowDtype` with ``pyarrow.string`` (:issue:`52401`)
3940
- :class:`DataFrame` created from empty dicts had :attr:`~DataFrame.columns` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
4041
- :class:`Series` created from empty dicts had :attr:`~Series.index` of dtype ``object``. It is now a :class:`RangeIndex` (:issue:`52404`)
4142

pandas/core/arrays/arrow/array.py

+17-7
Original file line numberDiff line numberDiff line change
@@ -2025,15 +2025,25 @@ def _str_rfind(self, sub, start: int = 0, end=None):
20252025
)
20262026

20272027
def _str_split(
2028-
self, pat=None, n=-1, expand: bool = False, regex: bool | None = None
2028+
self,
2029+
pat: str | None = None,
2030+
n: int | None = -1,
2031+
expand: bool = False,
2032+
regex: bool | None = None,
20292033
):
2030-
raise NotImplementedError(
2031-
"str.split not supported with pd.ArrowDtype(pa.string())."
2032-
)
2034+
if n in {-1, 0}:
2035+
n = None
2036+
if regex:
2037+
split_func = pc.split_pattern_regex
2038+
else:
2039+
split_func = pc.split_pattern
2040+
return type(self)(split_func(self._pa_array, pat, max_splits=n))
20332041

2034-
def _str_rsplit(self, pat=None, n=-1):
2035-
raise NotImplementedError(
2036-
"str.rsplit not supported with pd.ArrowDtype(pa.string())."
2042+
def _str_rsplit(self, pat: str | None = None, n: int | None = -1):
2043+
if n in {-1, 0}:
2044+
n = None
2045+
return type(self)(
2046+
pc.split_pattern(self._pa_array, pat, max_splits=n, reverse=True)
20372047
)
20382048

20392049
def _str_translate(self, table):

pandas/core/strings/accessor.py

+32-19
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
)
4242
from pandas.core.dtypes.missing import isna
4343

44+
from pandas.core.arrays.arrow.dtype import ArrowDtype
4445
from pandas.core.base import NoNewAttributesMixin
4546
from pandas.core.construction import extract_array
4647

@@ -267,27 +268,39 @@ def _wrap_result(
267268
# infer from ndim if expand is not specified
268269
expand = result.ndim != 1
269270

270-
elif (
271-
expand is True
272-
and is_object_dtype(result)
273-
and not isinstance(self._orig, ABCIndex)
274-
):
271+
elif expand is True and not isinstance(self._orig, ABCIndex):
275272
# required when expand=True is explicitly specified
276273
# not needed when inferred
277-
278-
def cons_row(x):
279-
if is_list_like(x):
280-
return x
281-
else:
282-
return [x]
283-
284-
result = [cons_row(x) for x in result]
285-
if result and not self._is_string:
286-
# propagate nan values to match longest sequence (GH 18450)
287-
max_len = max(len(x) for x in result)
288-
result = [
289-
x * max_len if len(x) == 0 or x[0] is np.nan else x for x in result
290-
]
274+
if isinstance(result.dtype, ArrowDtype):
275+
import pyarrow as pa
276+
277+
from pandas.core.arrays.arrow.array import ArrowExtensionArray
278+
279+
max_len = pa.compute.max(
280+
result._pa_array.combine_chunks().value_lengths()
281+
).as_py()
282+
if result.isna().any():
283+
result._pa_array = result._pa_array.fill_null([None] * max_len)
284+
result = {
285+
i: ArrowExtensionArray(pa.array(res))
286+
for i, res in enumerate(zip(*result.tolist()))
287+
}
288+
elif is_object_dtype(result):
289+
290+
def cons_row(x):
291+
if is_list_like(x):
292+
return x
293+
else:
294+
return [x]
295+
296+
result = [cons_row(x) for x in result]
297+
if result and not self._is_string:
298+
# propagate nan values to match longest sequence (GH 18450)
299+
max_len = max(len(x) for x in result)
300+
result = [
301+
x * max_len if len(x) == 0 or x[0] is np.nan else x
302+
for x in result
303+
]
291304

292305
if not isinstance(expand, bool):
293306
raise ValueError("expand must be True or False")

pandas/tests/extension/test_arrow.py

+56-2
Original file line numberDiff line numberDiff line change
@@ -2089,6 +2089,62 @@ def test_str_removesuffix(val):
20892089
tm.assert_series_equal(result, expected)
20902090

20912091

2092+
def test_str_split():
2093+
# GH 52401
2094+
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
2095+
result = ser.str.split("c")
2096+
expected = pd.Series(
2097+
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
2098+
)
2099+
tm.assert_series_equal(result, expected)
2100+
2101+
result = ser.str.split("c", n=1)
2102+
expected = pd.Series(
2103+
ArrowExtensionArray(pa.array([["a1", "bcb"], ["a2", "bcb"], None]))
2104+
)
2105+
tm.assert_series_equal(result, expected)
2106+
2107+
result = ser.str.split("[1-2]", regex=True)
2108+
expected = pd.Series(
2109+
ArrowExtensionArray(pa.array([["a", "cbcb"], ["a", "cbcb"], None]))
2110+
)
2111+
tm.assert_series_equal(result, expected)
2112+
2113+
result = ser.str.split("[1-2]", regex=True, expand=True)
2114+
expected = pd.DataFrame(
2115+
{
2116+
0: ArrowExtensionArray(pa.array(["a", "a", None])),
2117+
1: ArrowExtensionArray(pa.array(["cbcb", "cbcb", None])),
2118+
}
2119+
)
2120+
tm.assert_frame_equal(result, expected)
2121+
2122+
2123+
def test_str_rsplit():
2124+
# GH 52401
2125+
ser = pd.Series(["a1cbcb", "a2cbcb", None], dtype=ArrowDtype(pa.string()))
2126+
result = ser.str.rsplit("c")
2127+
expected = pd.Series(
2128+
ArrowExtensionArray(pa.array([["a1", "b", "b"], ["a2", "b", "b"], None]))
2129+
)
2130+
tm.assert_series_equal(result, expected)
2131+
2132+
result = ser.str.rsplit("c", n=1)
2133+
expected = pd.Series(
2134+
ArrowExtensionArray(pa.array([["a1cb", "b"], ["a2cb", "b"], None]))
2135+
)
2136+
tm.assert_series_equal(result, expected)
2137+
2138+
result = ser.str.rsplit("c", n=1, expand=True)
2139+
expected = pd.DataFrame(
2140+
{
2141+
0: ArrowExtensionArray(pa.array(["a1cb", "a2cb", None])),
2142+
1: ArrowExtensionArray(pa.array(["b", "b", None])),
2143+
}
2144+
)
2145+
tm.assert_frame_equal(result, expected)
2146+
2147+
20922148
@pytest.mark.parametrize(
20932149
"method, args",
20942150
[
@@ -2104,8 +2160,6 @@ def test_str_removesuffix(val):
21042160
["rindex", ("abc",)],
21052161
["normalize", ("abc",)],
21062162
["rfind", ("abc",)],
2107-
["split", ()],
2108-
["rsplit", ()],
21092163
["translate", ("abc",)],
21102164
["wrap", ("abc",)],
21112165
],

0 commit comments

Comments
 (0)