Skip to content

Commit 005759c

Browse files
committed
keep fixing
1 parent 3d581a8 commit 005759c

File tree

4 files changed

+90
-16
lines changed

4 files changed

+90
-16
lines changed

pandas-stubs/core/indexes/base.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ class Index(IndexOpsMixin[S1]):
272272
Index[int],
273273
Index[bytes],
274274
Index[str],
275+
Index[type[object]],
275276
]: ...
276277
def is_(self, other) -> bool: ...
277278
def __len__(self) -> int: ...

pandas-stubs/core/series.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
11641164
Series[int],
11651165
Series[bytes],
11661166
Series[str],
1167+
Series[type[object]],
11671168
]: ...
11681169
@property
11691170
def dt(self) -> CombinedDatetimelikeProperties: ...

pandas-stubs/core/strings.pyi

+13-13
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,12 @@ _TI = TypeVar("_TI", bound=Series[int] | Index[int])
4343
_TE = TypeVar("_TE", bound=Series[bytes] | Index[bytes])
4444
# The _TD type is what is used for the result of str.encode
4545
_TD = TypeVar("_TD", bound=Series[str] | Index[str])
46+
# The _TO type is what is used for the result of str.encode
47+
_TO = TypeVar("_TO", bound=Series[type[object]] | Index[type[object]])
4648

47-
class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _TD]):
49+
class StringMethods(
50+
NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _TD, _TO]
51+
):
4852
def __init__(self, data: T) -> None: ...
4953
def __getitem__(self, key: slice | int) -> T: ...
5054
def __iter__(self) -> T: ...
@@ -101,23 +105,19 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _
101105
@overload
102106
def partition(self, sep: str, expand: Literal[True]) -> _TS: ...
103107
@overload
104-
def partition(
105-
self, sep: str, expand: Literal[False]
106-
) -> pd.Series[type[object]]: ...
108+
def partition(self, sep: str, expand: Literal[False]) -> _TO: ...
107109
@overload
108-
def partition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ...
110+
def partition(self, *, expand: Literal[False]) -> _TO: ...
109111
@overload
110112
def rpartition(self, sep: str = ...) -> _TS: ...
111113
@overload
112-
def rpartition(self, *, expand: Literal[True]) -> pd.DataFrame: ...
114+
def rpartition(self, *, expand: Literal[True]) -> _TS: ...
113115
@overload
114-
def rpartition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ...
116+
def rpartition(self, sep: str, expand: Literal[True]) -> _TS: ...
115117
@overload
116-
def rpartition(
117-
self, sep: str, expand: Literal[False]
118-
) -> pd.Series[type[object]]: ...
118+
def rpartition(self, sep: str, expand: Literal[False]) -> _TO: ...
119119
@overload
120-
def rpartition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ...
120+
def rpartition(self, *, expand: Literal[False]) -> _TO: ...
121121
def get(self, i: int) -> T: ...
122122
def join(self, sep: str) -> _TD: ...
123123
def contains(
@@ -180,7 +180,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _
180180
@overload
181181
def extract(
182182
self, pat: str, flags: int = ..., *, expand: Literal[True] = ...
183-
) -> _TS: ...
183+
) -> pd.DataFrame: ...
184184
@overload
185185
def extract(
186186
self, pat: str, flags: int, expand: Literal[False]
@@ -189,7 +189,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _
189189
def extract(
190190
self, pat: str, flags: int = ..., *, expand: Literal[False]
191191
) -> Series[type[object]]: ...
192-
def extractall(self, pat: str, flags: int = ...) -> _TS: ...
192+
def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ...
193193
def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...
194194
def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ...
195195
def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ...

tests/test_string_accessors.py

+75-3
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def test_string_accessors_type_preserving_index() -> None:
8181

8282
def test_string_accessors_boolean_series():
8383
s = pd.Series(DATA)
84-
_check = functools.partial(check, klass=pd.Series, dtype=bool)
84+
_check = functools.partial(check, klass=pd.Series, dtype=np.bool_)
8585
_check(assert_type(s.str.startswith("a"), "pd.Series[bool]"))
8686
_check(
8787
assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"),
@@ -220,10 +220,82 @@ def test_string_accessors_expanding_series():
220220
def test_string_accessors_expanding_index():
221221
idx = pd.Index(["a1", "b2", "c3"])
222222
_check = functools.partial(check, klass=pd.MultiIndex)
223-
_check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.MultiIndex))
224-
_check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.MultiIndex))
225223
_check(assert_type(idx.str.get_dummies(), pd.MultiIndex))
226224
_check(assert_type(idx.str.partition("p"), pd.MultiIndex))
227225
_check(assert_type(idx.str.rpartition("p"), pd.MultiIndex))
228226
_check(assert_type(idx.str.rsplit("a", expand=True), pd.MultiIndex))
229227
_check(assert_type(idx.str.split("a", expand=True), pd.MultiIndex))
228+
229+
# These ones are the odd ones out?
230+
check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
231+
check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame)
232+
233+
234+
def test_series_overloads_partition():
235+
s = pd.Series(
236+
[
237+
"ap;pl;ep",
238+
"ban;an;ap",
239+
"Che;rr;yp",
240+
"DA;TEp",
241+
"eGGp;LANT;p",
242+
"12;3p",
243+
"23.45p",
244+
]
245+
)
246+
check(assert_type(s.str.partition(sep=";"), pd.DataFrame), pd.DataFrame)
247+
check(
248+
assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame
249+
)
250+
check(
251+
assert_type(s.str.partition(sep=";", expand=False), "pd.Series[type[object]]"),
252+
pd.Series,
253+
object,
254+
)
255+
256+
check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame)
257+
check(
258+
assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame
259+
)
260+
check(
261+
assert_type(s.str.rpartition(sep=";", expand=False), "pd.Series[type[object]]"),
262+
pd.Series,
263+
object,
264+
)
265+
266+
267+
def test_index_overloads_partition():
268+
idx = pd.Index(
269+
[
270+
"ap;pl;ep",
271+
"ban;an;ap",
272+
"Che;rr;yp",
273+
"DA;TEp",
274+
"eGGp;LANT;p",
275+
"12;3p",
276+
"23.45p",
277+
]
278+
)
279+
check(assert_type(idx.str.partition(sep=";"), pd.MultiIndex), pd.MultiIndex)
280+
check(
281+
assert_type(idx.str.partition(sep=";", expand=True), pd.MultiIndex),
282+
pd.MultiIndex,
283+
)
284+
check(
285+
assert_type(idx.str.partition(sep=";", expand=False), "pd.Index[type[object]]"),
286+
pd.Index,
287+
object,
288+
)
289+
290+
check(assert_type(idx.str.rpartition(sep=";"), pd.MultiIndex), pd.MultiIndex)
291+
check(
292+
assert_type(idx.str.rpartition(sep=";", expand=True), pd.MultiIndex),
293+
pd.MultiIndex,
294+
)
295+
check(
296+
assert_type(
297+
idx.str.rpartition(sep=";", expand=False), "pd.Index[type[object]]"
298+
),
299+
pd.Index,
300+
object,
301+
)

0 commit comments

Comments
 (0)