From bb5b2b9ec839adcfe5a1a0737b66cca063321a5f Mon Sep 17 00:00:00 2001 From: Simon Hawkins Date: Tue, 18 May 2021 16:08:27 +0100 Subject: [PATCH] [ArrowStringArray] REF: str.extract - move code from function to accessor method --- pandas/core/strings/accessor.py | 77 ++++++++++++++++----------------- 1 file changed, 37 insertions(+), 40 deletions(-) diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 1461c52d5cb65..600c07f63d895 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -2373,6 +2373,11 @@ def extract( 2 NaN dtype: object """ + from pandas import ( + DataFrame, + array as pd_array, + ) + if not isinstance(expand, bool): raise ValueError("expand must be True or False") @@ -2384,7 +2389,38 @@ def extract( raise ValueError("only one regex group is supported with Index") # TODO: dispatch - return str_extract(self, pat, flags, expand=expand) + + obj = self._data + result_dtype = _result_dtype(obj) + + returns_df = regex.groups > 1 or expand + + if returns_df: + name = None + columns = _get_group_names(regex) + + if obj.array.size == 0: + result = DataFrame(columns=columns, dtype=result_dtype) + + else: + result_list = _str_extract_expand(obj.array, pat, flags=flags) + + result_index: Optional["Index"] + if isinstance(obj, ABCSeries): + result_index = obj.index + else: + result_index = None + + result = DataFrame( + result_list, columns=columns, index=result_index, dtype=result_dtype + ) + + else: + name = _get_single_group_name(regex) + result_arr = _str_extract_noexpand(obj.array, pat, flags=flags) + # not dispatching, so we have to reconstruct here. + result = pd_array(result_arr, dtype=result_dtype) + return self._wrap_result(result, name=name) @forbid_nonstring_types(["bytes"]) def extractall(self, pat, flags=0): @@ -3124,45 +3160,6 @@ def _str_extract_expand(arr: ArrayLike, pat: str, flags: int = 0) -> List[List]: return [groups_or_na(val) for val in np.asarray(arr)] -def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool = True): - from pandas import ( - DataFrame, - array as pd_array, - ) - - obj = accessor._data - result_dtype = _result_dtype(obj) - regex = re.compile(pat, flags=flags) - returns_df = regex.groups > 1 or expand - - if returns_df: - name = None - columns = _get_group_names(regex) - - if obj.array.size == 0: - result = DataFrame(columns=columns, dtype=result_dtype) - - else: - result_list = _str_extract_expand(obj.array, pat, flags=flags) - - result_index: Optional["Index"] - if isinstance(obj, ABCSeries): - result_index = obj.index - else: - result_index = None - - result = DataFrame( - result_list, columns=columns, index=result_index, dtype=result_dtype - ) - - else: - name = _get_single_group_name(regex) - result_arr = _str_extract_noexpand(obj.array, pat, flags=flags) - # not dispatching, so we have to reconstruct here. - result = pd_array(result_arr, dtype=result_dtype) - return accessor._wrap_result(result, name=name) - - def str_extractall(arr, pat, flags=0): regex = re.compile(pat, flags=flags) # the regex must contain capture groups.