Skip to content

Commit de28385

Browse files
committed
preserve type in series.str
1 parent 427a707 commit de28385

File tree

3 files changed

+97
-55
lines changed

3 files changed

+97
-55
lines changed

pandas-stubs/core/series.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
11561156
@property
11571157
def str(
11581158
self,
1159-
) -> StringMethods[Series, DataFrame, Series[bool], Series[list[str]]]: ...
1159+
) -> StringMethods[Series[S1], DataFrame, Series[bool], Series[list[str]]]: ...
11601160
@property
11611161
def dt(self) -> CombinedDatetimelikeProperties: ...
11621162
@property

pandas-stubs/core/strings.pyi

+28-12
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,23 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
9595
@overload
9696
def partition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ...
9797
@overload
98-
def partition(self, sep: str, expand: Literal[False]) -> T: ...
98+
def partition(
99+
self, sep: str, expand: Literal[False]
100+
) -> pd.Series[type[object]]: ...
99101
@overload
100-
def partition(self, *, expand: Literal[False]) -> T: ...
102+
def partition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ...
101103
@overload
102104
def rpartition(self, sep: str = ...) -> pd.DataFrame: ...
103105
@overload
104106
def rpartition(self, *, expand: Literal[True]) -> pd.DataFrame: ...
105107
@overload
106108
def rpartition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ...
107109
@overload
108-
def rpartition(self, sep: str, expand: Literal[False]) -> T: ...
110+
def rpartition(
111+
self, sep: str, expand: Literal[False]
112+
) -> pd.Series[type[object]]: ...
109113
@overload
110-
def rpartition(self, *, expand: Literal[False]) -> T: ...
114+
def rpartition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ...
111115
def get(self, i: int) -> T: ...
112116
def join(self, sep: str) -> T: ...
113117
def contains(
@@ -147,8 +151,8 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
147151
def slice_replace(
148152
self, start: int | None = ..., stop: int | None = ..., repl: str | None = ...
149153
) -> T: ...
150-
def decode(self, encoding: str, errors: str = ...) -> T: ...
151-
def encode(self, encoding: str, errors: str = ...) -> T: ...
154+
def decode(self, encoding: str, errors: str = ...) -> Series[str]: ...
155+
def encode(self, encoding: str, errors: str = ...) -> Series[bytes]: ...
152156
def strip(self, to_strip: str | None = ...) -> T: ...
153157
def lstrip(self, to_strip: str | None = ...) -> T: ...
154158
def rstrip(self, to_strip: str | None = ...) -> T: ...
@@ -172,15 +176,27 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]):
172176
self, pat: str, flags: int = ..., *, expand: Literal[True] = ...
173177
) -> pd.DataFrame: ...
174178
@overload
175-
def extract(self, pat: str, flags: int, expand: Literal[False]) -> T: ...
179+
def extract(
180+
self, pat: str, flags: int, expand: Literal[False]
181+
) -> Series[type[object]]: ...
176182
@overload
177-
def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> T: ...
183+
def extract(
184+
self, pat: str, flags: int = ..., *, expand: Literal[False]
185+
) -> Series[type[object]]: ...
178186
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
179-
def find(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
180-
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
187+
def find(
188+
self, sub: str, start: int = ..., end: int | None = ...
189+
) -> Series[int]: ...
190+
def rfind(
191+
self, sub: str, start: int = ..., end: int | None = ...
192+
) -> Series[int]: ...
181193
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ...
182-
def index(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
183-
def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> T: ...
194+
def index(
195+
self, sub: str, start: int = ..., end: int | None = ...
196+
) -> Series[int]: ...
197+
def rindex(
198+
self, sub: str, start: int = ..., end: int | None = ...
199+
) -> Series[int]: ...
184200
def len(self) -> Series[int]: ...
185201
def lower(self) -> T: ...
186202
def upper(self) -> T: ...

tests/test_series.py

+68-42
Original file line numberDiff line numberDiff line change
@@ -1577,31 +1577,32 @@ def test_string_accessors():
15771577
)
15781578
s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
15791579
s3 = pd.Series(["a1", "b2", "c3"])
1580-
check(assert_type(s.str.capitalize(), pd.Series), pd.Series)
1581-
check(assert_type(s.str.casefold(), pd.Series), pd.Series)
1580+
s4 = pd.Series([b"a1", b"b2", b"c3"])
1581+
check(assert_type(s.str.capitalize(), "pd.Series[str]"), pd.Series, str)
1582+
check(assert_type(s.str.casefold(), "pd.Series[str]"), pd.Series, str)
15821583
check(assert_type(s.str.cat(sep="X"), str), str)
1583-
check(assert_type(s.str.center(10), pd.Series), pd.Series)
1584+
check(assert_type(s.str.center(10), "pd.Series[str]"), pd.Series, str)
15841585
check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_)
15851586
check(
15861587
assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"),
15871588
pd.Series,
15881589
np.bool_,
15891590
)
15901591
check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, np.integer)
1591-
check(assert_type(s.str.decode("utf-8"), pd.Series), pd.Series)
1592-
check(assert_type(s.str.encode("latin-1"), pd.Series), pd.Series)
1592+
check(assert_type(s4.str.decode("utf-8"), "pd.Series[str]"), pd.Series, str)
1593+
check(assert_type(s.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes)
15931594
check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_)
15941595
check(
15951596
assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_
15961597
)
15971598
check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
15981599
check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
1599-
check(assert_type(s.str.find("p"), pd.Series), pd.Series)
1600+
check(assert_type(s.str.find("p"), "pd.Series[int]"), pd.Series, np.int64)
16001601
check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"), pd.Series, list)
16011602
check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_)
1602-
check(assert_type(s.str.get(2), pd.Series), pd.Series)
1603+
check(assert_type(s.str.get(2), "pd.Series[str]"), pd.Series, str)
16031604
check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame)
1604-
check(assert_type(s.str.index("p"), pd.Series), pd.Series)
1605+
check(assert_type(s.str.index("p"), "pd.Series[int]"), pd.Series, np.int64)
16051606
check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_)
16061607
check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_)
16071608
check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_)
@@ -1613,20 +1614,20 @@ def test_string_accessors():
16131614
check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_)
16141615
check(assert_type(s2.str.join("-"), pd.Series), pd.Series)
16151616
check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, np.integer)
1616-
check(assert_type(s.str.ljust(80), pd.Series), pd.Series)
1617-
check(assert_type(s.str.lower(), pd.Series), pd.Series)
1618-
check(assert_type(s.str.lstrip("a"), pd.Series), pd.Series)
1617+
check(assert_type(s.str.ljust(80), "pd.Series[str]"), pd.Series, str)
1618+
check(assert_type(s.str.lower(), "pd.Series[str]"), pd.Series, str)
1619+
check(assert_type(s.str.lstrip("a"), "pd.Series[str]"), pd.Series, str)
16191620
check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_)
1620-
check(assert_type(s.str.normalize("NFD"), pd.Series), pd.Series)
1621-
check(assert_type(s.str.pad(80, "right"), pd.Series), pd.Series)
1621+
check(assert_type(s.str.normalize("NFD"), "pd.Series[str]"), pd.Series, str)
1622+
check(assert_type(s.str.pad(80, "right"), "pd.Series[str]"), pd.Series, str)
16221623
check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame)
1623-
check(assert_type(s.str.removeprefix("a"), pd.Series), pd.Series)
1624-
check(assert_type(s.str.removesuffix("e"), pd.Series), pd.Series)
1625-
check(assert_type(s.str.repeat(2), pd.Series), pd.Series)
1626-
check(assert_type(s.str.replace("a", "X"), pd.Series), pd.Series)
1627-
check(assert_type(s.str.rfind("e"), pd.Series), pd.Series)
1628-
check(assert_type(s.str.rindex("p"), pd.Series), pd.Series)
1629-
check(assert_type(s.str.rjust(80), pd.Series), pd.Series)
1624+
check(assert_type(s.str.removeprefix("a"), "pd.Series[str]"), pd.Series, str)
1625+
check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"), pd.Series, str)
1626+
check(assert_type(s.str.repeat(2), "pd.Series[str]"), pd.Series, str)
1627+
check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"), pd.Series, str)
1628+
check(assert_type(s.str.rfind("e"), "pd.Series[int]"), pd.Series, np.int64)
1629+
check(assert_type(s.str.rindex("p"), "pd.Series[int]"), pd.Series, np.int64)
1630+
check(assert_type(s.str.rjust(80), "pd.Series[str]"), pd.Series, str)
16301631
check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame)
16311632
check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list)
16321633
check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame)
@@ -1635,9 +1636,11 @@ def test_string_accessors():
16351636
pd.Series,
16361637
list,
16371638
)
1638-
check(assert_type(s.str.rstrip(), pd.Series), pd.Series)
1639-
check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series)
1640-
check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series)
1639+
check(assert_type(s.str.rstrip(), "pd.Series[str]"), pd.Series, str)
1640+
check(assert_type(s.str.slice(0, 4, 2), "pd.Series[str]"), pd.Series, str)
1641+
check(
1642+
assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"), pd.Series, str
1643+
)
16411644
check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list)
16421645
# GH 194
16431646
check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame)
@@ -1652,13 +1655,19 @@ def test_string_accessors():
16521655
pd.Series,
16531656
np.bool_,
16541657
)
1655-
check(assert_type(s.str.strip(), pd.Series), pd.Series)
1656-
check(assert_type(s.str.swapcase(), pd.Series), pd.Series)
1657-
check(assert_type(s.str.title(), pd.Series), pd.Series)
1658-
check(assert_type(s.str.translate(None), pd.Series), pd.Series)
1659-
check(assert_type(s.str.upper(), pd.Series), pd.Series)
1660-
check(assert_type(s.str.wrap(80), pd.Series), pd.Series)
1661-
check(assert_type(s.str.zfill(10), pd.Series), pd.Series)
1658+
check(assert_type(s.str.strip(), "pd.Series[str]"), pd.Series, str)
1659+
check(assert_type(s.str.swapcase(), "pd.Series[str]"), pd.Series, str)
1660+
check(assert_type(s.str.title(), "pd.Series[str]"), pd.Series, str)
1661+
check(
1662+
assert_type(
1663+
s.str.translate(str.maketrans({"ñ": "n", "ç": "c"})), "pd.Series[str]"
1664+
),
1665+
pd.Series,
1666+
str,
1667+
)
1668+
check(assert_type(s.str.upper(), "pd.Series[str]"), pd.Series, str)
1669+
check(assert_type(s.str.wrap(80), "pd.Series[str]"), pd.Series, str)
1670+
check(assert_type(s.str.zfill(10), "pd.Series[str]"), pd.Series, str)
16621671

