|
2 | 2 | from functools import wraps
|
3 | 3 | import re
|
4 | 4 | from typing import (
|
| 5 | + TYPE_CHECKING, |
5 | 6 | Dict,
|
6 | 7 | Hashable,
|
7 | 8 | List,
|
8 | 9 | Optional,
|
9 | 10 | Pattern,
|
| 11 | + Union, |
10 | 12 | )
|
11 | 13 | import warnings
|
12 | 14 |
|
13 | 15 | import numpy as np
|
14 | 16 |
|
15 | 17 | import pandas._libs.lib as lib
|
| 18 | +from pandas._typing import FrameOrSeriesUnion |
16 | 19 | from pandas.util._decorators import Appender
|
17 | 20 |
|
18 | 21 | from pandas.core.dtypes.common import (
|
|
33 | 36 |
|
34 | 37 | from pandas.core.base import NoNewAttributesMixin
|
35 | 38 |
|
| 39 | +if TYPE_CHECKING: |
| 40 | + from pandas import Index |
| 41 | + |
36 | 42 | _shared_docs: Dict[str, str] = {}
|
37 | 43 | _cpython_optimized_encoders = (
|
38 | 44 | "utf-8",
|
@@ -2276,7 +2282,9 @@ def findall(self, pat, flags=0):
|
2276 | 2282 | return self._wrap_result(result, returns_string=False)
|
2277 | 2283 |
|
2278 | 2284 | @forbid_nonstring_types(["bytes"])
|
2279 |
| - def extract(self, pat, flags=0, expand=True): |
| 2285 | + def extract( |
| 2286 | + self, pat: str, flags: int = 0, expand: bool = True |
| 2287 | + ) -> Union[FrameOrSeriesUnion, "Index"]: |
2280 | 2288 | r"""
|
2281 | 2289 | Extract capture groups in the regex `pat` as columns in a DataFrame.
|
2282 | 2290 |
|
@@ -2357,6 +2365,16 @@ def extract(self, pat, flags=0, expand=True):
|
2357 | 2365 | 2 NaN
|
2358 | 2366 | dtype: object
|
2359 | 2367 | """
|
| 2368 | + if not isinstance(expand, bool): |
| 2369 | + raise ValueError("expand must be True or False") |
| 2370 | + |
| 2371 | + regex = re.compile(pat, flags=flags) |
| 2372 | + if regex.groups == 0: |
| 2373 | + raise ValueError("pattern contains no capture groups") |
| 2374 | + |
| 2375 | + if not expand and regex.groups > 1 and isinstance(self._data, ABCIndex): |
| 2376 | + raise ValueError("only one regex group is supported with Index") |
| 2377 | + |
2360 | 2378 | # TODO: dispatch
|
2361 | 2379 | return str_extract(self, pat, flags, expand=expand)
|
2362 | 2380 |
|
@@ -3010,8 +3028,6 @@ def cat_core(list_of_columns: List, sep: str):
|
3010 | 3028 |
|
3011 | 3029 | def _groups_or_na_fun(regex):
|
3012 | 3030 | """Used in both extract_noexpand and extract_frame"""
|
3013 |
| - if regex.groups == 0: |
3014 |
| - raise ValueError("pattern contains no capture groups") |
3015 | 3031 | empty_row = [np.nan] * regex.groups
|
3016 | 3032 |
|
3017 | 3033 | def f(x):
|
@@ -3086,8 +3102,6 @@ def _str_extract_noexpand(arr, pat, flags=0):
|
3086 | 3102 | # not dispatching, so we have to reconstruct here.
|
3087 | 3103 | result = pd_array(result, dtype=result_dtype)
|
3088 | 3104 | else:
|
3089 |
| - if isinstance(arr, ABCIndex): |
3090 |
| - raise ValueError("only one regex group is supported with Index") |
3091 | 3105 | name = None
|
3092 | 3106 | columns = _get_group_names(regex)
|
3093 | 3107 | if arr.size == 0:
|
@@ -3138,8 +3152,6 @@ def _str_extract_frame(arr, pat, flags=0):
|
3138 | 3152 |
|
3139 | 3153 |
|
3140 | 3154 | def str_extract(arr, pat, flags=0, expand=True):
|
3141 |
| - if not isinstance(expand, bool): |
3142 |
| - raise ValueError("expand must be True or False") |
3143 | 3155 | if expand:
|
3144 | 3156 | result = _str_extract_frame(arr._orig, pat, flags=flags)
|
3145 | 3157 | return result.__finalize__(arr._orig, method="str_extract")
|
|
0 commit comments