Skip to content

ENH: Allow ArrowDtype(pa.string()) to be compatable with str accessor #51207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f8930b4
Remove __init__ import
mroeschke Feb 3, 2023
0c31e1e
Add base class methods
mroeschke Feb 4, 2023
f8aa6e5
Adapt for groupby
mroeschke Feb 4, 2023
092f58f
Start adding test pt 1
mroeschke Feb 4, 2023
10ad912
Add more tests part 2
mroeschke Feb 4, 2023
0ae2e1d
Test pt 3
mroeschke Feb 4, 2023
37d0add
More tests
mroeschke Feb 7, 2023
4e6d24e
finish tests
mroeschke Feb 7, 2023
1444393
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 7, 2023
c74473f
xfail dask test due to moved path
mroeschke Feb 7, 2023
d7a463f
Add whatsnew
mroeschke Feb 7, 2023
bca214b
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 7, 2023
62d2f58
Fix import
mroeschke Feb 7, 2023
6455214
Define len
mroeschke Feb 7, 2023
8808c88
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 8, 2023
b57f142
Address some comments
mroeschke Feb 8, 2023
852920d
address more dask tests
mroeschke Feb 8, 2023
53d2083
Revert groupby change
mroeschke Feb 8, 2023
977b698
fix some tests
mroeschke Feb 8, 2023
2014a06
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 9, 2023
5bb7d2a
Typing
mroeschke Feb 9, 2023
0d15998
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 10, 2023
112e185
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 14, 2023
8b82b1b
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 14, 2023
5d7ff0c
Undo downstream changes
mroeschke Feb 14, 2023
29a5cbb
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 14, 2023
a3f53fd
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 15, 2023
e24a3ea
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 15, 2023
35e35e3
Improve error message
mroeschke Feb 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ Alternatively, copy on write can be enabled locally through:

