Skip to content

Commit 0796d9d

Browse files
authored
ENH: Allow ArrowDtype(pa.string()) to be compatable with str accessor (#51207)
* Remove __init__ import * Add base class methods * Adapt for groupby * Start adding test pt 1 * Add more tests part 2 * Test pt 3 * More tests * finish tests * xfail dask test due to moved path * Add whatsnew * Fix import * Define len * Address some comments * address more dask tests * Revert groupby change * fix some tests * Typing * Undo downstream changes * Improve error message
1 parent 07667f3 commit 0796d9d

File tree

11 files changed

+678
-21
lines changed

11 files changed

+678
-21
lines changed

doc/source/whatsnew/v2.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,7 @@ Alternatively, copy on write can be enabled locally through:
285285

286286
Other enhancements
287287
^^^^^^^^^^^^^^^^^^
288+
- Added support for ``str`` accessor methods when using :class:`ArrowDtype` with a ``pyarrow.string`` type (:issue:`50325`)
288289
- Added support for ``dt`` accessor methods when using :class:`ArrowDtype` with a ``pyarrow.timestamp`` type (:issue:`50954`)
289290
- :func:`read_sas` now supports using ``encoding='infer'`` to correctly read and use the encoding specified by the sas file. (:issue:`48048`)
290291
- :meth:`.DataFrameGroupBy.quantile`, :meth:`.SeriesGroupBy.quantile` and :meth:`.DataFrameGroupBy.std` now preserve nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)

pandas/core/arrays/arrow/array.py

+316-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
from __future__ import annotations
22

33
from copy import deepcopy
4+
import re
45
from typing import (
56
TYPE_CHECKING,
67
Any,
8+
Callable,
79
Literal,
10+
Sequence,
811
TypeVar,
912
cast,
1013
)
@@ -55,6 +58,7 @@
5558
unpack_tuple_and_ellipses,
5659
validate_indices,
5760
)
61+
from pandas.core.strings.base import BaseStringArrayMethods
5862

5963
from pandas.tseries.frequencies import to_offset
6064

