Skip to content

Implement any and all for pyarrow numpy strings #54591

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
b24afc9
Start new string array
phofl Aug 13, 2023
b306c6f
Add missing methods
phofl Aug 13, 2023
2dbcfb0
Implement Arrow String Array that is compatible with NumPy semantics
phofl Aug 13, 2023
d9e61e5
Move methods
phofl Aug 13, 2023
3188c25
Refactor
phofl Aug 13, 2023
df231f0
Merge remote-tracking branch 'upstream/main' into string_array_numpy_…
phofl Aug 14, 2023
cd19bfb
Refactor
phofl Aug 14, 2023
c73c6b0
Remove
phofl Aug 14, 2023
6b26309
Fix
phofl Aug 14, 2023
da6d67c
Update
phofl Aug 14, 2023
d862eca
Na return value
phofl Aug 16, 2023
4be0ee8
Merge remote-tracking branch 'upstream/main' into string_array_numpy_…
phofl Aug 16, 2023
6cf2639
Fix
phofl Aug 16, 2023
333dcae
Merge remote-tracking branch 'origin/string_array_numpy_semantics' in…
phofl Aug 16, 2023
48bd626
Update
phofl Aug 16, 2023
4f9387a
Implement any and all for pyarrow numpy strings
phofl Aug 16, 2023
606cd71
Fix typing
phofl Aug 21, 2023
a44b042
Merge remote-tracking branch 'upstream/main' into string_array_numpy_…
phofl Aug 21, 2023
6414501
Update
phofl Aug 21, 2023
27b5057
Merge remote-tracking branch 'upstream/main' into any_all
phofl Aug 21, 2023
3fec6d3
Merge remote-tracking branch 'upstream/main' into string_array_numpy_…
phofl Aug 21, 2023
68acc32
Fix
phofl Aug 21, 2023
fbab6fb
Fix
phofl Aug 21, 2023
dd0f9a8
Merge branch 'string_array_numpy_semantics_na_val' into any_all
phofl Aug 21, 2023
68e5f8f
Fix
phofl Aug 21, 2023
0322006
Move test
phofl Aug 21, 2023
8bb52f4
Skip test when no pa
phofl Aug 21, 2023
cc8e6f7
Fix typing
phofl Aug 22, 2023
a3ef88d
Merge remote-tracking branch 'upstream/main' into any_all
phofl Aug 23, 2023
33355a7
Fix tests
phofl Aug 23, 2023
f337cd9
Merge remote-tracking branch 'upstream/main' into any_all
phofl Aug 26, 2023
1facb79
move + rename test
jorisvandenbossche Aug 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ def nullable_string_dtype(request):
params=[
"python",
pytest.param("pyarrow", marks=td.skip_if_no("pyarrow")),
pytest.param("pyarrow_numpy", marks=td.skip_if_no("pyarrow")),
]
)
def string_storage(request):
Expand Down Expand Up @@ -1380,6 +1381,7 @@ def object_dtype(request):
"object",
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
]
)
def any_string_dtype(request):
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,10 @@ def __getitem__(self, item: PositionalIndexer):
if isinstance(item, np.ndarray):
if not len(item):
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
if self._dtype.name == "string" and self._dtype.storage in (
"pyarrow",
"pyarrow_numpy",
):
pa_dtype = pa.string()
else:
pa_dtype = self._dtype.pyarrow_dtype
Expand Down
21 changes: 16 additions & 5 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class StringDtype(StorageExtensionDtype):

Parameters
----------
storage : {"python", "pyarrow"}, optional
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
If not given, the value of ``pd.options.mode.string_storage``.

Attributes
Expand Down Expand Up @@ -108,11 +108,11 @@ def na_value(self) -> libmissing.NAType:
def __init__(self, storage=None) -> None:
if storage is None:
storage = get_option("mode.string_storage")
if storage not in {"python", "pyarrow"}:
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
raise ValueError(
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
)
if storage == "pyarrow" and pa_version_under7p0:
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under7p0:
raise ImportError(
"pyarrow>=7.0.0 is required for PyArrow backed StringArray."
)
Expand Down Expand Up @@ -160,6 +160,8 @@ def construct_from_string(cls, string):
return cls(storage="python")
elif string == "string[pyarrow]":
return cls(storage="pyarrow")
elif string == "string[pyarrow_numpy]":
return cls(storage="pyarrow_numpy")
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")

