diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 1461c52d5cb65..43df34a7ecbb2 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -3034,22 +3034,6 @@ def cat_core(list_of_columns: List, sep: str): return np.sum(arr_with_sep, axis=0) -def _groups_or_na_fun(regex): - """Used in both extract_noexpand and extract_frame""" - empty_row = [np.nan] * regex.groups - - def f(x): - if not isinstance(x, str): - return empty_row - m = regex.search(x) - if m: - return [np.nan if item is None else item for item in m.groups()] - else: - return empty_row - - return f - - def _result_dtype(arr): # workaround #27953 # ideally we just pass `dtype=arr.dtype` unconditionally, but this fails @@ -3087,41 +3071,31 @@ def _get_group_names(regex: Pattern) -> List[Hashable]: return [names.get(1 + i, i) for i in range(regex.groups)] -def _str_extract_noexpand(arr: ArrayLike, pat: str, flags=0): +def _str_extract(arr: ArrayLike, pat: str, flags=0, expand: bool = True): """ Find groups in each string in the array using passed regular expression. - This function is called from str_extract(expand=False) when there is a single group - in the regex. - Returns ------- - np.ndarray + np.ndarray or list of lists is expand is True """ regex = re.compile(pat, flags=flags) - groups_or_na = _groups_or_na_fun(regex) - - result = np.array([groups_or_na(val)[0] for val in np.asarray(arr)], dtype=object) - return result - -def _str_extract_expand(arr: ArrayLike, pat: str, flags: int = 0) -> List[List]: - """ - Find groups in each string in the array using passed regular expression. - - For each subject string in the array, extract groups from the first match of - regular expression pat. This function is called from str_extract(expand=True) or - str_extract(expand=False) when there is more than one group in the regex. + empty_row = [np.nan] * regex.groups - Returns - ------- - list of lists + def f(x): + if not isinstance(x, str): + return empty_row + m = regex.search(x) + if m: + return [np.nan if item is None else item for item in m.groups()] + else: + return empty_row - """ - regex = re.compile(pat, flags=flags) - groups_or_na = _groups_or_na_fun(regex) + if expand: + return [f(val) for val in np.asarray(arr)] - return [groups_or_na(val) for val in np.asarray(arr)] + return np.array([f(val)[0] for val in np.asarray(arr)], dtype=object) def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool = True): @@ -3143,7 +3117,7 @@ def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool result = DataFrame(columns=columns, dtype=result_dtype) else: - result_list = _str_extract_expand(obj.array, pat, flags=flags) + result_list = _str_extract(obj.array, pat, flags=flags, expand=returns_df) result_index: Optional["Index"] if isinstance(obj, ABCSeries): @@ -3157,7 +3131,7 @@ def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool else: name = _get_single_group_name(regex) - result_arr = _str_extract_noexpand(obj.array, pat, flags=flags) + result_arr = _str_extract(obj.array, pat, flags=flags, expand=returns_df) # not dispatching, so we have to reconstruct here. result = pd_array(result_arr, dtype=result_dtype) return accessor._wrap_result(result, name=name)