Skip to content

Commit 208a55c

Browse files
committed
return _T_STR, except for slice because that one preserves the input types
1 parent 17e280f commit 208a55c

File tree

2 files changed

+110
-93
lines changed

2 files changed

+110
-93
lines changed

pandas-stubs/core/strings.pyi

+26-26
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class StringMethods(
5151
Generic[T, _T_EXPANDING, _T_BOOL, _T_LIST_STR, _T_INT, _T_BYTES, _T_STR, _T_OBJECT],
5252
):
5353
def __init__(self, data: T) -> None: ...
54-
def __getitem__(self, key: slice | int) -> T: ...
55-
def __iter__(self) -> T: ...
54+
def __getitem__(self, key: slice | int) -> _T_STR: ...
55+
def __iter__(self) -> _T_STR: ...
5656
@overload
5757
def cat(
5858
self,
@@ -79,7 +79,7 @@ class StringMethods(
7979
sep: str = ...,
8080
na_rep: str | None = ...,
8181
join: JoinHow = ...,
82-
) -> T: ...
82+
) -> _T_STR: ...
8383
@overload
8484
def split(
8585
self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ...
@@ -121,7 +121,7 @@ class StringMethods(
121121
def rpartition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ...
122122
@overload
123123
def rpartition(self, *, expand: Literal[False]) -> _T_OBJECT: ...
124-
def get(self, i: int) -> T: ...
124+
def get(self, i: int) -> _T_STR: ...
125125
def join(self, sep: str) -> _T_STR: ...
126126
def contains(
127127
self,
@@ -142,29 +142,29 @@ class StringMethods(
142142
case: bool | None = ...,
143143
flags: int = ...,
144144
regex: bool = ...,
145-
) -> T: ...
146-
def repeat(self, repeats: int | Sequence[int]) -> T: ...
145+
) -> _T_STR: ...
146+
def repeat(self, repeats: int | Sequence[int]) -> _T_STR: ...
147147
def pad(
148148
self,
149149
width: int,
150150
side: Literal["left", "right", "both"] = ...,
151151
fillchar: str = ...,
152-
) -> T: ...
153-
def center(self, width: int, fillchar: str = ...) -> T: ...
154-
def ljust(self, width: int, fillchar: str = ...) -> T: ...
155-
def rjust(self, width: int, fillchar: str = ...) -> T: ...
156-
def zfill(self, width: int) -> T: ...
152+
) -> _T_STR: ...
153+
def center(self, width: int, fillchar: str = ...) -> _T_STR: ...
154+
def ljust(self, width: int, fillchar: str = ...) -> _T_STR: ...
155+
def rjust(self, width: int, fillchar: str = ...) -> _T_STR: ...
156+
def zfill(self, width: int) -> _T_STR: ...
157157
def slice(
158158
self, start: int | None = ..., stop: int | None = ..., step: int | None = ...
159159
) -> T: ...
160160
def slice_replace(
161161
self, start: int | None = ..., stop: int | None = ..., repl: str | None = ...
162-
) -> T: ...
162+
) -> _T_STR: ...
163163
def decode(self, encoding: str, errors: str = ...) -> _T_STR: ...
164164
def encode(self, encoding: str, errors: str = ...) -> _T_BYTES: ...
165-
def strip(self, to_strip: str | None = ...) -> T: ...
166-
def lstrip(self, to_strip: str | None = ...) -> T: ...
167-
def rstrip(self, to_strip: str | None = ...) -> T: ...
165+
def strip(self, to_strip: str | None = ...) -> _T_STR: ...
166+
def lstrip(self, to_strip: str | None = ...) -> _T_STR: ...
167+
def rstrip(self, to_strip: str | None = ...) -> _T_STR: ...
168168
def wrap(
169169
self,
170170
width: int,
@@ -173,9 +173,9 @@ class StringMethods(
173173
drop_whitespace: bool | None = ...,
174174
break_long_words: bool | None = ...,
175175
break_on_hyphens: bool | None = ...,
176-
) -> T: ...
176+
) -> _T_STR: ...
177177
def get_dummies(self, sep: str = ...) -> _T_EXPANDING: ...
178-
def translate(self, table: dict[int, int | str | None] | None) -> T: ...
178+
def translate(self, table: dict[int, int | str | None] | None) -> _T_STR: ...
179179
def count(self, pat: str, flags: int = ...) -> _T_INT: ...
180180
def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
181181
def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ...
@@ -193,16 +193,16 @@ class StringMethods(
193193
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
194194
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
195195
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
196-
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ...
196+
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ...
197197
def index(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
198198
def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ...
199199
def len(self) -> _T_INT: ...
200-
def lower(self) -> T: ...
201-
def upper(self) -> T: ...
202-
def title(self) -> T: ...
203-
def capitalize(self) -> T: ...
204-
def swapcase(self) -> T: ...
205-
def casefold(self) -> T: ...
200+
def lower(self) -> _T_STR: ...
201+
def upper(self) -> _T_STR: ...
202+
def title(self) -> _T_STR: ...
203+
def capitalize(self) -> _T_STR: ...
204+
def swapcase(self) -> _T_STR: ...
205+
def casefold(self) -> _T_STR: ...
206206
def isalnum(self) -> _T_BOOL: ...
207207
def isalpha(self) -> _T_BOOL: ...
208208
def isdigit(self) -> _T_BOOL: ...
@@ -215,5 +215,5 @@ class StringMethods(
215215
def fullmatch(
216216
self, pat: str, case: bool = ..., flags: int = ..., na: Any = ...
217217
) -> _T_BOOL: ...
218-
def removeprefix(self, prefix: str) -> T: ...
219-
def removesuffix(self, suffix: str) -> T: ...
218+
def removeprefix(self, prefix: str) -> _T_STR: ...
219+
def removesuffix(self, suffix: str) -> _T_STR: ...

tests/test_string_accessors.py

+84-67
Original file line numberDiff line numberDiff line change
@@ -13,70 +13,21 @@
1313

1414

1515
DATA = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"]
16+
DATA_BYTES = [b"applep", b"bananap"]
1617

1718

1819
def test_string_accessors_type_preserving_series() -> None:
19-
s = pd.Series(DATA)
20-
_check = functools.partial(check, klass=pd.Series, dtype=str)
21-
_check(assert_type(s.str.capitalize(), "pd.Series[str]"))
22-
_check(assert_type(s.str.casefold(), "pd.Series[str]"))
23-
check(assert_type(s.str.cat(sep="X"), str), str)
24-
_check(assert_type(s.str.center(10), "pd.Series[str]"))
25-
_check(assert_type(s.str.get(2), "pd.Series[str]"))
26-
_check(assert_type(s.str.ljust(80), "pd.Series[str]"))
27-
_check(assert_type(s.str.lower(), "pd.Series[str]"))
28-
_check(assert_type(s.str.lstrip("a"), "pd.Series[str]"))
29-
_check(assert_type(s.str.normalize("NFD"), "pd.Series[str]"))
30-
_check(assert_type(s.str.pad(80, "right"), "pd.Series[str]"))
31-
_check(assert_type(s.str.removeprefix("a"), "pd.Series[str]"))
32-
_check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"))
33-
_check(assert_type(s.str.repeat(2), "pd.Series[str]"))
34-
_check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"))
35-
_check(assert_type(s.str.rjust(80), "pd.Series[str]"))
36-
_check(assert_type(s.str.rstrip(), "pd.Series[str]"))
37-
_check(assert_type(s.str.slice(0, 4, 2), "pd.Series[str]"))
38-
_check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"))
39-
_check(assert_type(s.str.strip(), "pd.Series[str]"))
40-
_check(assert_type(s.str.swapcase(), "pd.Series[str]"))
41-
_check(assert_type(s.str.title(), "pd.Series[str]"))
42-
_check(
43-
assert_type(s.str.translate({241: "n"}), "pd.Series[str]"),
44-
)
45-
_check(assert_type(s.str.upper(), "pd.Series[str]"))
46-
_check(assert_type(s.str.wrap(80), "pd.Series[str]"))
47-
_check(assert_type(s.str.zfill(10), "pd.Series[str]"))
20+
s_str = pd.Series(DATA)
21+
s_bytes = pd.Series(DATA_BYTES)
22+
check(assert_type(s_str.str.slice(0, 4, 2), "pd.Series[str]"), pd.Series, str)
23+
check(assert_type(s_bytes.str.slice(0, 4, 2), "pd.Series[bytes]"), pd.Series, bytes)
4824

4925

5026
def test_string_accessors_type_preserving_index() -> None:
51-
idx = pd.Index(DATA)
52-
_check = functools.partial(check, klass=pd.Index, dtype=str)
53-
_check(assert_type(idx.str.capitalize(), "pd.Index[str]"))
54-
_check(assert_type(idx.str.casefold(), "pd.Index[str]"))
55-
check(assert_type(idx.str.cat(sep="X"), str), str)
56-
_check(assert_type(idx.str.center(10), "pd.Index[str]"))
57-
_check(assert_type(idx.str.get(2), "pd.Index[str]"))
58-
_check(assert_type(idx.str.ljust(80), "pd.Index[str]"))
59-
_check(assert_type(idx.str.lower(), "pd.Index[str]"))
60-
_check(assert_type(idx.str.lstrip("a"), "pd.Index[str]"))
61-
_check(assert_type(idx.str.normalize("NFD"), "pd.Index[str]"))
62-
_check(assert_type(idx.str.pad(80, "right"), "pd.Index[str]"))
63-
_check(assert_type(idx.str.removeprefix("a"), "pd.Index[str]"))
64-
_check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]"))
65-
_check(assert_type(idx.str.repeat(2), "pd.Index[str]"))
66-
_check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]"))
67-
_check(assert_type(idx.str.rjust(80), "pd.Index[str]"))
68-
_check(assert_type(idx.str.rstrip(), "pd.Index[str]"))
69-
_check(assert_type(idx.str.slice(0, 4, 2), "pd.Index[str]"))
70-
_check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]"))
71-
_check(assert_type(idx.str.strip(), "pd.Index[str]"))
72-
_check(assert_type(idx.str.swapcase(), "pd.Index[str]"))
73-
_check(assert_type(idx.str.title(), "pd.Index[str]"))
74-
_check(
75-
assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"),
76-
)
77-
_check(assert_type(idx.str.upper(), "pd.Index[str]"))
78-
_check(assert_type(idx.str.wrap(80), "pd.Index[str]"))
79-
_check(assert_type(idx.str.zfill(10), "pd.Index[str]"))
27+
idx_str = pd.Index(DATA)
28+
idx_bytes = pd.Index(DATA_BYTES)
29+
check(assert_type(idx_str.str.slice(0, 4, 2), "pd.Index[str]"), pd.Index, str)
30+
check(assert_type(idx_bytes.str.slice(0, 4, 2), "pd.Index[bytes]"), pd.Index, bytes)
8031

