diff --git a/pandas/_libs/lib.pyi b/pandas/_libs/lib.pyi index 9dbc47f1d40f7..902d1b7245b46 100644 --- a/pandas/_libs/lib.pyi +++ b/pandas/_libs/lib.pyi @@ -196,6 +196,7 @@ def map_infer_mask( convert: bool = ..., na_value: Any = ..., dtype: np.dtype = ..., + out: np.ndarray = ..., ) -> np.ndarray: ... def indices_fast( diff --git a/pandas/_libs/lib.pyx b/pandas/_libs/lib.pyx index e1cb744c7033c..58bd6cdc54b7d 100644 --- a/pandas/_libs/lib.pyx +++ b/pandas/_libs/lib.pyx @@ -2536,7 +2536,8 @@ no_default = NoDefault.no_default # Sentinel indicating the default value. @cython.boundscheck(False) @cython.wraparound(False) def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=True, - object na_value=no_default, cnp.dtype dtype=np.dtype(object) + object na_value=no_default, cnp.dtype dtype=np.dtype(object), + ndarray out=None ) -> np.ndarray: """ Substitute for np.vectorize with pandas-friendly dtype inference. @@ -2554,6 +2555,8 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr input value is used dtype : numpy.dtype The numpy dtype to use for the result ndarray. + out : ndarray + The result. Returns ------- @@ -2565,7 +2568,10 @@ def map_infer_mask(ndarray arr, object f, const uint8_t[:] mask, bint convert=Tr object val n = len(arr) - result = np.empty(n, dtype=dtype) + if out is not None: + result = out + else: + result = np.empty(n, dtype=dtype) for i in range(n): if mask[i]: if na_value is no_default: diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index cb8a08f5668ac..95d9409b265ce 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -2453,7 +2453,9 @@ def replace(self, to_replace, value, inplace: bool = False): # ------------------------------------------------------------------------ # String methods interface - def _str_map(self, f, na_value=np.nan, dtype=np.dtype("object")): + def _str_map( + self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True + ): # Optimization to apply the callable `f` to the categories once # and rebuild the result by `take`ing from the result with the codes. # Returns the same type as the object-dtype implementation though. diff --git a/pandas/core/arrays/string_.py b/pandas/core/arrays/string_.py index 74ca5130ca322..ab1dadf4d2dfa 100644 --- a/pandas/core/arrays/string_.py +++ b/pandas/core/arrays/string_.py @@ -410,7 +410,9 @@ def _cmp_method(self, other, op): # String methods interface _str_na_value = StringDtype.na_value - def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): from pandas.arrays import BooleanArray if dtype is None: diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index d5ee28eb7017e..1e2f382b2976f 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -741,7 +741,9 @@ def value_counts(self, dropna: bool = True) -> Series: _str_na_value = ArrowStringDtype.na_value - def _str_map(self, f, na_value=None, dtype: Dtype | None = None): + def _str_map( + self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True + ): # TODO: de-duplicate with StringArray method. This method is moreless copy and # paste. diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 43df34a7ecbb2..a56c62074cc2a 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -15,10 +15,7 @@ import numpy as np import pandas._libs.lib as lib -from pandas._typing import ( - ArrayLike, - FrameOrSeriesUnion, -) +from pandas._typing import FrameOrSeriesUnion from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -162,7 +159,6 @@ class StringMethods(NoNewAttributesMixin): # TODO: Dispatch all the methods # Currently the following are not dispatched to the array # * cat - # * extract # * extractall def __init__(self, data): @@ -245,7 +241,7 @@ def _wrap_result( self, result, name=None, - expand=None, + expand: Optional[bool] = None, fill_value=np.nan, returns_string=True, ): @@ -2383,8 +2379,36 @@ def extract( if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex): raise ValueError("only one regex group is supported with Index") - # TODO: dispatch - return str_extract(self, pat, flags, expand=expand) + result = self._data.array._str_extract(pat, flags, expand) + returns_df = regex.groups > 1 or expand + + name = _get_group_names(regex) if returns_df else _get_single_group_name(regex) + + # extract is inconsistent for Indexes when expand is True. To avoid special + # casing _wrap_result we handle that case here + if expand and isinstance(self._data, ABCIndex): + from pandas import DataFrame + + # if expand is True, name is a list of column names + assert isinstance(name, list) # for mypy + return DataFrame(result, columns=name, dtype=object) + + # bypass padding code in _wrap_result + expand_kwarg: Optional[bool] + if returns_df: + if is_object_dtype(result): + if regex.groups == 1: + result = result.reshape(1, -1).T + if result.size == 0: + expand_kwarg = True + else: + expand_kwarg = None + else: + expand_kwarg = True + else: + expand_kwarg = False + + return self._wrap_result(result, name=name, expand=expand_kwarg) @forbid_nonstring_types(["bytes"]) def extractall(self, pat, flags=0): @@ -3071,72 +3095,6 @@ def _get_group_names(regex: Pattern) -> List[Hashable]: return [names.get(1 + i, i) for i in range(regex.groups)] -def _str_extract(arr: ArrayLike, pat: str, flags=0, expand: bool = True): - """ - Find groups in each string in the array using passed regular expression. - - Returns - ------- - np.ndarray or list of lists is expand is True - """ - regex = re.compile(pat, flags=flags) - - empty_row = [np.nan] * regex.groups - - def f(x): - if not isinstance(x, str): - return empty_row - m = regex.search(x) - if m: - return [np.nan if item is None else item for item in m.groups()] - else: - return empty_row - - if expand: - return [f(val) for val in np.asarray(arr)] - - return np.array([f(val)[0] for val in np.asarray(arr)], dtype=object) - - -def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool = True): - from pandas import ( - DataFrame, - array as pd_array, - ) - - obj = accessor._data - result_dtype = _result_dtype(obj) - regex = re.compile(pat, flags=flags) - returns_df = regex.groups > 1 or expand - - if returns_df: - name = None - columns = _get_group_names(regex) - - if obj.array.size == 0: - result = DataFrame(columns=columns, dtype=result_dtype) - - else: - result_list = _str_extract(obj.array, pat, flags=flags, expand=returns_df) - - result_index: Optional["Index"] - if isinstance(obj, ABCSeries): - result_index = obj.index - else: - result_index = None - - result = DataFrame( - result_list, columns=columns, index=result_index, dtype=result_dtype - ) - - else: - name = _get_single_group_name(regex) - result_arr = _str_extract(obj.array, pat, flags=flags, expand=returns_df) - # not dispatching, so we have to reconstruct here. - result = pd_array(result_arr, dtype=result_dtype) - return accessor._wrap_result(result, name=name) - - def str_extractall(arr, pat, flags=0): regex = re.compile(pat, flags=flags) # the regex must contain capture groups. diff --git a/pandas/core/strings/base.py b/pandas/core/strings/base.py index a77f8861a7c02..9cd0f25c1055a 100644 --- a/pandas/core/strings/base.py +++ b/pandas/core/strings/base.py @@ -222,3 +222,7 @@ def _str_split(self, pat=None, n=-1, expand=False): @abc.abstractmethod def _str_rsplit(self, pat=None, n=-1): pass + + @abc.abstractmethod + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + pass diff --git a/pandas/core/strings/object_array.py b/pandas/core/strings/object_array.py index 869eabc76b555..0f54d5f869973 100644 --- a/pandas/core/strings/object_array.py +++ b/pandas/core/strings/object_array.py @@ -19,6 +19,7 @@ ) from pandas.core.dtypes.common import ( + is_object_dtype, is_re, is_scalar, ) @@ -38,7 +39,9 @@ def __len__(self): # For typing, _str_map relies on the object being sized. raise NotImplementedError - def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None): + def _str_map( + self, f, na_value=None, dtype: Optional[Dtype] = None, convert: bool = True + ): """ Map a callable over valid element of the array. @@ -53,6 +56,8 @@ def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None): for object-dtype and Categorical and ``pd.NA`` for StringArray. dtype : Dtype, optional The dtype of the result array. + convert : bool, default True + Whether to call `maybe_convert_objects` on the resulting ndarray """ if dtype is None: dtype = np.dtype("object") @@ -66,9 +71,9 @@ def _str_map(self, f, na_value=None, dtype: Optional[Dtype] = None): arr = np.asarray(self, dtype=object) mask = isna(arr) - convert = not np.all(mask) + map_convert = convert and not np.all(mask) try: - result = lib.map_infer_mask(arr, f, mask.view(np.uint8), convert) + result = lib.map_infer_mask(arr, f, mask.view(np.uint8), map_convert) except (TypeError, AttributeError) as e: # Reraise the exception if callable `f` got wrong number of args. # The user may want to be warned by this, instead of getting NaN @@ -94,7 +99,7 @@ def g(x): return result if na_value is not np.nan: np.putmask(result, mask, na_value) - if result.dtype == object: + if convert and result.dtype == object: result = lib.maybe_convert_objects(result) return result @@ -408,3 +413,51 @@ def _str_lstrip(self, to_strip=None): def _str_rstrip(self, to_strip=None): return self._str_map(lambda x: x.rstrip(to_strip)) + + def _str_extract(self, pat: str, flags: int = 0, expand: bool = True): + regex = re.compile(pat, flags=flags) + na_value = self._str_na_value + + if regex.groups == 1: + + def f(x): + m = regex.search(x) + return m.groups()[0] if m else na_value + + return self._str_map(f, convert=False) + else: + out = np.empty((len(self), regex.groups), dtype=object) + + if is_object_dtype(self): + + def f(x): + if not isinstance(x, str): + return na_value + m = regex.search(x) + if m: + return [ + na_value if item is None else item for item in m.groups() + ] + else: + return na_value + + else: + + def f(x): + m = regex.search(x) + if m: + return [ + na_value if item is None else item for item in m.groups() + ] + else: + return na_value + + result = lib.map_infer_mask( + np.asarray(self), + f, + mask=isna(self).view("uint8"), + convert=False, + na_value=na_value, + out=out, + ) + return result