Skip to content

[ArrowStringArray] REF: str.extract dispatch to array #41372

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def map_infer_mask(
convert: bool = ...,
na_value: Any = ...,
dtype: np.dtype = ...,
out: np.ndarray = ...,
) -> np.ndarray: ...

def indices_fast(
Expand Down
10 changes: 8 additions & 2 deletions pandas/_libs/lib.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2536,7 +2536,8 @@ no_default = NoDefault.no_default # Sentinel indicating the default value.
@cython.boundscheck(False)
@cython.wraparound(False)
def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=True,
object na_value=no_default, cnp.dtype dtype=np.dtype(object)
object na_value=no_default, cnp.dtype dtype=np.dtype(object),
ndarray out=None
) -> np.ndarray:
"""
Substitute for np.vectorize with pandas-friendly dtype inference.
Expand All @@ -2554,6 +2555,8 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr
input value is used
dtype : numpy.dtype
The numpy dtype to use for the result ndarray.
out : ndarray
The result.

Returns
-------
Expand All @@ -2565,7 +2568,10 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr
object val

n = len(arr)
result = np.empty(n, dtype=dtype)
if out is not None:
result = out
else:
result = np.empty(n, dtype=dtype)
for i in range(n):
if mask[i]:
if na_value is no_default:
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2453,7 +2453,9 @@ def replace(self, to_replace, value, inplace: bool = False):

# ------------------------------------------------------------------------
# String methods interface
def _str_map(self, f, na_value=np.nan, dtype=np.dtype("object")):
def _str_map(
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
):
# Optimization to apply the callable `f` to the categories once
# and rebuild the result by `take`ing from the result with the codes.
# Returns the same type as the object-dtype implementation though.
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,9 @@ def _cmp_method(self, other, op):
# String methods interface
_str_na_value = StringDtype.na_value

def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
from pandas.arrays import BooleanArray

if dtype is None:
Expand Down
4 changes: 3 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,9 @@ def value_counts(self, dropna: bool = True) -> Series:

_str_na_value = ArrowStringDtype.na_value

def _str_map(self, f, na_value=None, dtype: Dtype | None = None):
def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
# TODO: de-duplicate with StringArray method. This method is moreless copy and
# paste.

Expand Down
106 changes: 32 additions & 74 deletions pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
import numpy as np

import pandas._libs.lib as lib
from pandas._typing import (
ArrayLike,
FrameOrSeriesUnion,
)
from pandas._typing import FrameOrSeriesUnion
from pandas.util._decorators import Appender

from pandas.core.dtypes.common import (
Expand Down Expand Up @@ -162,7 +159,6 @@ class StringMethods(NoNewAttributesMixin):
# TODO: Dispatch all the methods
# Currently the following are not dispatched to the array
# * cat
# * extract
# * extractall

def __init__(self, data):
Expand Down Expand Up @@ -245,7 +241,7 @@ def _wrap_result(
self,
result,
name=None,
expand=None,
expand: Optional[bool] = None,
fill_value=np.nan,
returns_string=True,
):
Expand Down Expand Up @@ -2383,8 +2379,36 @@ def extract(
if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex):
raise ValueError("only one regex group is supported with Index")

# TODO: dispatch
return str_extract(self, pat, flags, expand=expand)
result = self._data.array._str_extract(pat, flags, expand)
returns_df = regex.groups > 1 or expand

name = _get_group_names(regex) if returns_df else _get_single_group_name(regex)

# extract is inconsistent for Indexes when expand is True. To avoid special
# casing _wrap_result we handle that case here
if expand and isinstance(self._data, ABCIndex):
from pandas import DataFrame

# if expand is True, name is a list of column names
assert isinstance(name, list) # for mypy
return DataFrame(result, columns=name, dtype=object)

# bypass padding code in _wrap_result
expand_kwarg: Optional[bool]
if returns_df:
if is_object_dtype(result):
if regex.groups == 1:
result = result.reshape(1, -1).T
if result.size == 0:
expand_kwarg = True
else:
expand_kwarg = None
else:
expand_kwarg = True
else:
expand_kwarg = False

return self._wrap_result(result, name=name, expand=expand_kwarg)

@forbid_nonstring_types(["bytes"])
def extractall(self, pat, flags=0):
Expand Down Expand Up @@ -3071,72 +3095,6 @@ def _get_group_names(regex: Pattern) -> List[Hashable]:
return [names.get(1 + i, i) for i in range(regex.groups)]


def _str_extract(arr: ArrayLike, pat: str, flags=0, expand: bool = True):
"""
Find groups in each string in the array using passed regular expression.

