Skip to content

Commit 6be9e3e

Browse files
simonjayhawkinsTLouf
authored andcommitted
[ArrowStringArray] REF: str.extract accessor method (pandas-dev#41535)
1 parent 04fd235 commit 6be9e3e

File tree

1 file changed

+40
-33
lines changed

1 file changed

+40
-33
lines changed

pandas/core/strings/accessor.py

+40-33
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
import numpy as np
1616

1717
import pandas._libs.lib as lib
18-
from pandas._typing import FrameOrSeriesUnion
18+
from pandas._typing import (
19+
ArrayLike,
20+
FrameOrSeriesUnion,
21+
)
1922
from pandas.util._decorators import Appender
2023

2124
from pandas.core.dtypes.common import (
@@ -3084,9 +3087,9 @@ def _get_group_names(regex: Pattern) -> List[Hashable]:
30843087
return [names.get(1 + i, i) for i in range(regex.groups)]
30853088

30863089

3087-
def _str_extract_noexpand(arr, pat, flags=0):
3090+
def _str_extract_noexpand(arr: ArrayLike, pat: str, flags=0):
30883091
"""
3089-
Find groups in each string in the Series/Index using passed regular expression.
3092+
Find groups in each string in the array using passed regular expression.
30903093
30913094
This function is called from str_extract(expand=False) when there is a single group
30923095
in the regex.
@@ -3095,65 +3098,69 @@ def _str_extract_noexpand(arr, pat, flags=0):
30953098
-------
30963099
np.ndarray
30973100
"""
3098-
from pandas import array as pd_array
3099-
31003101
regex = re.compile(pat, flags=flags)
31013102
groups_or_na = _groups_or_na_fun(regex)
3102-
result_dtype = _result_dtype(arr)
31033103

31043104
result = np.array([groups_or_na(val)[0] for val in np.asarray(arr)], dtype=object)
3105-
# not dispatching, so we have to reconstruct here.
3106-
result = pd_array(result, dtype=result_dtype)
31073105
return result
31083106

31093107

3110-
def _str_extract_frame(arr, pat, flags=0):
3108+
def _str_extract_expand(arr: ArrayLike, pat: str, flags: int = 0) -> List[List]:
31113109
"""
3112-
Find groups in each string in the Series/Index using passed regular expression.
3110+
Find groups in each string in the array using passed regular expression.
31133111
3114-
For each subject string in the Series/Index, extract groups from the first match of
3112+
For each subject string in the array, extract groups from the first match of
31153113
regular expression pat. This function is called from str_extract(expand=True) or
31163114
str_extract(expand=False) when there is more than one group in the regex.
31173115
31183116
Returns
31193117
-------
3120-
DataFrame
3118+
list of lists
31213119
31223120
"""
3123-
from pandas import DataFrame
3124-
31253121
regex = re.compile(pat, flags=flags)
31263122
groups_or_na = _groups_or_na_fun(regex)
3127-
columns = _get_group_names(regex)
3128-
result_dtype = _result_dtype(arr)
31293123

3130-
if arr.size == 0:
3131-
return DataFrame(columns=columns, dtype=result_dtype)
3124+
return [groups_or_na(val) for val in np.asarray(arr)]
31323125

3133-
result_index: Optional["Index"]
3134-
if isinstance(arr, ABCSeries):
3135-
result_index = arr.index
3136-
else:
3137-
result_index = None
3138-
return DataFrame(
3139-
[groups_or_na(val) for val in np.asarray(arr)],
3140-
columns=columns,
3141-
index=result_index,
3142-
dtype=result_dtype,
3143-
)
31443126

3127+
def str_extract(accessor: StringMethods, pat: str, flags: int = 0, expand: bool = True):
3128+
from pandas import (
3129+
DataFrame,
3130+
array as pd_array,
3131+
)
31453132

3146-
def str_extract(arr, pat, flags=0, expand=True):
3133+
obj = accessor._data
3134+
result_dtype = _result_dtype(obj)
31473135
regex = re.compile(pat, flags=flags)
31483136
returns_df = regex.groups > 1 or expand
31493137

31503138
if returns_df:
31513139
name = None
3152-
result = _str_extract_frame(arr._orig, pat, flags=flags)
3140+
columns = _get_group_names(regex)
3141+
3142+
if obj.array.size == 0:
3143+
result = DataFrame(columns=columns, dtype=result_dtype)
3144+
3145+
else:
3146+
result_list = _str_extract_expand(obj.array, pat, flags=flags)
3147+
3148+
result_index: Optional["Index"]
3149+
if isinstance(obj, ABCSeries):
3150+
result_index = obj.index
3151+
else:
3152+
result_index = None
3153+
3154+
result = DataFrame(
3155+
result_list, columns=columns, index=result_index, dtype=result_dtype
3156+
)
3157+
31533158
else:
31543159
name = _get_single_group_name(regex)
3155-
result = _str_extract_noexpand(arr._orig, pat, flags=flags)
3156-
return arr._wrap_result(result, name=name)
3160+
result_arr = _str_extract_noexpand(obj.array, pat, flags=flags)
3161+
# not dispatching, so we have to reconstruct here.
3162+
result = pd_array(result_arr, dtype=result_dtype)
3163+
return accessor._wrap_result(result, name=name)
31573164

31583165

31593166
def str_extractall(arr, pat, flags=0):

0 commit comments

Comments
 (0)