Skip to content

Commit 2ad36cc

Browse files
Backport PR #54533 on branch 2.1.x (Implement Arrow String Array that is compatible with NumPy semantics) (#54713)
Backport PR #54533: Implement Arrow String Array that is compatible with NumPy semantics Co-authored-by: Patrick Hoefler <[email protected]>
1 parent 5c9b63c commit 2ad36cc

File tree

14 files changed

+273
-53
lines changed

14 files changed

+273
-53
lines changed

pandas/conftest.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1321,6 +1321,7 @@ def nullable_string_dtype(request):
13211321
params=[
13221322
"python",
13231323
pytest.param("pyarrow", marks=td.skip_if_no("pyarrow")),
1324+
pytest.param("pyarrow_numpy", marks=td.skip_if_no("pyarrow")),
13241325
]
13251326
)
13261327
def string_storage(request):
@@ -1329,6 +1330,7 @@ def string_storage(request):
13291330
13301331
* 'python'
13311332
* 'pyarrow'
1333+
* 'pyarrow_numpy'
13321334
"""
13331335
return request.param
13341336

@@ -1380,6 +1382,7 @@ def object_dtype(request):
13801382
"object",
13811383
"string[python]",
13821384
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
1385+
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
13831386
]
13841387
)
13851388
def any_string_dtype(request):
@@ -2000,4 +2003,4 @@ def warsaw(request) -> str:
20002003

20012004
@pytest.fixture()
20022005
def arrow_string_storage():
2003-
return ("pyarrow",)
2006+
return ("pyarrow", "pyarrow_numpy")

pandas/core/arrays/arrow/array.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,10 @@ def __getitem__(self, item: PositionalIndexer):
512512
if isinstance(item, np.ndarray):
513513
if not len(item):
514514
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
515-
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
515+
if self._dtype.name == "string" and self._dtype.storage in (
516+
"pyarrow",
517+
"pyarrow_numpy",
518+
):
516519
pa_dtype = pa.string()
517520
else:
518521
pa_dtype = self._dtype.pyarrow_dtype

pandas/core/arrays/string_.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ class StringDtype(StorageExtensionDtype):
7676
7777
Parameters
7878
----------
79-
storage : {"python", "pyarrow"}, optional
79+
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
8080
If not given, the value of ``pd.options.mode.string_storage``.
8181
8282
Attributes
@@ -108,11 +108,11 @@ def na_value(self) -> libmissing.NAType:
108108
def __init__(self, storage=None) -> None:
109109
if storage is None:
110110
storage = get_option("mode.string_storage")
111-
if storage not in {"python", "pyarrow"}:
111+
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
112112
raise ValueError(
113113
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
114114
)
115-
if storage == "pyarrow" and pa_version_under7p0:
115+
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under7p0:
116116
raise ImportError(
117117
"pyarrow>=7.0.0 is required for PyArrow backed StringArray."
118118
)
@@ -160,6 +160,8 @@ def construct_from_string(cls, string):
160160
return cls(storage="python")
161161
elif string == "string[pyarrow]":
162162
return cls(storage="pyarrow")
163+
elif string == "string[pyarrow_numpy]":
164+
return cls(storage="pyarrow_numpy")
163165
else:
164166
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")
165167

@@ -176,12 +178,17 @@ def construct_array_type( # type: ignore[override]
176178
-------
177179
type
178180
"""
179-
from pandas.core.arrays.string_arrow import ArrowStringArray
181+
from pandas.core.arrays.string_arrow import (
182+
ArrowStringArray,
183+
ArrowStringArrayNumpySemantics,
184+
)
180185

181186
if self.storage == "python":
182187
return StringArray
183-
else:
188+
elif self.storage == "pyarrow":
184189
return ArrowStringArray
190+
else:
191+
return ArrowStringArrayNumpySemantics
185192