Returns
-------
np.ndarray or list of lists is expand is True
"""
regex = re.compile(pat, flags=flags)

empty_row = [np.nan] * regex.groups

def f(x):
if not isinstance(x, str):
return empty_row
m = regex.search(x)
if m:
return [np.nan if item is None else item for item in m.groups()]
else:
return empty_row

if expand:
return [f(val) for val in np.asarray(arr)]

return np.array([f(val)[0] for val in np.asarray(arr)], dtype=object)


def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool = True):
from pandas import (
DataFrame,
array as pd_array,
)

obj = accessor._data
result_dtype = _result_dtype(obj)
regex = re.compile(pat, flags=flags)
returns_df = regex.groups > 1 or expand

if returns_df:
name = None
columns = _get_group_names(regex)

if obj.array.size == 0:
result = DataFrame(columns=columns, dtype=result_dtype)

else:
result_list = _str_extract(obj.array, pat, flags=flags, expand=returns_df)

result_index: Optional["Index"]
if isinstance(obj, ABCSeries):
result_index = obj.index
else:
result_index = None

result = DataFrame(
result_list, columns=columns, index=result_index, dtype=result_dtype
)

else:
name = _get_single_group_name(regex)
result_arr = _str_extract(obj.array, pat, flags=flags, expand=returns_df)
# not dispatching, so we have to reconstruct here.
result = pd_array(result_arr, dtype=result_dtype)
return accessor._wrap_result(result, name=name)


def str_extractall(arr, pat, flags=0):
regex = re.compile(pat, flags=flags)
# the regex must contain capture groups.
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/strings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,7 @@ def _str_split(self, pat=None, n=-1, expand=False):
@abc.abstractmethod
def _str_rsplit(self, pat=None, n=-1):
pass

@abc.abstractmethod
def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
pass
61 changes: 57 additions & 4 deletions pandas/core/strings/object_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
)

from pandas.core.dtypes.common import (
is_object_dtype,
is_re,
is_scalar,
)
Expand All @@ -38,7 +39,9 @@ def __len__(self):
# For typing, _str_map relies on the object being sized.
raise NotImplementedError

def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None):
def _str_map(
self, f, na_value=None, dtype: Optional[Dtype] = None, convert: bool = True
):
"""
Map a callable over valid element of the array.

Expand All @@ -53,6 +56,8 @@ def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None):
for object-dtype and Categorical and ``pd.NA`` for StringArray.
dtype : Dtype, optional
The dtype of the result array.
convert : bool, default True
Whether to call `maybe_convert_objects` on the resulting ndarray
"""
if dtype is None:
dtype = np.dtype("object")
Expand All @@ -66,9 +71,9 @@ def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None):

arr = np.asarray(self, dtype=object)
mask = isna(arr)
convert = not np.all(mask)
map_convert = convert and not np.all(mask)
try:
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), convert)
result = lib.map_infer_mask(arr, f, mask.view(np.uint8), map_convert)
except (TypeError, AttributeError) as e:
# Reraise the exception if callable `f` got wrong number of args.
# The user may want to be warned by this, instead of getting NaN
Expand All @@ -94,7 +99,7 @@ def g(x):
return result
if na_value is not np.nan:
np.putmask(result, mask, na_value)
if result.dtype == object:
if convert and result.dtype == object:
result = lib.maybe_convert_objects(result)
return result

Expand Down Expand Up @@ -408,3 +413,51 @@ def _str_lstrip(self, to_strip=None):

def _str_rstrip(self, to_strip=None):
return self._str_map(lambda x: x.rstrip(to_strip))

def _str_extract(self, pat: str, flags: int = 0, expand: bool = True):
regex = re.compile(pat, flags=flags)
na_value = self._str_na_value

if regex.groups == 1:

def f(x):
m = regex.search(x)
return m.groups()[0] if m else na_value

return self._str_map(f, convert=False)
else:
out = np.empty((len(self), regex.groups), dtype=object)

if is_object_dtype(self):

def f(x):
if not isinstance(x, str):
return na_value
m = regex.search(x)
if m:
return [
na_value if item is None else item for item in m.groups()
]
else:
return na_value

else:

def f(x):
m = regex.search(x)
if m:
return [
na_value if item is None else item for item in m.groups()
]
else:
return na_value

result = lib.map_infer_mask(
np.asarray(self),
f,
mask=isna(self).view("uint8"),
convert=False,
na_value=na_value,
out=out,
)
return result