Skip to content

Commit 0d1fc59

Browse files
committed
fixup str.extract
1 parent b244308 commit 0d1fc59

File tree

3 files changed

+47
-30
lines changed

3 files changed

+47
-30
lines changed

pandas-stubs/core/strings.pyi

+2-6
Original file line numberDiff line numberDiff line change
@@ -182,13 +182,9 @@ class StringMethods(
182182
self, pat: str, flags: int = ..., *, expand: Literal[True] = ...
183183
) -> pd.DataFrame: ...
184184
@overload
185-
def extract(
186-
self, pat: str, flags: int, expand: Literal[False]
187-
) -> Series[type[object]]: ...
185+
def extract(self, pat: str, flags: int, expand: Literal[False]) -> _TO: ...
188186
@overload
189-
def extract(
190-
self, pat: str, flags: int = ..., *, expand: Literal[False]
191-
) -> Series[type[object]]: ...
187+
def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> _TO: ...
192188
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
193189
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...
194190
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...

tests/test_series.py

-24
Original file line numberDiff line numberDiff line change
@@ -1571,30 +1571,6 @@ def test_categorical_codes():
15711571
assert_type(cat.codes, "np_ndarray_int")
15721572

15731573

1574-
def test_series_overloads_extract():
1575-
s = pd.Series(
1576-
["appl;ep", "ban;anap", "Cherr;yp", "DATEp", "eGGp;LANTp", "12;3p", "23.45p"]
1577-
)
1578-
check(assert_type(s.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame)
1579-
check(
1580-
assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame
1581-
)
1582-
check(
1583-
assert_type(
1584-
s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]"
1585-
),
1586-
pd.Series,
1587-
object,
1588-
)
1589-
check(
1590-
assert_type(
1591-
s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]"
1592-
),
1593-
pd.Series,
1594-
object,
1595-
)
1596-
1597-
15981574
def test_relops() -> None:
15991575
# GH 175
16001576
s: str = "abc"

tests/test_string_accessors.py

+45
Original file line numberDiff line numberDiff line change
@@ -349,3 +349,48 @@ def test_index_overloads_cat():
349349
check(
350350
assert_type(idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), pd.Index, str
351351
)
352+
353+
354+
def test_series_overloads_extract():
355+
s = pd.Series(DATA)
356+
check(assert_type(s.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame)
357+
check(
358+
assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame
359+
)
360+
check(
361+
assert_type(
362+
s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]"
363+
),
364+
pd.Series,
365+
object,
366+
)
367+
check(
368+
assert_type(
369+
s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]"
370+
),
371+
pd.Series,
372+
object,
373+
)
374+
375+
376+
def test_index_overloads_extract():
377+
idx = pd.Index(DATA)
378+
check(assert_type(idx.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame)
379+
check(
380+
assert_type(idx.str.extract(r"[ab](\d)", expand=True), pd.DataFrame),
381+
pd.DataFrame,
382+
)
383+
check(
384+
assert_type(
385+
idx.str.extract(r"[ab](\d)", expand=False), "pd.Index[type[object]]"
386+
),
387+
pd.Index,
388+
object,
389+
)
390+
check(
391+
assert_type(
392+
idx.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Index[type[object]]"
393+
),
394+
pd.Index,
395+
object,
396+
)

0 commit comments

Comments
 (0)