Other enhancements
^^^^^^^^^^^^^^^^^^
- Added support for ``str`` accessor methods when using :class:`ArrowDtype` with a ``pyarrow.string`` type (:issue:`50325`)
- Added support for ``dt`` accessor methods when using :class:`ArrowDtype` with a ``pyarrow.timestamp`` type (:issue:`50954`)
- :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`)
- :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
Expand Down
317 changes: 316 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from copy import deepcopy
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Sequence,
TypeVar,
cast,
)
Expand Down Expand Up @@ -55,6 +58,7 @@
unpack_tuple_and_ellipses,
validate_indices,
)
from pandas.core.strings.base import BaseStringArrayMethods

from pandas.tseries.frequencies import to_offset

Expand Down Expand Up @@ -165,7 +169,7 @@ def to_pyarrow_type(
return None


class ArrowExtensionArray(OpsMixin, ExtensionArray):
class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods):
"""
Pandas ExtensionArray backed by a PyArrow ChunkedArray.

Expand Down Expand Up @@ -1463,6 +1467,317 @@ def _replace_with_mask(
result[mask] = replacements
return pa.array(result, type=values.type, from_pandas=True)

def _str_count(self, pat: str, flags: int = 0):
if flags:
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._data, pat))

def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._data, width=width, padding=fillchar))

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")

if regex:
pa_contains = pc.match_substring_regex
else:
pa_contains = pc.match_substring
result = pa_contains(self._data, pat, ignore_case=not case)
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_startswith(self, pat: str, na=None):
result = pc.starts_with(self._data, pattern=pat)
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_endswith(self, pat: str, na=None):
result = pc.ends_with(self._data, pattern=pat)
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_replace(
self,
pat: str | re.Pattern,
repl: str | Callable,
n: int = -1,
case: bool = True,
flags: int = 0,
regex: bool = True,
):
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
raise NotImplementedError(
"replace is not supported with a re.Pattern, callable repl, "
"case=False, or flags!=0"
)

func = pc.replace_substring_regex if regex else pc.replace_substring
result = func(self._data, pattern=pat, replacement=repl, max_replacements=n)
return type(self)(result)

def _str_repeat(self, repeats: int | Sequence[int]):
if not isinstance(repeats, int):
raise NotImplementedError(
f"repeat is not implemented when repeats is {type(repeats).__name__}"
)
elif pa_version_under7p0:
raise NotImplementedError("repeat is not implemented for pyarrow < 7")
else:
return type(self)(pc.binary_repeat(self._data, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("//$"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see this is already in the string_arrow code, but why "//$"?

Copy link
Member Author

@mroeschke mroeschke Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, as you mentioned I just copied it over from the string_arrow code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Best guess is it should be backslashes, meant to exclude escaped dollar signs. OK to consider out of scope.

pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._data, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._data
result = pc.find_substring(slices, sub)
else:
raise NotImplementedError(
f"find not implemented with {sub=}, {start=}, {end=}"
)
return type(self)(result)

def _str_get(self, i: int):
lengths = pc.utf8_length(self._data)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._data, start=start, stop=stop, step=step
)
result = pa.array([None] * self._data.length(), type=self._data.type)
result = pc.if_else(not_out_of_bounds, selected, result)
return type(self)(result)

def _str_join(self, sep: str):
return type(self)(pc.binary_join(self._data, sep))

def _str_partition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.partition not supported with pd.ArrowDtype(pa.string())."
)

def _str_rpartition(self, sep: str, expand: bool):
raise NotImplementedError(
"str.rpartition not supported with pd.ArrowDtype(pa.string())."
)

def _str_slice(
self, start: int | None = None, stop: int | None = None, step: int | None = None
):
if start is None:
start = 0
if step is None:
step = 1
return type(self)(
pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step)
)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl))

def _str_isalnum(self):
return type(self)(pc.utf8_is_alnum(self._data))

def _str_isalpha(self):
return type(self)(pc.utf8_is_alpha(self._data))

def _str_isdecimal(self):
return type(self)(pc.utf8_is_decimal(self._data))

def _str_isdigit(self):
return type(self)(pc.utf8_is_digit(self._data))

def _str_islower(self):
return type(self)(pc.utf8_is_lower(self._data))

def _str_isnumeric(self):
return type(self)(pc.utf8_is_numeric(self._data))

def _str_isspace(self):
return type(self)(pc.utf8_is_space(self._data))

def _str_istitle(self):
return type(self)(pc.utf8_is_title(self._data))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._data))

def _str_title(self):
return type(self)(pc.utf8_title(self._data))

def _str_isupper(self):
return type(self)(pc.utf8_is_upper(self._data))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._data))

def _str_len(self):
return type(self)(pc.utf8_length(self._data))

def _str_lower(self):
return type(self)(pc.utf8_lower(self._data))

def _str_upper(self):
return type(self)(pc.utf8_upper(self._data))

def _str_strip(self, to_strip=None):
if to_strip is None:
result = pc.utf8_trim_whitespace(self._data)
else:
result = pc.utf8_trim(self._data, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None):
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._data)
else:
result = pc.utf8_ltrim(self._data, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None):
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._data)
else:
result = pc.utf8_rtrim(self._data, characters=to_strip)
return type(self)(result)

def _str_removeprefix(self, prefix: str):
raise NotImplementedError(
"str.removeprefix not supported with pd.ArrowDtype(pa.string())."
)
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
# starts_with = pc.starts_with(self._data, pattern=prefix)
# removed = pc.utf8_slice_codeunits(self._data, len(prefix))
# result = pc.if_else(starts_with, removed, self._data)
# return type(self)(result)

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._data, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._data)
return type(self)(result)

def _str_casefold(self):
raise NotImplementedError(
"str.casefold not supported with pd.ArrowDtype(pa.string())."
)

def _str_encode(self, encoding, errors: str = "strict"):
raise NotImplementedError(
"str.encode not supported with pd.ArrowDtype(pa.string())."
)

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
raise NotImplementedError(
"str.extract not supported with pd.ArrowDtype(pa.string())."
)

def _str_findall(self, pat, flags: int = 0):
raise NotImplementedError(
"str.findall not supported with pd.ArrowDtype(pa.string())."
)

def _str_get_dummies(self, sep: str = "|"):
raise NotImplementedError(
"str.get_dummies not supported with pd.ArrowDtype(pa.string())."
)

def _str_index(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.index not supported with pd.ArrowDtype(pa.string())."
)

def _str_rindex(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rindex not supported with pd.ArrowDtype(pa.string())."
)

def _str_normalize(self, form):
raise NotImplementedError(
"str.normalize not supported with pd.ArrowDtype(pa.string())."
)

def _str_rfind(self, sub, start: int = 0, end=None):
raise NotImplementedError(
"str.rfind not supported with pd.ArrowDtype(pa.string())."
)

def _str_split(
self, pat=None, n=-1, expand: bool = False, regex: bool | None = None
):
raise NotImplementedError(
"str.split not supported with pd.ArrowDtype(pa.string())."
)

def _str_rsplit(self, pat=None, n=-1):
raise NotImplementedError(
"str.rsplit not supported with pd.ArrowDtype(pa.string())."
)

def _str_translate(self, table):
raise NotImplementedError(
"str.translate not supported with pd.ArrowDtype(pa.string())."
)

def _str_wrap(self, width, **kwargs):
raise NotImplementedError(
"str.wrap not supported with pd.ArrowDtype(pa.string())."
)

@property
def _dt_day(self):
return type(self)(pc.day(self._data))
Expand Down
12 changes: 11 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _chk_pyarrow_available() -> None:
# fallback for the ones that pyarrow doesn't yet support


class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray):
"""
Extension array for string data in a ``pyarrow.ChunkedArray``.

Expand Down Expand Up @@ -117,6 +117,16 @@ def __init__(self, values) -> None:
"ArrowStringArray requires a PyArrow (chunked) array of string type"
)

def __len__(self) -> int:
"""
Length of this array.

Returns
-------
length : int
"""
return len(self._data)

@classmethod
def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False):
from pandas.core.arrays.masked import BaseMaskedArray
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@
get_group_index_sorter,
nargsort,
)
from pandas.core.strings import StringMethods
from pandas.core.strings.accessor import StringMethods

from pandas.io.formats.printing import (
PrettyDict,
Expand Down
Loading