Skip to content

Commit 07d8ee3

Browse files
simonjayhawkinsyeshsurya
authored andcommitted
ENH: [ArrowStringArray] Enable the string methods for the arrow-backed StringArray (pandas-dev#40708)
1 parent 12c9eff commit 07d8ee3

File tree

3 files changed

+67
-221
lines changed

3 files changed

+67
-221
lines changed

pandas/core/arrays/string_arrow.py

+23-179
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,12 @@
11
from __future__ import annotations
22

33
from distutils.version import LooseVersion
4-
import re
54
from typing import (
65
TYPE_CHECKING,
76
Any,
87
Sequence,
98
cast,
109
)
11-
import warnings
1210

1311
import numpy as np
1412

@@ -27,21 +25,22 @@
2725

2826
from pandas.core.dtypes.base import ExtensionDtype
2927
from pandas.core.dtypes.common import (
30-
is_array_like,
31-
is_bool_dtype,
32-
is_integer,
33-
is_integer_dtype,
3428
is_object_dtype,
35-
is_scalar,
3629
is_string_dtype,
3730
)
3831
from pandas.core.dtypes.dtypes import register_extension_dtype
3932
from pandas.core.dtypes.missing import isna
4033

34+
from pandas.api.types import (
35+
is_array_like,
36+
is_bool_dtype,
37+
is_integer,
38+
is_integer_dtype,
39+
is_scalar,
40+
)
4141
from pandas.core import missing
4242
from pandas.core.arraylike import OpsMixin
4343
from pandas.core.arrays.base import ExtensionArray
44-
from pandas.core.arrays.boolean import BooleanDtype
4544
from pandas.core.indexers import (
4645
check_array_indexer,
4746
validate_indices,
@@ -230,21 +229,10 @@ def _chk_pyarrow_available(cls) -> None:
230229

231230
@classmethod
232231
def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False):
233-
from pandas.core.arrays.masked import BaseMaskedArray
234-
235232
cls._chk_pyarrow_available()
236-
237-
if isinstance(scalars, BaseMaskedArray):
238-
# avoid costly conversion to object dtype in ensure_string_array and
239-
# numerical issues with Float32Dtype
240-
na_values = scalars._mask
241-
result = scalars._data
242-
result = lib.ensure_string_array(result, copy=copy, convert_na_value=False)
243-
return cls(pa.array(result, mask=na_values, type=pa.string()))
244-
245-
# convert non-na-likes to str
246-
result = lib.ensure_string_array(scalars, copy=copy)
247-
return cls(pa.array(result, type=pa.string(), from_pandas=True))
233+
# convert non-na-likes to str, and nan-likes to ArrowStringDtype.na_value
234+
scalars = lib.ensure_string_array(scalars, copy=False)
235+
return cls(pa.array(scalars, type=pa.string(), from_pandas=True))
248236

249237
@classmethod
250238
def _from_sequence_of_strings(
@@ -433,8 +421,10 @@ def fillna(self, value=None, method=None, limit=None):
433421
if mask.any():
434422
if method is not None:
435423
func = missing.get_fill_func(method)
424+
# error: Argument 1 to "to_numpy" of "ArrowStringArray" has incompatible
425+
# type "Type[object]"; expected "Union[str, dtype[Any], None]"
436426
new_values, _ = func(
437-
self.to_numpy("object"),
427+
self.to_numpy(object), # type: ignore[arg-type]
438428
limit=limit,
439429
mask=mask,
440430
)
@@ -687,18 +677,13 @@ def value_counts(self, dropna: bool = True) -> Series:
687677

688678
vc = self._data.value_counts()
689679

690-
values = vc.field(0)
691-
counts = vc.field(1)
692-
if dropna and self._data.null_count > 0:
693-
mask = values.is_valid()
694-
values = values.filter(mask)
695-
counts = counts.filter(mask)
696-
680+
# Index cannot hold ExtensionArrays yet
681+
index = Index(type(self)(vc.field(0)).astype(object))
697682
# No missing values so we can adhere to the interface and return a numpy array.
698-
counts = np.array(counts)
683+
counts = np.array(vc.field(1))
699684

700-
# Index cannot hold ExtensionArrays yet
701-
index = Index(type(self)(values)).astype(object)
685+
if dropna and self._data.null_count > 0:
686+
raise NotImplementedError("yo")
702687

703688
return Series(counts, index=index).astype("Int64")
704689

@@ -751,7 +736,11 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
751736
if not na_value_is_na:
752737
mask[:] = False
753738

754-
return constructor(result, mask)
739+
# error: Argument 1 to "IntegerArray" has incompatible type
740+
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
741+
# error: Argument 1 to "BooleanArray" has incompatible type
742+
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
743+
return constructor(result, mask) # type: ignore[arg-type]
755744

756745
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
757746
# i.e. StringDtype
@@ -765,148 +754,3 @@ def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
765754
# or .findall returns a list).
766755
# -> We don't know the result type. E.g. `.get` can return anything.
767756
return lib.map_infer_mask(arr, f, mask.view("uint8"))
768-
769-
def _str_contains(self, pat, case=True, flags=0, na=np.nan, regex=True):
770-
if flags:
771-
return super()._str_contains(pat, case, flags, na, regex)
772-
773-
if regex:
774-
# match_substring_regex added in pyarrow 4.0.0
775-
if hasattr(pc, "match_substring_regex") and case:
776-
if re.compile(pat).groups:
777-
warnings.warn(
778-
"This pattern has match groups. To actually get the "
779-
"groups, use str.extract.",
780-
UserWarning,
781-
stacklevel=3,
782-
)
783-
result = pc.match_substring_regex(self._data, pat)
784-
else:
785-
return super()._str_contains(pat, case, flags, na, regex)
786-
else:
787-
if case:
788-
result = pc.match_substring(self._data, pat)
789-
else:
790-
result = pc.match_substring(pc.utf8_upper(self._data), pat.upper())
791-
result = BooleanDtype().__from_arrow__(result)
792-
if not isna(na):
793-
result[isna(result)] = bool(na)
794-
return result
795-
796-
def _str_startswith(self, pat, na=None):
797-
# match_substring_regex added in pyarrow 4.0.0
798-
if hasattr(pc, "match_substring_regex"):
799-
result = pc.match_substring_regex(self._data, "^" + re.escape(pat))
800-
result = BooleanDtype().__from_arrow__(result)
801-
if not isna(na):
802-
result[isna(result)] = bool(na)
803-
return result
804-
else:
805-
return super()._str_startswith(pat, na)
806-
807-
def _str_endswith(self, pat, na=None):
808-
# match_substring_regex added in pyarrow 4.0.0
809-
if hasattr(pc, "match_substring_regex"):
810-
result = pc.match_substring_regex(self._data, re.escape(pat) + "$")
811-
result = BooleanDtype().__from_arrow__(result)
812-
if not isna(na):
813-
result[isna(result)] = bool(na)
814-
return result
815-
else:
816-
return super()._str_endswith(pat, na)
817-
818-
def _str_isalnum(self):
819-
if hasattr(pc, "utf8_is_alnum"):
820-
result = pc.utf8_is_alnum(self._data)
821-
return BooleanDtype().__from_arrow__(result)
822-
else:
823-
return super()._str_isalnum()
824-
825-
def _str_isalpha(self):
826-
if hasattr(pc, "utf8_is_alpha"):
827-
result = pc.utf8_is_alpha(self._data)
828-
return BooleanDtype().__from_arrow__(result)
829-
else:
830-
return super()._str_isalpha()
831-
832-
def _str_isdecimal(self):
833-
if hasattr(pc, "utf8_is_decimal"):
834-
result = pc.utf8_is_decimal(self._data)
835-
return BooleanDtype().__from_arrow__(result)
836-
else:
837-
return super()._str_isdecimal()
838-
839-
def _str_isdigit(self):
840-
if hasattr(pc, "utf8_is_digit"):
841-
result = pc.utf8_is_digit(self._data)
842-
return BooleanDtype().__from_arrow__(result)
843-
else:
844-
return super()._str_isdigit()
845-
846-
def _str_islower(self):
847-
if hasattr(pc, "utf8_is_lower"):
848-
result = pc.utf8_is_lower(self._data)
849-
return BooleanDtype().__from_arrow__(result)
850-
else:
851-
return super()._str_islower()
852-
853-
def _str_isnumeric(self):
854-
if hasattr(pc, "utf8_is_numeric"):
855-
result = pc.utf8_is_numeric(self._data)
856-
return BooleanDtype().__from_arrow__(result)
857-
else:
858-
return super()._str_isnumeric()
859-
860-
def _str_isspace(self):
861-
if hasattr(pc, "utf8_is_space"):
862-
result = pc.utf8_is_space(self._data)
863-
return BooleanDtype().__from_arrow__(result)
864-
else:
865-
return super()._str_isspace()
866-
867-
def _str_istitle(self):
868-
if hasattr(pc, "utf8_is_title"):
869-
result = pc.utf8_is_title(self._data)
870-
return BooleanDtype().__from_arrow__(result)
871-
else:
872-
return super()._str_istitle()
873-
874-
def _str_isupper(self):
875-
if hasattr(pc, "utf8_is_upper"):
876-
result = pc.utf8_is_upper(self._data)
877-
return BooleanDtype().__from_arrow__(result)
878-
else:
879-
return super()._str_isupper()
880-
881-
def _str_lower(self):
882-
return type(self)(pc.utf8_lower(self._data))
883-
884-
def _str_upper(self):
885-
return type(self)(pc.utf8_upper(self._data))
886-
887-
def _str_strip(self, to_strip=None):
888-
if to_strip is None:
889-
if hasattr(pc, "utf8_trim_whitespace"):
890-
return type(self)(pc.utf8_trim_whitespace(self._data))
891-
else:
892-
if hasattr(pc, "utf8_trim"):
893-
return type(self)(pc.utf8_trim(self._data, characters=to_strip))
894-
return super()._str_strip(to_strip)
895-
896-
def _str_lstrip(self, to_strip=None):
897-
if to_strip is None:
898-
if hasattr(pc, "utf8_ltrim_whitespace"):
899-
return type(self)(pc.utf8_ltrim_whitespace(self._data))
900-
else:
901-
if hasattr(pc, "utf8_ltrim"):
902-
return type(self)(pc.utf8_ltrim(self._data, characters=to_strip))
903-
return super()._str_lstrip(to_strip)
904-
905-
def _str_rstrip(self, to_strip=None):
906-
if to_strip is None:
907-
if hasattr(pc, "utf8_rtrim_whitespace"):
908-
return type(self)(pc.utf8_rtrim_whitespace(self._data))
909-
else:
910-
if hasattr(pc, "utf8_rtrim"):
911-
return type(self)(pc.utf8_rtrim(self._data, characters=to_strip))
912-
return super()._str_rstrip(to_strip)

pandas/tests/arrays/string_/test_string.py

+29-21
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
"""
2-
This module tests the functionality of StringArray and ArrowStringArray.
3-
Tests for the str accessors are in pandas/tests/strings/test_string_array.py
4-
"""
1+
import operator
52

63
import numpy as np
74
import pytest
@@ -91,6 +88,23 @@ def test_setitem_with_scalar_string(dtype):
9188
tm.assert_extension_array_equal(arr, expected)
9289

9390

91+
@pytest.mark.parametrize(
92+
"input, method",
93+
[
94+
(["a", "b", "c"], operator.methodcaller("capitalize")),
95+
(["a b", "a bc. de"], operator.methodcaller("capitalize")),
96+
],
97+
)
98+
def test_string_methods(input, method, dtype):
99+
a = pd.Series(input, dtype=dtype)
100+
b = pd.Series(input, dtype="object")
101+
result = method(a.str)
102+
expected = method(b.str)
103+
104+
assert result.dtype.name == dtype
105+
tm.assert_series_equal(result.astype(object), expected)
106+
107+
94108
def test_astype_roundtrip(dtype, request):
95109
if dtype == "arrow_string":
96110
reason = "ValueError: Could not convert object to NumPy datetime"
@@ -476,23 +490,12 @@ def test_arrow_roundtrip(dtype, dtype_object):
476490
assert result.loc[2, "a"] is pd.NA
477491

478492

479-
@td.skip_if_no("pyarrow", min_version="0.15.1.dev")
480-
def test_arrow_load_from_zero_chunks(dtype, dtype_object):
481-
# GH-41040
482-
import pyarrow as pa
483-
484-
data = pd.array([], dtype=dtype)
485-
df = pd.DataFrame({"a": data})
486-
table = pa.table(df)
487-
assert table.field("a").type == "string"
488-
# Instantiate the same table with no chunks at all
489-
table = pa.table([pa.chunked_array([], type=pa.string())], schema=table.schema)
490-
result = table.to_pandas()
491-
assert isinstance(result["a"].dtype, dtype_object)
492-
tm.assert_frame_equal(result, df)
493-
493+
def test_value_counts_na(dtype, request):
494+
if dtype == "arrow_string":
495+
reason = "TypeError: boolean value of NA is ambiguous"
496+
mark = pytest.mark.xfail(reason=reason)
497+
request.node.add_marker(mark)
494498

495-
def test_value_counts_na(dtype):
496499
arr = pd.array(["a", "b", "a", pd.NA], dtype=dtype)
497500
result = arr.value_counts(dropna=False)
498501
expected = pd.Series([2, 1, 1], index=["a", "b", pd.NA], dtype="Int64")
@@ -503,7 +506,12 @@ def test_value_counts_na(dtype):
503506
tm.assert_series_equal(result, expected)
504507

505508

506-
def test_value_counts_with_normalize(dtype):
509+
def test_value_counts_with_normalize(dtype, request):
510+
if dtype == "arrow_string":
511+
reason = "TypeError: boolean value of NA is ambiguous"
512+
mark = pytest.mark.xfail(reason=reason)
513+
request.node.add_marker(mark)
514+
507515
s = pd.Series(["a", "b", "a", pd.NA], dtype=dtype)
508516
result = s.value_counts(normalize=True)
509517
expected = pd.Series([2, 1], index=["a", "b"], dtype="Float64") / 3

0 commit comments

Comments
 (0)