Skip to content

ENH: Allow ArrowDtype(pa.string()) to be compatable with str accessor #51207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Feb 16, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
f8930b4
Remove __init__ import
mroeschke Feb 3, 2023
0c31e1e
Add base class methods
mroeschke Feb 4, 2023
f8aa6e5
Adapt for groupby
mroeschke Feb 4, 2023
092f58f
Start adding test pt 1
mroeschke Feb 4, 2023
10ad912
Add more tests part 2
mroeschke Feb 4, 2023
0ae2e1d
Test pt 3
mroeschke Feb 4, 2023
37d0add
More tests
mroeschke Feb 7, 2023
4e6d24e
finish tests
mroeschke Feb 7, 2023
1444393
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 7, 2023
c74473f
xfail dask test due to moved path
mroeschke Feb 7, 2023
d7a463f
Add whatsnew
mroeschke Feb 7, 2023
bca214b
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 7, 2023
62d2f58
Fix import
mroeschke Feb 7, 2023
6455214
Define len
mroeschke Feb 7, 2023
8808c88
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 8, 2023
b57f142
Address some comments
mroeschke Feb 8, 2023
852920d
address more dask tests
mroeschke Feb 8, 2023
53d2083
Revert groupby change
mroeschke Feb 8, 2023
977b698
fix some tests
mroeschke Feb 8, 2023
2014a06
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 9, 2023
5bb7d2a
Typing
mroeschke Feb 9, 2023
0d15998
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 10, 2023
112e185
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 14, 2023
8b82b1b
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 14, 2023
5d7ff0c
Undo downstream changes
mroeschke Feb 14, 2023
29a5cbb
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 14, 2023
a3f53fd
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 15, 2023
e24a3ea
Merge remote-tracking branch 'upstream/main' into enh/str/arrow
mroeschke Feb 15, 2023
35e35e3
Improve error message
mroeschke Feb 15, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@ Alternatively, copy on write can be enabled locally through:

