Skip to content

ENH: Implement str.extract for ArrowDtype #56334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.2.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ Other enhancements
- Allow passing ``read_only``, ``data_only`` and ``keep_links`` arguments to openpyxl using ``engine_kwargs`` of :func:`read_excel` (:issue:`55027`)
- DataFrame.apply now allows the usage of numba (via ``engine="numba"``) to JIT compile the passed function, allowing for potential speedups (:issue:`54666`)
- Implement masked algorithms for :meth:`Series.value_counts` (:issue:`54984`)
- Implemented :meth:`Series.str.extract` for :class:`ArrowDtype` (:issue:`56268`)
- Improved error message that appears in :meth:`DatetimeIndex.to_period` with frequencies which are not supported as period frequencies, such as "BMS" (:issue:`56243`)
- Improved error message when constructing :class:`Period` with invalid offsets such as "QS" (:issue:`55785`)

Expand Down
16 changes: 13 additions & 3 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,9 +2296,19 @@ def _str_encode(self, encoding: str, errors: str = "strict"):
return type(self)(pa.chunked_array(result))

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
raise NotImplementedError(
"str.extract not supported with pd.ArrowDtype(pa.string())."
)
if flags:
raise NotImplementedError("Only flags=0 is implemented.")
groups = re.compile(pat).groupindex.keys()
if len(groups) == 0:
raise ValueError(f"{pat=} must contain a symbolic group name.")
result = pc.extract_regex(self._pa_array, pat)
if expand:
return {
col: type(self)(pc.struct_field(result, [i]))
for col, i in zip(groups, range(result.type.num_fields))
}
else:
return type(self)(pc.struct_field(result, [0]))

def _str_findall(self, pat: str, flags: int = 0):
regex = re.compile(pat, flags=flags)
Expand Down
36 changes: 31 additions & 5 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -2150,14 +2150,40 @@ def test_str_rsplit():
tm.assert_frame_equal(result, expected)


def test_str_unsupported_extract():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
with pytest.raises(
NotImplementedError, match="str.extract not supported with pd.ArrowDtype"
):
def test_str_extract_non_symbolic():
ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string()))
with pytest.raises(ValueError, match="pat=.* must contain a symbolic group name."):
ser.str.extract(r"[ab](\d)")


@pytest.mark.parametrize("expand", [True, False])
def test_str_extract(expand):
ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string()))
result = ser.str.extract(r"(?P<letter>[ab])(?P<digit>\d)", expand=expand)
expected = pd.DataFrame(
{
"letter": ArrowExtensionArray(pa.array(["a", "b", None])),
"digit": ArrowExtensionArray(pa.array(["1", "2", None])),
}
)
tm.assert_frame_equal(result, expected)


def test_str_extract_expand():
ser = pd.Series(["a1", "b2", "c3"], dtype=ArrowDtype(pa.string()))
result = ser.str.extract(r"[ab](?P<digit>\d)", expand=True)
expected = pd.DataFrame(
{
"digit": ArrowExtensionArray(pa.array(["1", "2", None])),
}
)
tm.assert_frame_equal(result, expected)

result = ser.str.extract(r"[ab](?P<digit>\d)", expand=False)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add coverage for when there are multiple groups and expand=False?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I parametrized the above test with expand=

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't the expand=False case only produce a series?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily:

https://pandas.pydata.org/docs/reference/api/pandas.Series.str.extract.html

If True, return DataFrame with one column per capture group. If False, return a Series/Index if there is one capture group or DataFrame if there are multiple capture groups.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK - sorry I misremembered how that works. I was somewhat expecting expand=False here to just return the struct array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW I'm OK with this now too as it is definitely consistent with our current API. There may be a future where we actuallly want it to return a pyarrow struct but can always come back and do that later

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agreed. For arrow types, I think in the future expand=False should return a struct typed Series

expected = pd.Series(ArrowExtensionArray(pa.array(["1", "2", None])), name="digit")
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize("unit", ["ns", "us", "ms", "s"])
def test_duration_from_strings_with_nat(unit):
# GH51175
Expand Down