186193
def __from_arrow__(
187194
self, array: pyarrow.Array | pyarrow.ChunkedArray
@@ -193,6 +200,10 @@ def __from_arrow__(
193200
from pandas.core.arrays.string_arrow import ArrowStringArray
194201

195202
return ArrowStringArray(array)
203+
elif self.storage == "pyarrow_numpy":
204+
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics
205+
206+
return ArrowStringArrayNumpySemantics(array)
196207
else:
197208
import pyarrow
198209

pandas/core/arrays/string_arrow.py

+135-14
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from functools import partial
34
import re
45
from typing import (
56
TYPE_CHECKING,
@@ -27,6 +28,7 @@
2728
)
2829
from pandas.core.dtypes.missing import isna
2930

31+
from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
3032
from pandas.core.arrays.arrow import ArrowExtensionArray
3133
from pandas.core.arrays.boolean import BooleanDtype
3234
from pandas.core.arrays.integer import Int64Dtype
@@ -113,10 +115,11 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
113115
# error: Incompatible types in assignment (expression has type "StringDtype",
114116
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
115117
_dtype: StringDtype # type: ignore[assignment]
118+
_storage = "pyarrow"
116119

117120
def __init__(self, values) -> None:
118121
super().__init__(values)
119-
self._dtype = StringDtype(storage="pyarrow")
122+
self._dtype = StringDtype(storage=self._storage)
120123

121124
if not pa.types.is_string(self._pa_array.type) and not (
122125
pa.types.is_dictionary(self._pa_array.type)
@@ -144,7 +147,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)
144147

145148
if dtype and not (isinstance(dtype, str) and dtype == "string"):
146149
dtype = pandas_dtype(dtype)
147-
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
150+
assert isinstance(dtype, StringDtype) and dtype.storage in (
151+
"pyarrow",
152+
"pyarrow_numpy",
153+
)
148154

149155
if isinstance(scalars, BaseMaskedArray):
150156
# avoid costly conversion to object dtype in ensure_string_array and
@@ -178,6 +184,10 @@ def insert(self, loc: int, item) -> ArrowStringArray:
178184
raise TypeError("Scalar must be NA or str")
179185
return super().insert(loc, item)
180186

187+
@classmethod
188+
def _result_converter(cls, values, na=None):
189+
return BooleanDtype().__from_arrow__(values)
190+
181191
def _maybe_convert_setitem_value(self, value):
182192
"""Maybe convert value to be pyarrow compatible."""
183193
if is_scalar(value):
@@ -313,7 +323,7 @@ def _str_contains(
313323
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
314324
else:
315325
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
316-
result = BooleanDtype().__from_arrow__(result)
326+
result = self._result_converter(result, na=na)
317327
if not isna(na):
318328
result[isna(result)] = bool(na)
319329
return result
@@ -322,7 +332,7 @@ def _str_startswith(self, pat: str, na=None):
322332
result = pc.starts_with(self._pa_array, pattern=pat)
323333
if not isna(na):
324334
result = result.fill_null(na)
325-
result = BooleanDtype().__from_arrow__(result)
335+
result = self._result_converter(result)
326336
if not isna(na):
327337
result[isna(result)] = bool(na)
328338
return result
@@ -331,7 +341,7 @@ def _str_endswith(self, pat: str, na=None):
331341
result = pc.ends_with(self._pa_array, pattern=pat)
332342
if not isna(na):
333343
result = result.fill_null(na)
334-
result = BooleanDtype().__from_arrow__(result)
344+
result = self._result_converter(result)
335345
if not isna(na):
336346
result[isna(result)] = bool(na)
337347
return result
@@ -369,39 +379,39 @@ def _str_fullmatch(
369379

370380
def _str_isalnum(self):
371381
result = pc.utf8_is_alnum(self._pa_array)
372-
return BooleanDtype().__from_arrow__(result)
382+
return self._result_converter(result)
373383

374384
def _str_isalpha(self):
375385
result = pc.utf8_is_alpha(self._pa_array)
376-
return BooleanDtype().__from_arrow__(result)
386+
return self._result_converter(result)
377387

378388
def _str_isdecimal(self):
379389
result = pc.utf8_is_decimal(self._pa_array)
380-
return BooleanDtype().__from_arrow__(result)
390+
return self._result_converter(result)
381391

382392
def _str_isdigit(self):
383393
result = pc.utf8_is_digit(self._pa_array)
384-
return BooleanDtype().__from_arrow__(result)
394+
return self._result_converter(result)
385395

386396
def _str_islower(self):
387397
result = pc.utf8_is_lower(self._pa_array)
388-
return BooleanDtype().__from_arrow__(result)
398+
return self._result_converter(result)
389399

390400
def _str_isnumeric(self):
391401
result = pc.utf8_is_numeric(self._pa_array)
392-
return BooleanDtype().__from_arrow__(result)
402+
return self._result_converter(result)
393403

394404
def _str_isspace(self):
395405
result = pc.utf8_is_space(self._pa_array)
396-
return BooleanDtype().__from_arrow__(result)
406+
return self._result_converter(result)
397407

398408
def _str_istitle(self):
399409
result = pc.utf8_is_title(self._pa_array)
400-
return BooleanDtype().__from_arrow__(result)
410+
return self._result_converter(result)
401411

402412
def _str_isupper(self):
403413
result = pc.utf8_is_upper(self._pa_array)
404-
return BooleanDtype().__from_arrow__(result)
414+
return self._result_converter(result)
405415

406416
def _str_len(self):
407417
result = pc.utf8_length(self._pa_array)
@@ -433,3 +443,114 @@ def _str_rstrip(self, to_strip=None):
433443
else:
434444
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
435445
return type(self)(result)
446+
447+
448+
class ArrowStringArrayNumpySemantics(ArrowStringArray):
449+
_storage = "pyarrow_numpy"
450+
451+
@classmethod
452+
def _result_converter(cls, values, na=None):
453+
if not isna(na):
454+
values = values.fill_null(bool(na))
455+
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
456+
457+
def __getattribute__(self, item):
458+
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
459+
# creates inheritance problems (Diamond inheritance)
460+
if item in ArrowStringArrayMixin.__dict__ and item != "_pa_array":
461+
return partial(getattr(ArrowStringArrayMixin, item), self)
462+
return super().__getattribute__(item)
463+
464+
def _str_map(
465+
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
466+
):
467+
if dtype is None:
468+
dtype = self.dtype
469+
if na_value is None:
470+
na_value = self.dtype.na_value
471+
472+
mask = isna(self)
473+
arr = np.asarray(self)
474+
475+
if is_integer_dtype(dtype) or is_bool_dtype(dtype):
476+
if is_integer_dtype(dtype):
477+
na_value = np.nan
478+
else:
479+
na_value = False
480+
try:
481+
result = lib.map_infer_mask(
482+
arr,
483+
f,
484+
mask.view("uint8"),
485+
convert=False,
486+
na_value=na_value,
487+
dtype=np.dtype(dtype), # type: ignore[arg-type]
488+
)
489+
return result
490+
491+
except ValueError:
492+
result = lib.map_infer_mask(
493+
arr,
494+
f,
495+
mask.view("uint8"),
496+
convert=False,
497+
na_value=na_value,
498+
)
499+
if convert and result.dtype == object:
500+
result = lib.maybe_convert_objects(result)
501+
return result
502+
503+
elif is_string_dtype(dtype) and not is_object_dtype(dtype):
504+
# i.e. StringDtype
505+
result = lib.map_infer_mask(
506+
arr, f, mask.view("uint8"), convert=False, na_value=na_value
507+
)
508+
result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True)
509+
return type(self)(result)
510+
else:
511+
# This is when the result type is object. We reach this when
512+
# -> We know the result type is truly object (e.g. .encode returns bytes
513+
# or .findall returns a list).
514+
# -> We don't know the result type. E.g. `.get` can return anything.
515+
return lib.map_infer_mask(arr, f, mask.view("uint8"))
516+
517+
def _convert_int_dtype(self, result):
518+
if result.dtype == np.int32:
519+
result = result.astype(np.int64)
520+
return result
521+
522+
def _str_count(self, pat: str, flags: int = 0):
523+
if flags:
524+
return super()._str_count(pat, flags)
525+
result = pc.count_substring_regex(self._pa_array, pat).to_numpy()
526+
return self._convert_int_dtype(result)
527+
528+
def _str_len(self):
529+
result = pc.utf8_length(self._pa_array).to_numpy()
530+
return self._convert_int_dtype(result)
531+
532+
def _str_find(self, sub: str, start: int = 0, end: int | None = None):
533+
if start != 0 and end is not None:
534+
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
535+
result = pc.find_substring(slices, sub)
536+
not_found = pc.equal(result, -1)
537+
offset_result = pc.add(result, end - start)
538+
result = pc.if_else(not_found, result, offset_result)
539+
elif start == 0 and end is None:
540+
slices = self._pa_array
541+
result = pc.find_substring(slices, sub)
542+
else:
543+
return super()._str_find(sub, start, end)
544+
return self._convert_int_dtype(result.to_numpy())
545+
546+
def _cmp_method(self, other, op):
547+
result = super()._cmp_method(other, op)
548+
return result.to_numpy(np.bool_, na_value=False)
549+
550+
def value_counts(self, dropna: bool = True):
551+
from pandas import Series
552+
553+
result = super().value_counts(dropna)
554+
return Series(
555+
result._values.to_numpy(), index=result.index, name=result.name, copy=False
556+
)

pandas/core/config_init.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -500,7 +500,7 @@ def use_inf_as_na_cb(key) -> None:
500500
"string_storage",
501501
"python",
502502
string_storage_doc,
503-
validator=is_one_of_factory(["python", "pyarrow"]),
503+
validator=is_one_of_factory(["python", "pyarrow", "pyarrow_numpy"]),
504504
)
505505

506506

pandas/core/strings/accessor.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,9 @@ def _map_and_wrap(name: str | None, docstring: str | None):
145145
@forbid_nonstring_types(["bytes"], name=name)
146146
def wrapper(self):
147147
result = getattr(self._data.array, f"_str_{name}")()
148-
return self._wrap_result(result)
148+
return self._wrap_result(
149+
result, returns_string=name not in ("isnumeric", "isdecimal")
150+
)
149151

150152
wrapper.__doc__ = docstring
151153
return wrapper

0 commit comments

Comments
 (0)