Other enhancements
^^^^^^^^^^^^^^^^^^
- Added support for ``str`` accessor methods when using ``pd.ArrowDtype(pyarrow.string())`` (:issue:`50325`)
- :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`)
- :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
Expand Down
285 changes: 284 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

from copy import deepcopy
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
Literal,
Sequence,
TypeVar,
cast,
)
Expand Down Expand Up @@ -52,6 +55,7 @@
unpack_tuple_and_ellipses,
validate_indices,
)
from pandas.core.strings.base import BaseStringArrayMethods

if not pa_version_under7p0:
import pyarrow as pa
Expand Down Expand Up @@ -150,7 +154,7 @@ def to_pyarrow_type(
return None


class ArrowExtensionArray(OpsMixin, ExtensionArray):
class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods):
"""
Pandas ExtensionArray backed by a PyArrow ChunkedArray.

Expand Down Expand Up @@ -1413,3 +1417,282 @@ def _replace_with_mask(
result = np.array(values, dtype=object)
result[mask] = replacements
return pa.array(result, type=values.type, from_pandas=True)

def _str_count(self, pat: str, flags: int = 0):
if flags:
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._data, pat))

def _str_pad(
self,
width: int,
side: Literal["left", "right", "both"] = "left",
fillchar: str = " ",
):
if side == "left":
pa_pad = pc.utf8_lpad
elif side == "right":
pa_pad = pc.utf8_rpad
elif side == "both":
pa_pad = pc.utf8_center
else:
raise ValueError(
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
)
return type(self)(pa_pad(self._data, width=width, padding=fillchar))

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")

if regex:
pa_contains = pc.match_substring_regex
else:
pa_contains = pc.match_substring
result = pa_contains(self._data, pat, ignore_case=not case)
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_startswith(self, pat: str, na=None):
result = pc.starts_with(self._data, pattern=pat)
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_endswith(self, pat: str, na=None):
result = pc.ends_with(self._data, pattern=pat)
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_replace(
self,
pat: str | re.Pattern,
repl: str | Callable,
n: int = -1,
case: bool = True,
flags: int = 0,
regex: bool = True,
):
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
raise NotImplementedError(
"replace is not supported with a re.Pattern, callable repl, "
"case=False, or flags!=0"
)

func = pc.replace_substring_regex if regex else pc.replace_substring
result = func(self._data, pattern=pat, replacement=repl, max_replacements=n)
return type(self)(result)

def _str_repeat(self, repeats: int | Sequence[int]):
if not isinstance(repeats, int):
raise NotImplementedError(
f"repeat is not implemented when repeats is {type(repeats).__name__}"
)
elif pa_version_under7p0:
raise NotImplementedError("repeat is not implemented for pyarrow < 7")
else:
return type(self)(pc.binary_repeat(self._data, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("//$"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see this is already in the string_arrow code, but why "//$"?

Copy link
Member Author

@mroeschke mroeschke Feb 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure, as you mentioned I just copied it over from the string_arrow code

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Best guess is it should be backslashes, meant to exclude escaped dollar signs. OK to consider out of scope.

pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._data, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._data
result = pc.find_substring(slices, sub)
else:
raise NotImplementedError(
f"find not implemented with {sub=}, {start=}, {end=}"
)
return type(self)(result)

def _str_get(self, i: int):
lengths = pc.utf8_length(self._data)
if i >= 0:
out_of_bounds = pc.greater_equal(i, lengths)
start = i
stop = i + 1
step = 1
else:
out_of_bounds = pc.greater(-i, lengths)
start = i
stop = i - 1
step = -1
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
selected = pc.utf8_slice_codeunits(
self._data, start=start, stop=stop, step=step
)
result = pa.array([None] * self._data.length(), type=self._data.type)
result = pc.if_else(not_out_of_bounds, selected, result)
return type(self)(result)

def _str_join(self, sep: str):
return type(self)(pc.binary_join(self._data, sep))

def _str_partition(self, sep: str, expand: bool):
raise NotImplementedError("str.partition not supported.")

def _str_rpartition(self, sep: str, expand: bool):
raise NotImplementedError("str.rpartition not supported.")

def _str_slice(
self, start: int | None = None, stop: int | None = None, step: int | None = None
):
if start is None:
start = 0
if step is None:
step = 1
return type(self)(
pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step)
)

def _str_slice_replace(
self, start: int | None = None, stop: int | None = None, repl: str | None = None
):
if repl is None:
repl = ""
if start is None:
start = 0
return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl))

def _str_isalnum(self):
return type(self)(pc.utf8_is_alnum(self._data))

def _str_isalpha(self):
return type(self)(pc.utf8_is_alpha(self._data))

def _str_isdecimal(self):
return type(self)(pc.utf8_is_decimal(self._data))

def _str_isdigit(self):
return type(self)(pc.utf8_is_digit(self._data))

def _str_islower(self):
return type(self)(pc.utf8_is_lower(self._data))

def _str_isnumeric(self):
return type(self)(pc.utf8_is_numeric(self._data))

def _str_isspace(self):
return type(self)(pc.utf8_is_space(self._data))

def _str_istitle(self):
return type(self)(pc.utf8_is_title(self._data))

def _str_capitalize(self):
return type(self)(pc.utf8_capitalize(self._data))

def _str_title(self):
return type(self)(pc.utf8_title(self._data))

def _str_isupper(self):
return type(self)(pc.utf8_is_upper(self._data))

def _str_swapcase(self):
return type(self)(pc.utf8_swapcase(self._data))

def _str_len(self):
return type(self)(pc.utf8_length(self._data))

def _str_lower(self):
return type(self)(pc.utf8_lower(self._data))

def _str_upper(self):
return type(self)(pc.utf8_upper(self._data))

def _str_strip(self, to_strip=None):
if to_strip is None:
result = pc.utf8_trim_whitespace(self._data)
else:
result = pc.utf8_trim(self._data, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None):
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._data)
else:
result = pc.utf8_ltrim(self._data, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None):
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._data)
else:
result = pc.utf8_rtrim(self._data, characters=to_strip)
return type(self)(result)

def _str_removeprefix(self, prefix: str) -> Series:
raise NotImplementedError("str.removeprefix not supported.")
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
# starts_with = pc.starts_with(self._data, pattern=prefix)
# removed = pc.utf8_slice_codeunits(self._data, len(prefix))
# result = pc.if_else(starts_with, removed, self._data)
# return type(self)(result)

def _str_removesuffix(self, suffix: str) -> Series:
ends_with = pc.ends_with(self._data, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._data)
return type(self)(result)

def _str_casefold(self):
raise NotImplementedError("str.casefold not supported.")

def _str_encode(self, encoding, errors: str = "strict"):
raise NotImplementedError("str.encode not supported.")

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
raise NotImplementedError("str.extract not supported.")

def _str_findall(self, pat, flags: int = 0):
raise NotImplementedError("str.findall not supported.")

def _str_get_dummies(self, sep: str = "|"):
raise NotImplementedError("str.get_dummies not supported.")

def _str_index(self, sub, start: int = 0, end=None):
raise NotImplementedError("str.index not supported.")

def _str_rindex(self, sub, start: int = 0, end=None):
raise NotImplementedError("str.rindex not supported.")

def _str_normalize(self, form):
raise NotImplementedError("str.normalize not supported.")

def _str_rfind(self, sub, start: int = 0, end=None):
raise NotImplementedError("str.rfind not supported.")

def _str_split(
self, pat=None, n=-1, expand: bool = False, regex: bool | None = None
):
raise NotImplementedError("str.split not supported.")

def _str_rsplit(self, pat=None, n=-1):
raise NotImplementedError("str.rsplit not supported.")

def _str_translate(self, table):
raise NotImplementedError("str.translate not supported.")

def _str_wrap(self, width, **kwargs):
raise NotImplementedError("str.wrap not supported.")
2 changes: 1 addition & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _chk_pyarrow_available() -> None:
# fallback for the ones that pyarrow doesn't yet support


class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray):
"""
Extension array for string data in a ``pyarrow.ChunkedArray``.

Expand Down
7 changes: 4 additions & 3 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
is_numeric_dtype,
is_period_dtype,
is_sparse,
is_string_dtype,
is_timedelta64_dtype,
needs_i8_conversion,
)
Expand All @@ -76,7 +77,6 @@
BaseMaskedArray,
BaseMaskedDtype,
)
from pandas.core.arrays.string_ import StringDtype
from pandas.core.frame import DataFrame
from pandas.core.groupby import grouper
from pandas.core.indexes.api import (
Expand All @@ -96,6 +96,7 @@
)

if TYPE_CHECKING:
from pandas.core.arrays.string_ import StringDtype
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the annotation that uses this may no longer be accurate?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Added to | ArrowDtype to the annotation where this is used

from pandas.core.generic import NDFrame


Expand Down Expand Up @@ -387,7 +388,7 @@ def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
npvalues = values._ndarray.view("M8[ns]")
elif isinstance(values.dtype, StringDtype):
elif is_string_dtype(values.dtype):
# StringArray
npvalues = values.to_numpy(object, na_value=np.nan)
else:
Expand All @@ -405,7 +406,7 @@ def _reconstruct_ea_result(
"""
dtype: BaseMaskedDtype | StringDtype

if isinstance(values.dtype, StringDtype):
if is_string_dtype(values.dtype):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt make a difference here bc we know we have EA, but we should have a func like is_string_dtype that excludes object

dtype = values.dtype
string_array_cls = dtype.construct_array_type()
return string_array_cls._from_sequence(res_values, dtype=dtype)
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@
get_group_index_sorter,
nargsort,
)
from pandas.core.strings import StringMethods
from pandas.core.strings.accessor import StringMethods

from pandas.io.formats.printing import (
PrettyDict,
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@
ensure_key_mapped,
nargsort,
)
from pandas.core.strings import StringMethods
from pandas.core.strings.accessor import StringMethods
from pandas.core.tools.datetimes import to_datetime

import pandas.io.formats.format as fmt
Expand Down
5 changes: 0 additions & 5 deletions pandas/core/strings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,3 @@
# - PandasArray
# - Categorical
# - ArrowStringArray

from pandas.core.strings.accessor import StringMethods
from pandas.core.strings.base import BaseStringArrayMethods

__all__ = ["StringMethods", "BaseStringArrayMethods"]
Loading