8132

8233
def test_string_accessors_boolean_series():
@@ -158,21 +109,73 @@ def test_string_accessors_integer_index():
158109

159110

160111
def test_string_accessors_string_series():
161-
s = pd.Series([b"a1", b"b2", b"c3"])
112+
s = pd.Series(DATA)
162113
_check = functools.partial(check, klass=pd.Series, dtype=str)
163-
_check(assert_type(s.str.decode("utf-8"), "pd.Series[str]"))
164-
s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
165-
_check(assert_type(s2.str.join("-"), "pd.Series[str]"))
114+
_check(assert_type(s.str.capitalize(), "pd.Series[str]"))
115+
_check(assert_type(s.str.casefold(), "pd.Series[str]"))
116+
check(assert_type(s.str.cat(sep="X"), str), str)
117+
_check(assert_type(s.str.center(10), "pd.Series[str]"))
118+
_check(assert_type(s.str.get(2), "pd.Series[str]"))
119+
_check(assert_type(s.str.ljust(80), "pd.Series[str]"))
120+
_check(assert_type(s.str.lower(), "pd.Series[str]"))
121+
_check(assert_type(s.str.lstrip("a"), "pd.Series[str]"))
122+
_check(assert_type(s.str.normalize("NFD"), "pd.Series[str]"))
123+
_check(assert_type(s.str.pad(80, "right"), "pd.Series[str]"))
124+
_check(assert_type(s.str.removeprefix("a"), "pd.Series[str]"))
125+
_check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"))
126+
_check(assert_type(s.str.repeat(2), "pd.Series[str]"))
127+
_check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"))
128+
_check(assert_type(s.str.rjust(80), "pd.Series[str]"))
129+
_check(assert_type(s.str.rstrip(), "pd.Series[str]"))
130+
_check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"))
131+
_check(assert_type(s.str.strip(), "pd.Series[str]"))
132+
_check(assert_type(s.str.swapcase(), "pd.Series[str]"))
133+
_check(assert_type(s.str.title(), "pd.Series[str]"))
134+
_check(
135+
assert_type(s.str.translate({241: "n"}), "pd.Series[str]"),
136+
)
137+
_check(assert_type(s.str.upper(), "pd.Series[str]"))
138+
_check(assert_type(s.str.wrap(80), "pd.Series[str]"))
139+
_check(assert_type(s.str.zfill(10), "pd.Series[str]"))
140+
s_bytes = pd.Series([b"a1", b"b2", b"c3"])
141+
_check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"))
142+
s_list = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
143+
_check(assert_type(s_list.str.join("-"), "pd.Series[str]"))
166144