@@ -165,7 +169,7 @@ def to_pyarrow_type(
165169
return None
166170

167171

168-
class ArrowExtensionArray(OpsMixin, ExtensionArray):
172+
class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods):
169173
"""
170174
Pandas ExtensionArray backed by a PyArrow ChunkedArray.
171175
@@ -1463,6 +1467,317 @@ def _replace_with_mask(
14631467
result[mask] = replacements
14641468
return pa.array(result, type=values.type, from_pandas=True)
14651469

1470+
def _str_count(self, pat: str, flags: int = 0):
1471+
if flags:
1472+
raise NotImplementedError(f"count not implemented with {flags=}")
1473+
return type(self)(pc.count_substring_regex(self._data, pat))
1474+
1475+
def _str_pad(
1476+
self,
1477+
width: int,
1478+
side: Literal["left", "right", "both"] = "left",
1479+
fillchar: str = " ",
1480+
):
1481+
if side == "left":
1482+
pa_pad = pc.utf8_lpad
1483+
elif side == "right":
1484+
pa_pad = pc.utf8_rpad
1485+
elif side == "both":
1486+
pa_pad = pc.utf8_center
1487+
else:
1488+
raise ValueError(
1489+
f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'"
1490+
)
1491+
return type(self)(pa_pad(self._data, width=width, padding=fillchar))
1492+
1493+
def _str_contains(
1494+
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
1495+
):
1496+
if flags:
1497+
raise NotImplementedError(f"contains not implemented with {flags=}")
1498+
1499+
if regex:
1500+
pa_contains = pc.match_substring_regex
1501+
else:
1502+
pa_contains = pc.match_substring
1503+
result = pa_contains(self._data, pat, ignore_case=not case)
1504+
if not isna(na):
1505+
result = result.fill_null(na)
1506+
return type(self)(result)
1507+
1508+
def _str_startswith(self, pat: str, na=None):
1509+
result = pc.starts_with(self._data, pattern=pat)
1510+
if not isna(na):
1511+
result = result.fill_null(na)
1512+
return type(self)(result)
1513+
1514+
def _str_endswith(self, pat: str, na=None):
1515+
result = pc.ends_with(self._data, pattern=pat)
1516+
if not isna(na):
1517+
result = result.fill_null(na)
1518+
return type(self)(result)
1519+
1520+
def _str_replace(
1521+
self,
1522+
pat: str | re.Pattern,
1523+
repl: str | Callable,
1524+
n: int = -1,
1525+
case: bool = True,
1526+
flags: int = 0,
1527+
regex: bool = True,
1528+
):
1529+
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
1530+
raise NotImplementedError(
1531+
"replace is not supported with a re.Pattern, callable repl, "
1532+
"case=False, or flags!=0"
1533+
)
1534+
1535+
func = pc.replace_substring_regex if regex else pc.replace_substring
1536+
result = func(self._data, pattern=pat, replacement=repl, max_replacements=n)
1537+
return type(self)(result)
1538+
1539+
def _str_repeat(self, repeats: int | Sequence[int]):
1540+
if not isinstance(repeats, int):
1541+
raise NotImplementedError(
1542+
f"repeat is not implemented when repeats is {type(repeats).__name__}"
1543+
)
1544+
elif pa_version_under7p0:
1545+
raise NotImplementedError("repeat is not implemented for pyarrow < 7")
1546+
else:
1547+
return type(self)(pc.binary_repeat(self._data, repeats))
1548+
1549+
def _str_match(
1550+
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
1551+
):
1552+
if not pat.startswith("^"):
1553+
pat = f"^{pat}"
1554+
return self._str_contains(pat, case, flags, na, regex=True)
1555+
1556+
def _str_fullmatch(
1557+
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
1558+
):
1559+
if not pat.endswith("$") or pat.endswith("//$"):
1560+
pat = f"{pat}$"
1561+
return self._str_match(pat, case, flags, na)
1562+
1563+
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
1564+
if start != 0 and end is not None:
1565+
slices = pc.utf8_slice_codeunits(self._data, start, stop=end)
1566+
result = pc.find_substring(slices, sub)
1567+
not_found = pc.equal(result, -1)
1568+
offset_result = pc.add(result, end - start)
1569+
result = pc.if_else(not_found, result, offset_result)
1570+
elif start == 0 and end is None:
1571+
slices = self._data
1572+
result = pc.find_substring(slices, sub)
1573+
else:
1574+
raise NotImplementedError(
1575+
f"find not implemented with {sub=}, {start=}, {end=}"
1576+
)
1577+
return type(self)(result)
1578+
1579+
def _str_get(self, i: int):
1580+
lengths = pc.utf8_length(self._data)
1581+
if i >= 0:
1582+
out_of_bounds = pc.greater_equal(i, lengths)
1583+
start = i
1584+
stop = i + 1
1585+
step = 1
1586+
else:
1587+
out_of_bounds = pc.greater(-i, lengths)
1588+
start = i
1589+
stop = i - 1
1590+
step = -1
1591+
not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True))
1592+
selected = pc.utf8_slice_codeunits(
1593+
self._data, start=start, stop=stop, step=step
1594+
)
1595+
result = pa.array([None] * self._data.length(), type=self._data.type)
1596+
result = pc.if_else(not_out_of_bounds, selected, result)
1597+
return type(self)(result)
1598+
1599+
def _str_join(self, sep: str):
1600+
return type(self)(pc.binary_join(self._data, sep))
1601+
1602+
def _str_partition(self, sep: str, expand: bool):
1603+
raise NotImplementedError(
1604+
"str.partition not supported with pd.ArrowDtype(pa.string())."
1605+
)
1606+
1607+
def _str_rpartition(self, sep: str, expand: bool):
1608+
raise NotImplementedError(
1609+
"str.rpartition not supported with pd.ArrowDtype(pa.string())."
1610+
)
1611+
1612+
def _str_slice(
1613+
self, start: int | None = None, stop: int | None = None, step: int | None = None
1614+
):
1615+
if start is None:
1616+
start = 0
1617+
if step is None:
1618+
step = 1
1619+
return type(self)(
1620+
pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step)
1621+
)
1622+
1623+
def _str_slice_replace(
1624+
self, start: int | None = None, stop: int | None = None, repl: str | None = None
1625+
):
1626+
if repl is None:
1627+
repl = ""
1628+
if start is None:
1629+
start = 0
1630+
return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl))
1631+
1632+
def _str_isalnum(self):
1633+
return type(self)(pc.utf8_is_alnum(self._data))
1634+
1635+
def _str_isalpha(self):
1636+
return type(self)(pc.utf8_is_alpha(self._data))
1637+
1638+
def _str_isdecimal(self):
1639+
return type(self)(pc.utf8_is_decimal(self._data))
1640+
1641+
def _str_isdigit(self):
1642+
return type(self)(pc.utf8_is_digit(self._data))
1643+
1644+
def _str_islower(self):
1645+
return type(self)(pc.utf8_is_lower(self._data))
1646+
1647+
def _str_isnumeric(self):
1648+
return type(self)(pc.utf8_is_numeric(self._data))
1649+
1650+
def _str_isspace(self):
1651+
return type(self)(pc.utf8_is_space(self._data))
1652+
1653+
def _str_istitle(self):
1654+
return type(self)(pc.utf8_is_title(self._data))
1655+
1656+
def _str_capitalize(self):
1657+
return type(self)(pc.utf8_capitalize(self._data))
1658+
1659+
def _str_title(self):
1660+
return type(self)(pc.utf8_title(self._data))
1661+
1662+
def _str_isupper(self):
1663+
return type(self)(pc.utf8_is_upper(self._data))
1664+
1665+
def _str_swapcase(self):
1666+
return type(self)(pc.utf8_swapcase(self._data))
1667+
1668+
def _str_len(self):
1669+
return type(self)(pc.utf8_length(self._data))
1670+
1671+
def _str_lower(self):
1672+
return type(self)(pc.utf8_lower(self._data))
1673+
1674+
def _str_upper(self):
1675+
return type(self)(pc.utf8_upper(self._data))
1676+
1677+
def _str_strip(self, to_strip=None):
1678+
if to_strip is None:
1679+
result = pc.utf8_trim_whitespace(self._data)
1680+
else:
1681+
result = pc.utf8_trim(self._data, characters=to_strip)
1682+
return type(self)(result)
1683+
1684+
def _str_lstrip(self, to_strip=None):
1685+
if to_strip is None:
1686+
result = pc.utf8_ltrim_whitespace(self._data)
1687+
else:
1688+
result = pc.utf8_ltrim(self._data, characters=to_strip)
1689+
return type(self)(result)
1690+
1691+
def _str_rstrip(self, to_strip=None):
1692+
if to_strip is None:
1693+
result = pc.utf8_rtrim_whitespace(self._data)
1694+
else:
1695+
result = pc.utf8_rtrim(self._data, characters=to_strip)
1696+
return type(self)(result)
1697+
1698+
def _str_removeprefix(self, prefix: str):
1699+
raise NotImplementedError(
1700+
"str.removeprefix not supported with pd.ArrowDtype(pa.string())."
1701+
)
1702+
# TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed
1703+
# starts_with = pc.starts_with(self._data, pattern=prefix)
1704+
# removed = pc.utf8_slice_codeunits(self._data, len(prefix))
1705+
# result = pc.if_else(starts_with, removed, self._data)
1706+
# return type(self)(result)
1707+
1708+
def _str_removesuffix(self, suffix: str):
1709+
ends_with = pc.ends_with(self._data, pattern=suffix)
1710+
removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix))
1711+
result = pc.if_else(ends_with, removed, self._data)
1712+
return type(self)(result)
1713+
1714+
def _str_casefold(self):
1715+
raise NotImplementedError(
1716+
"str.casefold not supported with pd.ArrowDtype(pa.string())."
1717+
)
1718+
1719+
def _str_encode(self, encoding, errors: str = "strict"):
1720+
raise NotImplementedError(
1721+
"str.encode not supported with pd.ArrowDtype(pa.string())."
1722+
)
1723+
1724+
def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
1725+
raise NotImplementedError(
1726+
"str.extract not supported with pd.ArrowDtype(pa.string())."
1727+
)
1728+
1729+
def _str_findall(self, pat, flags: int = 0):
1730+
raise NotImplementedError(
1731+
"str.findall not supported with pd.ArrowDtype(pa.string())."
1732+
)
1733+
1734+
def _str_get_dummies(self, sep: str = "|"):
1735+
raise NotImplementedError(
1736+
"str.get_dummies not supported with pd.ArrowDtype(pa.string())."
1737+
)
1738+
1739+
def _str_index(self, sub, start: int = 0, end=None):
1740+
raise NotImplementedError(
1741+
"str.index not supported with pd.ArrowDtype(pa.string())."
1742+
)
1743+
1744+
def _str_rindex(self, sub, start: int = 0, end=None):
1745+
raise NotImplementedError(
1746+
"str.rindex not supported with pd.ArrowDtype(pa.string())."
1747+
)
1748+
1749+
def _str_normalize(self, form):
1750+
raise NotImplementedError(
1751+
"str.normalize not supported with pd.ArrowDtype(pa.string())."
1752+
)
1753+
1754+
def _str_rfind(self, sub, start: int = 0, end=None):
1755+
raise NotImplementedError(
1756+
"str.rfind not supported with pd.ArrowDtype(pa.string())."
1757+
)
1758+
1759+
def _str_split(
1760+
self, pat=None, n=-1, expand: bool = False, regex: bool | None = None
1761+
):
1762+
raise NotImplementedError(
1763+
"str.split not supported with pd.ArrowDtype(pa.string())."
1764+
)
1765+
1766+
def _str_rsplit(self, pat=None, n=-1):
1767+
raise NotImplementedError(
1768+
"str.rsplit not supported with pd.ArrowDtype(pa.string())."
1769+
)
1770+
1771+
def _str_translate(self, table):
1772+
raise NotImplementedError(
1773+
"str.translate not supported with pd.ArrowDtype(pa.string())."
1774+
)
1775+
1776+
def _str_wrap(self, width, **kwargs):
1777+
raise NotImplementedError(
1778+
"str.wrap not supported with pd.ArrowDtype(pa.string())."
1779+
)
1780+
14661781
@property
14671782
def _dt_day(self):
14681783
return type(self)(pc.day(self._data))

pandas/core/arrays/string_arrow.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def _chk_pyarrow_available() -> None:
6060
# fallback for the ones that pyarrow doesn't yet support
6161

6262

63-
class ArrowStringArray(ArrowExtensionArray, BaseStringArray, ObjectStringArrayMixin):
63+
class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringArray):
6464
"""
6565
Extension array for string data in a ``pyarrow.ChunkedArray``.
6666
@@ -117,6 +117,16 @@ def __init__(self, values) -> None:
117117
"ArrowStringArray requires a PyArrow (chunked) array of string type"
118118
)
119119

120+
def __len__(self) -> int:
121+
"""
122+
Length of this array.
123+
124+
Returns
125+
-------
126+
length : int
127+
"""
128+
return len(self._data)
129+
120130
@classmethod
121131
def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False):
122132
from pandas.core.arrays.masked import BaseMaskedArray

pandas/core/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@
169169
get_group_index_sorter,
170170
nargsort,
171171
)
172-
from pandas.core.strings import StringMethods
172+
from pandas.core.strings.accessor import StringMethods
173173

174174
from pandas.io.formats.printing import (
175175
PrettyDict,

0 commit comments

Comments
 (0)