Expand All @@ -176,12 +178,17 @@ def construct_array_type( # type: ignore[override]
-------
type
"""
from pandas.core.arrays.string_arrow import ArrowStringArray
from pandas.core.arrays.string_arrow import (
ArrowStringArray,
ArrowStringArrayNumpySemantics,
)

if self.storage == "python":
return StringArray
else:
elif self.storage == "pyarrow":
return ArrowStringArray
else:
return ArrowStringArrayNumpySemantics

def __from_arrow__(
self, array: pyarrow.Array | pyarrow.ChunkedArray
Expand All @@ -193,6 +200,10 @@ def __from_arrow__(
from pandas.core.arrays.string_arrow import ArrowStringArray

return ArrowStringArray(array)
elif self.storage == "pyarrow_numpy":
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics

return ArrowStringArrayNumpySemantics(array)
else:
import pyarrow

Expand Down
161 changes: 147 additions & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
import re
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -27,6 +28,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
from pandas.core.arrays.arrow import ArrowExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype
Expand Down Expand Up @@ -113,10 +115,11 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
# error: Incompatible types in assignment (expression has type "StringDtype",
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
_dtype: StringDtype # type: ignore[assignment]
_storage = "pyarrow"

def __init__(self, values) -> None:
super().__init__(values)
self._dtype = StringDtype(storage="pyarrow")
self._dtype = StringDtype(storage=self._storage)

if not pa.types.is_string(self._pa_array.type) and not (
pa.types.is_dictionary(self._pa_array.type)
Expand Down Expand Up @@ -144,7 +147,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)

if dtype and not (isinstance(dtype, str) and dtype == "string"):
dtype = pandas_dtype(dtype)
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
assert isinstance(dtype, StringDtype) and dtype.storage in (
"pyarrow",
"pyarrow_numpy",
)

if isinstance(scalars, BaseMaskedArray):
# avoid costly conversion to object dtype in ensure_string_array and
Expand Down Expand Up @@ -178,6 +184,10 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

@staticmethod
def _result_converter(values, **kwargs):
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
"""Maybe convert value to be pyarrow compatible."""
if is_scalar(value):
Expand Down Expand Up @@ -313,7 +323,7 @@ def _str_contains(
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
else:
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result, na=na)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -322,7 +332,7 @@ def _str_startswith(self, pat: str, na=None):
result = pc.starts_with(self._pa_array, pattern=pat)
if not isna(na):
result = result.fill_null(na)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -331,7 +341,7 @@ def _str_endswith(self, pat: str, na=None):
result = pc.ends_with(self._pa_array, pattern=pat)
if not isna(na):
result = result.fill_null(na)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand Down Expand Up @@ -369,39 +379,39 @@ def _str_fullmatch(

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isalpha(self):
result = pc.utf8_is_alpha(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isdecimal(self):
result = pc.utf8_is_decimal(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isdigit(self):
result = pc.utf8_is_digit(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_islower(self):
result = pc.utf8_is_lower(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isnumeric(self):
result = pc.utf8_is_numeric(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isspace(self):
result = pc.utf8_is_space(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_istitle(self):
result = pc.utf8_is_title(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
Expand Down Expand Up @@ -433,3 +443,126 @@ def _str_rstrip(self, to_strip=None):
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)


class ArrowStringArrayNumpySemantics(ArrowStringArray):
# _result_converter = lambda _, result: result.to_numpy(na_value=np.nan)
_storage = "pyarrow_numpy"

@staticmethod
def _result_converter(values, na=None):
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)

def __getattribute__(self, item):
if item in ArrowStringArrayMixin.__dict__ and item != "_pa_array":
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False
try:
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(dtype),
)
return result

except ValueError:
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
)
if convert and result.dtype == object:
result = lib.maybe_convert_objects(result)
return result

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _convert_int_dtype(self, result):
if result.dtype == np.int32:
result = result.astype(np.int64)
return result

def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
result = pc.count_substring_regex(self._pa_array, pat).to_numpy()
return self._convert_int_dtype(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array).to_numpy()
return self._convert_int_dtype(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
return super()._str_find(sub, start, end)
return self._convert_int_dtype(result.to_numpy())

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
return result.to_numpy(na_value=False)

def value_counts(self, dropna: bool = True):
from pandas import Series

result = super().value_counts(dropna)
return Series(
result._values.to_numpy(), index=result.index, name=result.name, copy=False
)

def _reduce(
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
):
if name in ["any", "all"]:
arr = pc.and_kleene(
pc.invert(pc.is_null(self._pa_array)), pc.not_equal(self._pa_array, "")
)
return ArrowExtensionArray(arr)._reduce(
name, skipna=skipna, keepdims=keepdims, **kwargs
)
else:
return super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
2 changes: 1 addition & 1 deletion pandas/core/config_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def use_inf_as_na_cb(key) -> None:
"string_storage",
"python",
string_storage_doc,
validator=is_one_of_factory(["python", "pyarrow"]),
validator=is_one_of_factory(["python", "pyarrow", "pyarrow_numpy"]),
)


Expand Down
4 changes: 3 additions & 1 deletion pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def _map_and_wrap(name: str | None, docstring: str | None):
@forbid_nonstring_types(["bytes"], name=name)
def wrapper(self):
result = getattr(self._data.array, f"_str_{name}")()
return self._wrap_result(result)
return self._wrap_result(
result, returns_string=name not in ("isnumeric", "isdecimal")
)

wrapper.__doc__ = docstring
return wrapper
Expand Down
Loading