Skip to content

Commit 539fbf0

Browse files
[ArrowStringArray] REF: dedup df creation in str.extract (#41502)
1 parent fea99c3 commit 539fbf0

File tree

1 file changed

+34
-42
lines changed

1 file changed

+34
-42
lines changed

pandas/core/strings/accessor.py

+34-42
Original file line numberDiff line numberDiff line change
@@ -3086,52 +3086,38 @@ def _get_group_names(regex: Pattern) -> List[Hashable]:
30863086

30873087
def _str_extract_noexpand(arr, pat, flags=0):
30883088
"""
3089-
Find groups in each string in the Series using passed regular
3090-
expression. This function is called from
3091-
str_extract(expand=False), and can return Series, DataFrame, or
3092-
Index.
3089+
Find groups in each string in the Series/Index using passed regular expression.
30933090
3091+
This function is called from str_extract(expand=False) when there is a single group
3092+
in the regex.
3093+
3094+
Returns
3095+
-------
3096+
np.ndarray
30943097
"""
3095-
from pandas import (
3096-
DataFrame,
3097-
array as pd_array,
3098-
)
3098+
from pandas import array as pd_array
30993099

31003100
regex = re.compile(pat, flags=flags)
31013101
groups_or_na = _groups_or_na_fun(regex)
31023102
result_dtype = _result_dtype(arr)
31033103

3104-
if regex.groups == 1:
3105-
result = np.array([groups_or_na(val)[0] for val in arr], dtype=object)
3106-
name = _get_single_group_name(regex)
3107-
# not dispatching, so we have to reconstruct here.
3108-
result = pd_array(result, dtype=result_dtype)
3109-
else:
3110-
name = None
3111-
columns = _get_group_names(regex)
3112-
if arr.size == 0:
3113-
# error: Incompatible types in assignment (expression has type
3114-
# "DataFrame", variable has type "ndarray")
3115-
result = DataFrame( # type: ignore[assignment]
3116-
columns=columns, dtype=result_dtype
3117-
)
3118-
else:
3119-
# error: Incompatible types in assignment (expression has type
3120-
# "DataFrame", variable has type "ndarray")
3121-
result = DataFrame( # type:ignore[assignment]
3122-
[groups_or_na(val) for val in arr],
3123-
columns=columns,
3124-
index=arr.index,
3125-
dtype=result_dtype,
3126-
)
3127-
return result, name
3104+
result = np.array([groups_or_na(val)[0] for val in arr], dtype=object)
3105+
# not dispatching, so we have to reconstruct here.
3106+
result = pd_array(result, dtype=result_dtype)
3107+
return result
31283108

31293109

31303110
def _str_extract_frame(arr, pat, flags=0):
31313111
"""
3132-
For each subject string in the Series, extract groups from the
3133-
first match of regular expression pat. This function is called from
3134-
str_extract(expand=True), and always returns a DataFrame.
3112+
Find groups in each string in the Series/Index using passed regular expression.
3113+
3114+
For each subject string in the Series/Index, extract groups from the first match of
3115+
regular expression pat. This function is called from str_extract(expand=True) or
3116+
str_extract(expand=False) when there is more than one group in the regex.
3117+
3118+
Returns
3119+
-------
3120+
DataFrame
31353121
31363122
"""
31373123
from pandas import DataFrame
@@ -3141,11 +3127,13 @@ def _str_extract_frame(arr, pat, flags=0):
31413127
columns = _get_group_names(regex)
31423128
result_dtype = _result_dtype(arr)
31433129

3144-
if len(arr) == 0:
3130+
if arr.size == 0:
31453131
return DataFrame(columns=columns, dtype=result_dtype)
3146-
try:
3132+
3133+
result_index: Optional["Index"]
3134+
if isinstance(arr, ABCSeries):
31473135
result_index = arr.index
3148-
except AttributeError:
3136+
else:
31493137
result_index = None
31503138
return DataFrame(
31513139
[groups_or_na(val) for val in arr],
@@ -3156,12 +3144,16 @@ def _str_extract_frame(arr, pat, flags=0):
31563144

31573145

31583146
def str_extract(arr, pat, flags=0, expand=True):
3159-
if expand:
3147+
regex = re.compile(pat, flags=flags)
3148+
returns_df = regex.groups > 1 or expand
3149+
3150+
if returns_df:
3151+
name = None
31603152
result = _str_extract_frame(arr._orig, pat, flags=flags)
3161-
return result.__finalize__(arr._orig, method="str_extract")
31623153
else:
3163-
result, name = _str_extract_noexpand(arr._orig, pat, flags=flags)
3164-
return arr._wrap_result(result, name=name, expand=expand)
3154+
name = _get_single_group_name(regex)
3155+
result = _str_extract_noexpand(arr._orig, pat, flags=flags)
3156+
return arr._wrap_result(result, name=name)
31653157

31663158

31673159
def str_extractall(arr, pat, flags=0):

0 commit comments

Comments
 (0)