Skip to content

Commit 130ce53

Browse files
[ArrowStringArray] REF: str.extract argument validation (#41418)
1 parent 8170b03 commit 130ce53

File tree

1 file changed

+19
-7
lines changed

1 file changed

+19
-7
lines changed

pandas/core/strings/accessor.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
from functools import wraps
33
import re
44
from typing import (
5+
TYPE_CHECKING,
56
Dict,
67
Hashable,
78
List,
89
Optional,
910
Pattern,
11+
Union,
1012
)
1113
import warnings
1214

1315
import numpy as np
1416

1517
import pandas._libs.lib as lib
18+
from pandas._typing import FrameOrSeriesUnion
1619
from pandas.util._decorators import Appender
1720

1821
from pandas.core.dtypes.common import (
@@ -33,6 +36,9 @@
3336

3437
from pandas.core.base import NoNewAttributesMixin
3538

39+
if TYPE_CHECKING:
40+
from pandas import Index
41+
3642
_shared_docs: Dict[str, str] = {}
3743
_cpython_optimized_encoders = (
3844
"utf-8",
@@ -2276,7 +2282,9 @@ def findall(self, pat, flags=0):
22762282
return self._wrap_result(result, returns_string=False)
22772283

22782284
@forbid_nonstring_types(["bytes"])
2279-
def extract(self, pat, flags=0, expand=True):
2285+
def extract(
2286+
self, pat: str, flags: int = 0, expand: bool = True
2287+
) -> Union[FrameOrSeriesUnion, "Index"]:
22802288
r"""
22812289
Extract capture groups in the regex `pat` as columns in a DataFrame.
22822290
@@ -2357,6 +2365,16 @@ def extract(self, pat, flags=0, expand=True):
23572365
2 NaN
23582366
dtype: object
23592367
"""
2368+
if not isinstance(expand, bool):
2369+
raise ValueError("expand must be True or False")
2370+
2371+
regex = re.compile(pat, flags=flags)
2372+
if regex.groups == 0:
2373+
raise ValueError("pattern contains no capture groups")
2374+
2375+
if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex):
2376+
raise ValueError("only one regex group is supported with Index")
2377+
23602378
# TODO: dispatch
23612379
return str_extract(self, pat, flags, expand=expand)
23622380

@@ -3010,8 +3028,6 @@ def cat_core(list_of_columns: List, sep: str):
30103028

30113029
def _groups_or_na_fun(regex):
30123030
"""Used in both extract_noexpand and extract_frame"""
3013-
if regex.groups == 0:
3014-
raise ValueError("pattern contains no capture groups")
30153031
empty_row = [np.nan] * regex.groups
30163032

30173033
def f(x):
@@ -3086,8 +3102,6 @@ def _str_extract_noexpand(arr, pat, flags=0):
30863102
# not dispatching, so we have to reconstruct here.
30873103
result = pd_array(result, dtype=result_dtype)
30883104
else:
3089-
if isinstance(arr, ABCIndex):
3090-
raise ValueError("only one regex group is supported with Index")
30913105
name = None
30923106
columns = _get_group_names(regex)
30933107
if arr.size == 0:
@@ -3138,8 +3152,6 @@ def _str_extract_frame(arr, pat, flags=0):
31383152

31393153

31403154
def str_extract(arr, pat, flags=0, expand=True):
3141-
if not isinstance(expand, bool):
3142-
raise ValueError("expand must be True or False")
31433155
if expand:
31443156
result = _str_extract_frame(arr._orig, pat, flags=flags)
31453157
return result.__finalize__(arr._orig, method="str_extract")

0 commit comments

Comments
 (0)