167145

168146
def test_string_accessors_string_index():
169-
idx = pd.Index([b"a1", b"b2", b"c3"])
147+
idx = pd.Index(DATA)
170148
_check = functools.partial(check, klass=pd.Index, dtype=str)
171-
_check(assert_type(idx.str.decode("utf-8"), "pd.Index[str]"))
172-
idx2: "pd.Index[list]" = pd.Index(
173-
[["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]
149+
_check(assert_type(idx.str.capitalize(), "pd.Index[str]"))
150+
_check(assert_type(idx.str.casefold(), "pd.Index[str]"))
151+
check(assert_type(idx.str.cat(sep="X"), str), str)
152+
_check(assert_type(idx.str.center(10), "pd.Index[str]"))
153+
_check(assert_type(idx.str.get(2), "pd.Index[str]"))
154+
_check(assert_type(idx.str.ljust(80), "pd.Index[str]"))
155+
_check(assert_type(idx.str.lower(), "pd.Index[str]"))
156+
_check(assert_type(idx.str.lstrip("a"), "pd.Index[str]"))
157+
_check(assert_type(idx.str.normalize("NFD"), "pd.Index[str]"))
158+
_check(assert_type(idx.str.pad(80, "right"), "pd.Index[str]"))
159+
_check(assert_type(idx.str.removeprefix("a"), "pd.Index[str]"))
160+
_check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]"))
161+
_check(assert_type(idx.str.repeat(2), "pd.Index[str]"))
162+
_check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]"))
163+
_check(assert_type(idx.str.rjust(80), "pd.Index[str]"))
164+
_check(assert_type(idx.str.rstrip(), "pd.Index[str]"))
165+
_check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]"))
166+
_check(assert_type(idx.str.strip(), "pd.Index[str]"))
167+
_check(assert_type(idx.str.swapcase(), "pd.Index[str]"))
168+
_check(assert_type(idx.str.title(), "pd.Index[str]"))
169+
_check(
170+
assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"),
174171
)
175-
_check(assert_type(idx2.str.join("-"), "pd.Index[str]"))
172+
_check(assert_type(idx.str.upper(), "pd.Index[str]"))
173+
_check(assert_type(idx.str.wrap(80), "pd.Index[str]"))
174+
_check(assert_type(idx.str.zfill(10), "pd.Index[str]"))
175+
idx_bytes = pd.Index([b"a1", b"b2", b"c3"])
176+
_check(assert_type(idx_bytes.str.decode("utf-8"), "pd.Index[str]"))
177+
idx_list = pd.Index([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
178+
_check(assert_type(idx_list.str.join("-"), "pd.Index[str]"))
176179

177180

178181
def test_string_accessors_bytes_series():
@@ -325,6 +328,12 @@ def test_series_overloads_cat():
325328
)
326329
unknown_s = pd.DataFrame({"a": list("abcdefg")})["a"]
327330
check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str)
331+
check(assert_type(unknown_s.str.cat(s, sep=";"), "pd.Series[str]"), pd.Series, str)
332+
check(
333+
assert_type(unknown_s.str.cat(unknown_s, sep=";"), "pd.Series[str]"),
334+
pd.Series,
335+
str,
336+
)
328337

329338

330339
def test_index_overloads_cat():
@@ -351,6 +360,14 @@ def test_index_overloads_cat():
351360
check(
352361
assert_type(idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), pd.Index, str
353362
)
363+
check(
364+
assert_type(unknown_idx.str.cat(idx, sep=";"), "pd.Index[str]"), pd.Index, str
365+
)
366+
check(
367+
assert_type(unknown_idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"),
368+
pd.Index,
369+
str,
370+
)
354371

355372

356373
def test_series_overloads_extract():

0 commit comments

Comments
 (0)