16631672

16641673
def test_series_overloads_cat():
@@ -1669,22 +1678,22 @@ def test_series_overloads_cat():
16691678
check(assert_type(s.str.cat(None, sep=";"), str), str)
16701679
check(
16711680
assert_type(
1672-
s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), UnknownSeries
1681+
s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"),
1682+
"pd.Series[str]",
16731683
),
1674-
UnknownSeries,
1684+
pd.Series,
1685+
str,
16751686
)
16761687
check(
16771688
assert_type(
16781689
s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"),
1679-
UnknownSeries,
1690+
"pd.Series[str]",
16801691
),
1681-
UnknownSeries,
1692+
pd.Series,
1693+
str,
16821694
)
16831695
unknown_s: UnknownSeries = pd.DataFrame({"a": ["a", "b"]})["a"]
1684-
check(
1685-
assert_type(s.str.cat(unknown_s, sep=";"), UnknownSeries),
1686-
UnknownSeries,
1687-
)
1696+
check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str)
16881697

16891698

16901699
def test_series_overloads_partition():
@@ -1703,13 +1712,21 @@ def test_series_overloads_partition():
17031712
check(
17041713
assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame
17051714
)
1706-
check(assert_type(s.str.partition(sep=";", expand=False), pd.Series), pd.Series)
1715+
check(
1716+
assert_type(s.str.partition(sep=";", expand=False), "pd.Series[type[object]]"),
1717+
pd.Series,
1718+
object,
1719+
)
17071720

17081721
check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame)
17091722
check(
17101723
assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame
17111724
)
1712-
check(assert_type(s.str.rpartition(sep=";", expand=False), pd.Series), pd.Series)
1725+
check(
1726+
assert_type(s.str.rpartition(sep=";", expand=False), "pd.Series[type[object]]"),
1727+
pd.Series,
1728+
object,
1729+
)
17131730

17141731

17151732
def test_series_overloads_extract():
@@ -1720,10 +1737,19 @@ def test_series_overloads_extract():
17201737
check(
17211738
assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame
17221739
)
1723-
check(assert_type(s.str.extract(r"[ab](\d)", expand=False), pd.Series), pd.Series)
17241740
check(
1725-
assert_type(s.str.extract(r"[ab](\d)", re.IGNORECASE, False), pd.Series),
1741+
assert_type(
1742+
s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]"
1743+
),
17261744
pd.Series,
1745+
object,
1746+
)
1747+
check(
1748+
assert_type(
1749+
s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]"
1750+
),
1751+
pd.Series,
1752+
object,
17271753
)
17281754

17291755

0 commit comments

Comments
 (0)