Skip to content

Commit b2d4657

Browse files
Dr-IrvMarcoGorelli
authored andcommitted
disallow .str on certain series types
1 parent 3dc660e commit b2d4657

File tree

3 files changed

+69
-20
lines changed

3 files changed

+69
-20
lines changed

pandas-stubs/core/frame.pyi

+8-5
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ from pandas.core.indexing import (
4444
_LocIndexer,
4545
)
4646
from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg
47-
from pandas.core.series import Series
47+
from pandas.core.series import (
48+
Series,
49+
UnknownSeries,
50+
)
4851
from pandas.core.window import (
4952
Expanding,
5053
ExponentialMovingWindow,
@@ -244,24 +247,24 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]):
244247
if sys.version_info >= (3, 12):
245248
class _GetItemHack:
246249
@overload
247-
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
250+
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> UnknownSeries: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
248251
@overload
249252
def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
250253
self, key: Iterable[Hashable] | slice
251254
) -> Self: ...
252255
@overload
253-
def __getitem__(self, key: Hashable) -> Series: ...
256+
def __getitem__(self, key: Hashable) -> UnknownSeries: ...
254257

255258
else:
256259
class _GetItemHack:
257260
@overload
258-
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
261+
def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> UnknownSeries: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload]
259262
@overload
260263
def __getitem__( # pyright: ignore[reportOverlappingOverload]
261264
self, key: Iterable[Hashable] | slice
262265
) -> Self: ...
263266
@overload
264-
def __getitem__(self, key: Hashable) -> Series: ...
267+
def __getitem__(self, key: Hashable) -> UnknownSeries: ...
265268

266269
class DataFrame(NDFrame, OpsMixin, _GetItemHack):
267270

pandas-stubs/core/series.pyi

+49-13
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,54 @@ class _LocIndexerSeries(_LocIndexer, Generic[S1]):
231231
value: S1 | ArrayLike | Series[S1] | None,
232232
) -> None: ...
233233

234+
class _StrMethods:
235+
@overload
236+
def __get__(self, instance: Series[str], owner: Any) -> StringMethods[
237+
Series[str],
238+
DataFrame,
239+
Series[bool],
240+
Series[list[str]],
241+
Series[int],
242+
Series[bytes],
243+
Series[str],
244+
Series[type[object]],
245+
]: ...
246+
@overload
247+
def __get__(self, instance: Series[bytes], owner: Any) -> StringMethods[
248+
Series[bytes],
249+
DataFrame,
250+
Series[bool],
251+
Series[list[str]],
252+
Series[int],
253+
Series[bytes],
254+
Series[str],
255+
Series[type[object]],
256+
]: ...
257+
@overload
258+
def __get__(self, instance: Series[list[str]], owner: Any) -> StringMethods[
259+
Series[list[str]],
260+
DataFrame,
261+
Series[bool],
262+
Series[list[str]],
263+
Series[int],
264+
Series[bytes],
265+
Series[str],
266+
Series[type[object]],
267+
]: ...
268+
@overload
269+
def __get__(self, instance: Series[S1], owner: Any) -> NoReturn: ...
270+
@overload
271+
def __get__(self, instance: UnknownSeries, owner: Any) -> StringMethods[
272+
Series,
273+
DataFrame,
274+
Series[bool],
275+
Series[list[str]],
276+
Series[int],
277+
Series[bytes],
278+
Series[str],
279+
Series[type[object]],
280+
]: ...
281+
234282
_ListLike: TypeAlias = (
235283
ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | IndexOpsMixin[S1]
236284
)
@@ -1153,19 +1201,7 @@ class Series(IndexOpsMixin[S1], NDFrame):
11531201
copy: _bool = ...,
11541202
) -> Series[S1]: ...
11551203
def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ...
1156-
@property
1157-
def str(
1158-
self,
1159-
) -> StringMethods[
1160-
Self,
1161-
DataFrame,
1162-
Series[bool],
1163-
Series[list[str]],
1164-
Series[int],
1165-
Series[bytes],
1166-
Series[str],
1167-
Series[type[object]],
1168-
]: ...
1204+
str: _StrMethods
11691205
@property
11701206
def dt(self) -> CombinedDatetimelikeProperties: ...
11711207
@property

tests/test_string_accessors.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
import pandas as pd
77
from typing_extensions import assert_type
88

9-
from tests import check
9+
from tests import (
10+
TYPE_CHECKING_INVALID_USAGE,
11+
check,
12+
)
1013

1114
# Separately define here so pytest works
1215
np_ndarray_bool = npt.NDArray[np.bool_]
@@ -139,7 +142,7 @@ def test_string_accessors_string_series():
139142
_check(assert_type(s.str.zfill(10), "pd.Series[str]"))
140143
s_bytes = pd.Series([b"a1", b"b2", b"c3"])
141144
_check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"))
142-
s_list = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]])
145+
s_list = pd.Series([["apple", "banana"], ["cherry", "date"], ["foo", "eggplant"]])
143146
_check(assert_type(s_list.str.join("-"), "pd.Series[str]"))
144147

145148

@@ -415,3 +418,10 @@ def test_index_overloads_extract():
415418
pd.Index,
416419
object,
417420
)
421+
422+
423+
def test_series_unknown():
424+
if TYPE_CHECKING_INVALID_USAGE:
425+
s = pd.Series([1, 2, 3])
426+
s.str.startswith("a") # type:ignore[attr-defined]
427+
s.str.slice(2, 4) # type:ignore[attr-defined]

0 commit comments

Comments
 (0)