diff --git a/pandas/core/strings/accessor.py b/pandas/core/strings/accessor.py index 55a12a301c6e6..2646ddfa45b58 100644 --- a/pandas/core/strings/accessor.py +++ b/pandas/core/strings/accessor.py @@ -2,17 +2,20 @@ from functools import wraps import re from typing import ( + TYPE_CHECKING, Dict, Hashable, List, Optional, Pattern, + Union, ) import warnings import numpy as np import pandas._libs.lib as lib +from pandas._typing import FrameOrSeriesUnion from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -33,6 +36,9 @@ from pandas.core.base import NoNewAttributesMixin +if TYPE_CHECKING: + from pandas import Index + _shared_docs: Dict[str, str] = {} _cpython_optimized_encoders = ( "utf-8", @@ -2276,7 +2282,9 @@ def findall(self, pat, flags=0): return self._wrap_result(result, returns_string=False) @forbid_nonstring_types(["bytes"]) - def extract(self, pat, flags=0, expand=True): + def extract( + self, pat: str, flags: int = 0, expand: bool = True + ) -> Union[FrameOrSeriesUnion, "Index"]: r""" Extract capture groups in the regex `pat` as columns in a DataFrame. @@ -2357,6 +2365,16 @@ def extract(self, pat, flags=0, expand=True): 2 NaN dtype: object """ + if not isinstance(expand, bool): + raise ValueError("expand must be True or False") + + regex = re.compile(pat, flags=flags) + if regex.groups == 0: + raise ValueError("pattern contains no capture groups") + + if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex): + raise ValueError("only one regex group is supported with Index") + # TODO: dispatch return str_extract(self, pat, flags, expand=expand) @@ -3010,8 +3028,6 @@ def cat_core(list_of_columns: List, sep: str): def _groups_or_na_fun(regex): """Used in both extract_noexpand and extract_frame""" - if regex.groups == 0: - raise ValueError("pattern contains no capture groups") empty_row = [np.nan] * regex.groups def f(x): @@ -3086,8 +3102,6 @@ def _str_extract_noexpand(arr, pat, flags=0): # not dispatching, so we have to reconstruct here. result = pd_array(result, dtype=result_dtype) else: - if isinstance(arr, ABCIndex): - raise ValueError("only one regex group is supported with Index") name = None columns = _get_group_names(regex) if arr.size == 0: @@ -3138,8 +3152,6 @@ def _str_extract_frame(arr, pat, flags=0): def str_extract(arr, pat, flags=0, expand=True): - if not isinstance(expand, bool): - raise ValueError("expand must be True or False") if expand: result = _str_extract_frame(arr._orig, pat, flags=flags) return result.__finalize__(arr._orig, method="str_extract")