|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | from copy import deepcopy
|
| 4 | +import re |
4 | 5 | from typing import (
|
5 | 6 | TYPE_CHECKING,
|
6 | 7 | Any,
|
| 8 | + Callable, |
7 | 9 | Literal,
|
| 10 | + Sequence, |
8 | 11 | TypeVar,
|
9 | 12 | cast,
|
10 | 13 | )
|
|
55 | 58 | unpack_tuple_and_ellipses,
|
56 | 59 | validate_indices,
|
57 | 60 | )
|
| 61 | +from pandas.core.strings.base import BaseStringArrayMethods |
58 | 62 |
|
59 | 63 | from pandas.tseries.frequencies import to_offset
|
60 | 64 |
|
@@ -165,7 +169,7 @@ def to_pyarrow_type(
|
165 | 169 | return None
|
166 | 170 |
|
167 | 171 |
|
168 |
| -class ArrowExtensionArray(OpsMixin, ExtensionArray): |
| 172 | +class ArrowExtensionArray(OpsMixin, ExtensionArray, BaseStringArrayMethods): |
169 | 173 | """
|
170 | 174 | Pandas ExtensionArray backed by a PyArrow ChunkedArray.
|
171 | 175 |
|
@@ -1463,6 +1467,317 @@ def _replace_with_mask(
|
1463 | 1467 | result[mask] = replacements
|
1464 | 1468 | return pa.array(result, type=values.type, from_pandas=True)
|
1465 | 1469 |
|
| 1470 | + def _str_count(self, pat: str, flags: int = 0): |
| 1471 | + if flags: |
| 1472 | + raise NotImplementedError(f"count not implemented with {flags=}") |
| 1473 | + return type(self)(pc.count_substring_regex(self._data, pat)) |
| 1474 | + |
| 1475 | + def _str_pad( |
| 1476 | + self, |
| 1477 | + width: int, |
| 1478 | + side: Literal["left", "right", "both"] = "left", |
| 1479 | + fillchar: str = " ", |
| 1480 | + ): |
| 1481 | + if side == "left": |
| 1482 | + pa_pad = pc.utf8_lpad |
| 1483 | + elif side == "right": |
| 1484 | + pa_pad = pc.utf8_rpad |
| 1485 | + elif side == "both": |
| 1486 | + pa_pad = pc.utf8_center |
| 1487 | + else: |
| 1488 | + raise ValueError( |
| 1489 | + f"Invalid side: {side}. Side must be one of 'left', 'right', 'both'" |
| 1490 | + ) |
| 1491 | + return type(self)(pa_pad(self._data, width=width, padding=fillchar)) |
| 1492 | + |
| 1493 | + def _str_contains( |
| 1494 | + self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True |
| 1495 | + ): |
| 1496 | + if flags: |
| 1497 | + raise NotImplementedError(f"contains not implemented with {flags=}") |
| 1498 | + |
| 1499 | + if regex: |
| 1500 | + pa_contains = pc.match_substring_regex |
| 1501 | + else: |
| 1502 | + pa_contains = pc.match_substring |
| 1503 | + result = pa_contains(self._data, pat, ignore_case=not case) |
| 1504 | + if not isna(na): |
| 1505 | + result = result.fill_null(na) |
| 1506 | + return type(self)(result) |
| 1507 | + |
| 1508 | + def _str_startswith(self, pat: str, na=None): |
| 1509 | + result = pc.starts_with(self._data, pattern=pat) |
| 1510 | + if not isna(na): |
| 1511 | + result = result.fill_null(na) |
| 1512 | + return type(self)(result) |
| 1513 | + |
| 1514 | + def _str_endswith(self, pat: str, na=None): |
| 1515 | + result = pc.ends_with(self._data, pattern=pat) |
| 1516 | + if not isna(na): |
| 1517 | + result = result.fill_null(na) |
| 1518 | + return type(self)(result) |
| 1519 | + |
| 1520 | + def _str_replace( |
| 1521 | + self, |
| 1522 | + pat: str | re.Pattern, |
| 1523 | + repl: str | Callable, |
| 1524 | + n: int = -1, |
| 1525 | + case: bool = True, |
| 1526 | + flags: int = 0, |
| 1527 | + regex: bool = True, |
| 1528 | + ): |
| 1529 | + if isinstance(pat, re.Pattern) or callable(repl) or not case or flags: |
| 1530 | + raise NotImplementedError( |
| 1531 | + "replace is not supported with a re.Pattern, callable repl, " |
| 1532 | + "case=False, or flags!=0" |
| 1533 | + ) |
| 1534 | + |
| 1535 | + func = pc.replace_substring_regex if regex else pc.replace_substring |
| 1536 | + result = func(self._data, pattern=pat, replacement=repl, max_replacements=n) |
| 1537 | + return type(self)(result) |
| 1538 | + |
| 1539 | + def _str_repeat(self, repeats: int | Sequence[int]): |
| 1540 | + if not isinstance(repeats, int): |
| 1541 | + raise NotImplementedError( |
| 1542 | + f"repeat is not implemented when repeats is {type(repeats).__name__}" |
| 1543 | + ) |
| 1544 | + elif pa_version_under7p0: |
| 1545 | + raise NotImplementedError("repeat is not implemented for pyarrow < 7") |
| 1546 | + else: |
| 1547 | + return type(self)(pc.binary_repeat(self._data, repeats)) |
| 1548 | + |
| 1549 | + def _str_match( |
| 1550 | + self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None |
| 1551 | + ): |
| 1552 | + if not pat.startswith("^"): |
| 1553 | + pat = f"^{pat}" |
| 1554 | + return self._str_contains(pat, case, flags, na, regex=True) |
| 1555 | + |
| 1556 | + def _str_fullmatch( |
| 1557 | + self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None |
| 1558 | + ): |
| 1559 | + if not pat.endswith("$") or pat.endswith("//$"): |
| 1560 | + pat = f"{pat}$" |
| 1561 | + return self._str_match(pat, case, flags, na) |
| 1562 | + |
| 1563 | + def _str_find(self, sub: str, start: int = 0, end: int | None = None): |
| 1564 | + if start != 0 and end is not None: |
| 1565 | + slices = pc.utf8_slice_codeunits(self._data, start, stop=end) |
| 1566 | + result = pc.find_substring(slices, sub) |
| 1567 | + not_found = pc.equal(result, -1) |
| 1568 | + offset_result = pc.add(result, end - start) |
| 1569 | + result = pc.if_else(not_found, result, offset_result) |
| 1570 | + elif start == 0 and end is None: |
| 1571 | + slices = self._data |
| 1572 | + result = pc.find_substring(slices, sub) |
| 1573 | + else: |
| 1574 | + raise NotImplementedError( |
| 1575 | + f"find not implemented with {sub=}, {start=}, {end=}" |
| 1576 | + ) |
| 1577 | + return type(self)(result) |
| 1578 | + |
| 1579 | + def _str_get(self, i: int): |
| 1580 | + lengths = pc.utf8_length(self._data) |
| 1581 | + if i >= 0: |
| 1582 | + out_of_bounds = pc.greater_equal(i, lengths) |
| 1583 | + start = i |
| 1584 | + stop = i + 1 |
| 1585 | + step = 1 |
| 1586 | + else: |
| 1587 | + out_of_bounds = pc.greater(-i, lengths) |
| 1588 | + start = i |
| 1589 | + stop = i - 1 |
| 1590 | + step = -1 |
| 1591 | + not_out_of_bounds = pc.invert(out_of_bounds.fill_null(True)) |
| 1592 | + selected = pc.utf8_slice_codeunits( |
| 1593 | + self._data, start=start, stop=stop, step=step |
| 1594 | + ) |
| 1595 | + result = pa.array([None] * self._data.length(), type=self._data.type) |
| 1596 | + result = pc.if_else(not_out_of_bounds, selected, result) |
| 1597 | + return type(self)(result) |
| 1598 | + |
| 1599 | + def _str_join(self, sep: str): |
| 1600 | + return type(self)(pc.binary_join(self._data, sep)) |
| 1601 | + |
| 1602 | + def _str_partition(self, sep: str, expand: bool): |
| 1603 | + raise NotImplementedError( |
| 1604 | + "str.partition not supported with pd.ArrowDtype(pa.string())." |
| 1605 | + ) |
| 1606 | + |
| 1607 | + def _str_rpartition(self, sep: str, expand: bool): |
| 1608 | + raise NotImplementedError( |
| 1609 | + "str.rpartition not supported with pd.ArrowDtype(pa.string())." |
| 1610 | + ) |
| 1611 | + |
| 1612 | + def _str_slice( |
| 1613 | + self, start: int | None = None, stop: int | None = None, step: int | None = None |
| 1614 | + ): |
| 1615 | + if start is None: |
| 1616 | + start = 0 |
| 1617 | + if step is None: |
| 1618 | + step = 1 |
| 1619 | + return type(self)( |
| 1620 | + pc.utf8_slice_codeunits(self._data, start=start, stop=stop, step=step) |
| 1621 | + ) |
| 1622 | + |
| 1623 | + def _str_slice_replace( |
| 1624 | + self, start: int | None = None, stop: int | None = None, repl: str | None = None |
| 1625 | + ): |
| 1626 | + if repl is None: |
| 1627 | + repl = "" |
| 1628 | + if start is None: |
| 1629 | + start = 0 |
| 1630 | + return type(self)(pc.utf8_replace_slice(self._data, start, stop, repl)) |
| 1631 | + |
| 1632 | + def _str_isalnum(self): |
| 1633 | + return type(self)(pc.utf8_is_alnum(self._data)) |
| 1634 | + |
| 1635 | + def _str_isalpha(self): |
| 1636 | + return type(self)(pc.utf8_is_alpha(self._data)) |
| 1637 | + |
| 1638 | + def _str_isdecimal(self): |
| 1639 | + return type(self)(pc.utf8_is_decimal(self._data)) |
| 1640 | + |
| 1641 | + def _str_isdigit(self): |
| 1642 | + return type(self)(pc.utf8_is_digit(self._data)) |
| 1643 | + |
| 1644 | + def _str_islower(self): |
| 1645 | + return type(self)(pc.utf8_is_lower(self._data)) |
| 1646 | + |
| 1647 | + def _str_isnumeric(self): |
| 1648 | + return type(self)(pc.utf8_is_numeric(self._data)) |
| 1649 | + |
| 1650 | + def _str_isspace(self): |
| 1651 | + return type(self)(pc.utf8_is_space(self._data)) |
| 1652 | + |
| 1653 | + def _str_istitle(self): |
| 1654 | + return type(self)(pc.utf8_is_title(self._data)) |
| 1655 | + |
| 1656 | + def _str_capitalize(self): |
| 1657 | + return type(self)(pc.utf8_capitalize(self._data)) |
| 1658 | + |
| 1659 | + def _str_title(self): |
| 1660 | + return type(self)(pc.utf8_title(self._data)) |
| 1661 | + |
| 1662 | + def _str_isupper(self): |
| 1663 | + return type(self)(pc.utf8_is_upper(self._data)) |
| 1664 | + |
| 1665 | + def _str_swapcase(self): |
| 1666 | + return type(self)(pc.utf8_swapcase(self._data)) |
| 1667 | + |
| 1668 | + def _str_len(self): |
| 1669 | + return type(self)(pc.utf8_length(self._data)) |
| 1670 | + |
| 1671 | + def _str_lower(self): |
| 1672 | + return type(self)(pc.utf8_lower(self._data)) |
| 1673 | + |
| 1674 | + def _str_upper(self): |
| 1675 | + return type(self)(pc.utf8_upper(self._data)) |
| 1676 | + |
| 1677 | + def _str_strip(self, to_strip=None): |
| 1678 | + if to_strip is None: |
| 1679 | + result = pc.utf8_trim_whitespace(self._data) |
| 1680 | + else: |
| 1681 | + result = pc.utf8_trim(self._data, characters=to_strip) |
| 1682 | + return type(self)(result) |
| 1683 | + |
| 1684 | + def _str_lstrip(self, to_strip=None): |
| 1685 | + if to_strip is None: |
| 1686 | + result = pc.utf8_ltrim_whitespace(self._data) |
| 1687 | + else: |
| 1688 | + result = pc.utf8_ltrim(self._data, characters=to_strip) |
| 1689 | + return type(self)(result) |
| 1690 | + |
| 1691 | + def _str_rstrip(self, to_strip=None): |
| 1692 | + if to_strip is None: |
| 1693 | + result = pc.utf8_rtrim_whitespace(self._data) |
| 1694 | + else: |
| 1695 | + result = pc.utf8_rtrim(self._data, characters=to_strip) |
| 1696 | + return type(self)(result) |
| 1697 | + |
| 1698 | + def _str_removeprefix(self, prefix: str): |
| 1699 | + raise NotImplementedError( |
| 1700 | + "str.removeprefix not supported with pd.ArrowDtype(pa.string())." |
| 1701 | + ) |
| 1702 | + # TODO: Should work once https://github.com/apache/arrow/issues/14991 is fixed |
| 1703 | + # starts_with = pc.starts_with(self._data, pattern=prefix) |
| 1704 | + # removed = pc.utf8_slice_codeunits(self._data, len(prefix)) |
| 1705 | + # result = pc.if_else(starts_with, removed, self._data) |
| 1706 | + # return type(self)(result) |
| 1707 | + |
| 1708 | + def _str_removesuffix(self, suffix: str): |
| 1709 | + ends_with = pc.ends_with(self._data, pattern=suffix) |
| 1710 | + removed = pc.utf8_slice_codeunits(self._data, 0, stop=-len(suffix)) |
| 1711 | + result = pc.if_else(ends_with, removed, self._data) |
| 1712 | + return type(self)(result) |
| 1713 | + |
| 1714 | + def _str_casefold(self): |
| 1715 | + raise NotImplementedError( |
| 1716 | + "str.casefold not supported with pd.ArrowDtype(pa.string())." |
| 1717 | + ) |
| 1718 | + |
| 1719 | + def _str_encode(self, encoding, errors: str = "strict"): |
| 1720 | + raise NotImplementedError( |
| 1721 | + "str.encode not supported with pd.ArrowDtype(pa.string())." |
| 1722 | + ) |
| 1723 | + |
| 1724 | + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): |
| 1725 | + raise NotImplementedError( |
| 1726 | + "str.extract not supported with pd.ArrowDtype(pa.string())." |
| 1727 | + ) |
| 1728 | + |
| 1729 | + def _str_findall(self, pat, flags: int = 0): |
| 1730 | + raise NotImplementedError( |
| 1731 | + "str.findall not supported with pd.ArrowDtype(pa.string())." |
| 1732 | + ) |
| 1733 | + |
| 1734 | + def _str_get_dummies(self, sep: str = "|"): |
| 1735 | + raise NotImplementedError( |
| 1736 | + "str.get_dummies not supported with pd.ArrowDtype(pa.string())." |
| 1737 | + ) |
| 1738 | + |
| 1739 | + def _str_index(self, sub, start: int = 0, end=None): |
| 1740 | + raise NotImplementedError( |
| 1741 | + "str.index not supported with pd.ArrowDtype(pa.string())." |
| 1742 | + ) |
| 1743 | + |
| 1744 | + def _str_rindex(self, sub, start: int = 0, end=None): |
| 1745 | + raise NotImplementedError( |
| 1746 | + "str.rindex not supported with pd.ArrowDtype(pa.string())." |
| 1747 | + ) |
| 1748 | + |
| 1749 | + def _str_normalize(self, form): |
| 1750 | + raise NotImplementedError( |
| 1751 | + "str.normalize not supported with pd.ArrowDtype(pa.string())." |
| 1752 | + ) |
| 1753 | + |
| 1754 | + def _str_rfind(self, sub, start: int = 0, end=None): |
| 1755 | + raise NotImplementedError( |
| 1756 | + "str.rfind not supported with pd.ArrowDtype(pa.string())." |
| 1757 | + ) |
| 1758 | + |
| 1759 | + def _str_split( |
| 1760 | + self, pat=None, n=-1, expand: bool = False, regex: bool | None = None |
| 1761 | + ): |
| 1762 | + raise NotImplementedError( |
| 1763 | + "str.split not supported with pd.ArrowDtype(pa.string())." |
| 1764 | + ) |
| 1765 | + |
| 1766 | + def _str_rsplit(self, pat=None, n=-1): |
| 1767 | + raise NotImplementedError( |
| 1768 | + "str.rsplit not supported with pd.ArrowDtype(pa.string())." |
| 1769 | + ) |
| 1770 | + |
| 1771 | + def _str_translate(self, table): |
| 1772 | + raise NotImplementedError( |
| 1773 | + "str.translate not supported with pd.ArrowDtype(pa.string())." |
| 1774 | + ) |
| 1775 | + |
| 1776 | + def _str_wrap(self, width, **kwargs): |
| 1777 | + raise NotImplementedError( |
| 1778 | + "str.wrap not supported with pd.ArrowDtype(pa.string())." |
| 1779 | + ) |
| 1780 | + |
1466 | 1781 | @property
|
1467 | 1782 | def _dt_day(self):
|
1468 | 1783 | return type(self)(pc.day(self._data))
|
|
0 commit comments