From 0d48f36ac1b2d3854d56a3733e30232933e5e0c1 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:18:35 +0000 Subject: [PATCH 01/39] make typing in pandas_stubs.core.strings.pyi strict, add UnknownSeries and UnknownIndex --- pandas-stubs/_typing.pyi | 3 +++ pandas-stubs/core/strings.pyi | 16 ++++++++++------ pyproject.toml | 6 ++++++ 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 478f60da0..273f1dd54 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -576,6 +576,9 @@ S2 = TypeVar( | list[str], ) +UnknownSeries: TypeAlias = Series[Any] +UnknownIndex: TypeAlias = Index[Any] + IndexingInt: TypeAlias = ( int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8 ) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index c12851705..8f1352e49 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -11,7 +11,7 @@ from typing import ( overload, ) -import numpy as np +import numpy.typing as npt import pandas as pd from pandas import ( DataFrame, @@ -21,9 +21,13 @@ from pandas import ( ) from pandas.core.base import NoNewAttributesMixin +from pandas._libs.tslibs.nattype import NaTType from pandas._typing import ( JoinHow, + Scalar, T, + UnknownIndex, + UnknownSeries, np_ndarray_bool, ) @@ -58,7 +62,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): @overload def cat( self, - others: Series | pd.Index | pd.DataFrame | np.ndarray | list[Any], + others: UnknownIndex | pd.DataFrame | npt.NDArray[Any] | list[Any], sep: str = ..., na_rep: str | None = ..., join: JoinHow = ..., @@ -106,10 +110,10 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def join(self, sep: str) -> T: ... def contains( self, - pat: str | re.Pattern, + pat: str | re.Pattern[str], case: bool = ..., flags: int = ..., - na=..., + na: Scalar | NaTType | None = ..., regex: bool = ..., ) -> Series[bool]: ... def match( @@ -118,7 +122,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def replace( self, pat: str, - repl: str | Callable[[re.Match], str], + repl: str | Callable[[re.Match[str]], str], n: int = ..., case: bool | None = ..., flags: int = ..., @@ -160,7 +164,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def count(self, pat: str, flags: int = ...) -> Series[int]: ... def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... - def findall(self, pat: str, flags: int = ...) -> Series: ... + def findall(self, pat: str, flags: int = ...) -> UnknownSeries: ... @overload def extract( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... diff --git a/pyproject.toml b/pyproject.toml index c71bebaf6..060669351 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,6 +201,12 @@ show_error_context = false show_column_numbers = false show_error_codes = true +[[tool.mypy.overrides]] +module = [ + "pandas-stubs.core.strings.*", +] +strict = true + [tool.pyright] typeCheckingMode = "strict" stubPath = "." From ca10bf295849e29830d79f938dc72d37683120cf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 14:57:47 +0000 Subject: [PATCH 02/39] undo pyproject.toml changes --- pandas-stubs/core/strings.pyi | 2 +- pyproject.toml | 6 ------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 8f1352e49..e6947c321 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -110,7 +110,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def join(self, sep: str) -> T: ... def contains( self, - pat: str | re.Pattern[str], + pat: str | re.Pattern, case: bool = ..., flags: int = ..., na: Scalar | NaTType | None = ..., diff --git a/pyproject.toml b/pyproject.toml index 060669351..c71bebaf6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -201,12 +201,6 @@ show_error_context = false show_column_numbers = false show_error_codes = true -[[tool.mypy.overrides]] -module = [ - "pandas-stubs.core.strings.*", -] -strict = true - [tool.pyright] typeCheckingMode = "strict" stubPath = "." From 4b8183db34bd7d2aa1244e9ed76e5520ca451914 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:09:21 +0000 Subject: [PATCH 03/39] use class, use pyright: strict --- pandas-stubs/_typing.pyi | 4 ++-- pandas-stubs/core/indexes/base.pyi | 2 +- pandas-stubs/core/series.pyi | 2 +- pandas-stubs/core/strings.pyi | 3 ++- tests/test_series.py | 3 ++- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 273f1dd54..2c3a88b6b 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -576,8 +576,8 @@ S2 = TypeVar( | list[str], ) -UnknownSeries: TypeAlias = Series[Any] -UnknownIndex: TypeAlias = Index[Any] +class UnknownSeries(Series[Any]): ... +class UnknownIndex(Index[Any]): ... IndexingInt: TypeAlias = ( int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8 diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index a28ec8c1b..5ebdcd224 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -67,7 +67,7 @@ class Index(IndexOpsMixin[S1]): __hash__: ClassVar[None] # type: ignore[assignment] # overloads with additional dtypes @overload - def __new__( # pyright: ignore[reportOverlappingOverload] + def __new__( cls, data: Sequence[int | np.integer] | IndexOpsMixin[int] | np_ndarray_anyint, *, diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 85044ed7f..3987bdd8e 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -249,7 +249,7 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: bool = ..., ) -> Series[float]: ... @overload - def __new__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __new__( # type: ignore[overload-overlap] cls, data: Sequence[Never], index: Axes | None = ..., diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index e6947c321..c0ea70a85 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -1,3 +1,4 @@ +# pyright: strict from collections.abc import ( Callable, Sequence, @@ -110,7 +111,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def join(self, sep: str) -> T: ... def contains( self, - pat: str | re.Pattern, + pat: str | re.Pattern[str], case: bool = ..., flags: int = ..., na: Scalar | NaTType | None = ..., diff --git a/tests/test_series.py b/tests/test_series.py index c76fc9856..8e2b34929 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -49,6 +49,7 @@ from pandas._typing import ( DtypeObj, Scalar, + UnknownSeries, ) from tests import ( @@ -1594,7 +1595,7 @@ def test_string_accessors(): check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s.str.find("p"), pd.Series), pd.Series) - check(assert_type(s.str.findall("pp"), pd.Series), pd.Series) + check(assert_type(s.str.findall("pp"), UnknownSeries), pd.Series) check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(s.str.get(2), pd.Series), pd.Series) check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) From def6eea9954e9a8aa3f7423e6a87c25a38da33b9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:17:06 +0000 Subject: [PATCH 04/39] update pyright --- pandas-stubs/core/indexes/base.pyi | 2 +- pandas-stubs/core/series.pyi | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 5ebdcd224..a28ec8c1b 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -67,7 +67,7 @@ class Index(IndexOpsMixin[S1]): __hash__: ClassVar[None] # type: ignore[assignment] # overloads with additional dtypes @overload - def __new__( + def __new__( # pyright: ignore[reportOverlappingOverload] cls, data: Sequence[int | np.integer] | IndexOpsMixin[int] | np_ndarray_anyint, *, diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 3987bdd8e..87444e098 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -249,7 +249,7 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: bool = ..., ) -> Series[float]: ... @overload - def __new__( # type: ignore[overload-overlap] + def __new__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] cls, data: Sequence[Never], index: Axes | None = ..., From 9c5b33a32cdea4db1d83d8ffe829ac375a08d2a5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:20:27 +0000 Subject: [PATCH 05/39] reduce diff --- pandas-stubs/core/series.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 87444e098..85044ed7f 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -249,7 +249,7 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: bool = ..., ) -> Series[float]: ... @overload - def __new__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __new__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] cls, data: Sequence[Never], index: Axes | None = ..., From 9b63e3f0a11efad267bff52b2a3a412717a0f80a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:24:50 +0000 Subject: [PATCH 06/39] fixup --- tests/test_series.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_series.py b/tests/test_series.py index 8e2b34929..a032cd00c 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -49,7 +49,6 @@ from pandas._typing import ( DtypeObj, Scalar, - UnknownSeries, ) from tests import ( @@ -67,6 +66,7 @@ TimedeltaSeries, TimestampSeries, ) + else: TimedeltaSeries: TypeAlias = pd.Series TimestampSeries: TypeAlias = pd.Series @@ -1595,7 +1595,7 @@ def test_string_accessors(): check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s.str.find("p"), pd.Series), pd.Series) - check(assert_type(s.str.findall("pp"), UnknownSeries), pd.Series) + check(assert_type(s.str.findall("pp"), "UnknownSeries"), pd.Series) check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(s.str.get(2), pd.Series), pd.Series) check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) From fd6188aec2e823679bbdb9416c88ce9e3b9f988a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 15:27:39 +0000 Subject: [PATCH 07/39] fixup --- tests/test_series.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_series.py b/tests/test_series.py index a032cd00c..1afaa11b8 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -87,6 +87,7 @@ UIntDtypeArg, VoidDtypeArg, ) + from pandas._typing import UnknownSeries # noqa: F401 from pandas._typing import np_ndarray_int # noqa: F401 # Tests will use numpy 2.1 in python 3.10 or later From bcdd40eafb09648763b86e3ab902f0c90fa035cc Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 17:37:00 +0000 Subject: [PATCH 08/39] include UnknownSeries in str.cat --- pandas-stubs/core/strings.pyi | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index c0ea70a85..3a4131da2 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -63,7 +63,9 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): @overload def cat( self, - others: UnknownIndex | pd.DataFrame | npt.NDArray[Any] | list[Any], + others: ( + UnknownSeries | UnknownIndex | pd.DataFrame | npt.NDArray[Any] | list[Any] + ), sep: str = ..., na_rep: str | None = ..., join: JoinHow = ..., From dba1bdab9fede5827031256fde65374714c0eb4e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Thu, 6 Mar 2025 17:58:19 +0000 Subject: [PATCH 09/39] move UnknownSeries and UnknownIndex location --- pandas-stubs/_typing.pyi | 3 --- pandas-stubs/core/indexes/base.pyi | 2 ++ pandas-stubs/core/series.pyi | 2 ++ pandas-stubs/core/strings.pyi | 4 ++-- tests/test_series.py | 3 ++- 5 files changed, 8 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 2c3a88b6b..478f60da0 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -576,9 +576,6 @@ S2 = TypeVar( | list[str], ) -class UnknownSeries(Series[Any]): ... -class UnknownIndex(Index[Any]): ... - IndexingInt: TypeAlias = ( int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8 ) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index a28ec8c1b..843987739 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -455,6 +455,8 @@ class Index(IndexOpsMixin[S1]): ), ) -> Self: ... +class UnknownIndex(Index[Any]): ... + def ensure_index_from_sequences( sequences: Sequence[Sequence[Dtype]], names: list[str] = ... ) -> Index: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 85044ed7f..10b866452 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -2133,6 +2133,8 @@ class Series(IndexOpsMixin[S1], NDFrame): ) -> Self: ... def __bool__(self) -> NoReturn: ... +class UnknownSeries(Series[Any]): ... + class TimestampSeries(Series[Timestamp]): @property def dt(self) -> TimestampProperties: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 3a4131da2..18f42831d 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -21,14 +21,14 @@ from pandas import ( Series, ) from pandas.core.base import NoNewAttributesMixin +from pandas.core.indexes.base import UnknownIndex +from pandas.core.series import UnknownSeries from pandas._libs.tslibs.nattype import NaTType from pandas._typing import ( JoinHow, Scalar, T, - UnknownIndex, - UnknownSeries, np_ndarray_bool, ) diff --git a/tests/test_series.py b/tests/test_series.py index 1afaa11b8..1b395ee92 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -73,6 +73,8 @@ OffsetSeries: TypeAlias = pd.Series if TYPE_CHECKING: + from pandas.core.series import UnknownSeries # noqa: F401 + from pandas._typing import ( BooleanDtypeArg, BytesDtypeArg, @@ -87,7 +89,6 @@ UIntDtypeArg, VoidDtypeArg, ) - from pandas._typing import UnknownSeries # noqa: F401 from pandas._typing import np_ndarray_int # noqa: F401 # Tests will use numpy 2.1 in python 3.10 or later From 6a31e8749640c8f0cbf6f0adc4af5802a12c4acf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:14:01 +0000 Subject: [PATCH 10/39] use typealias --- pandas-stubs/_typing.pyi | 3 +++ pandas-stubs/core/series.pyi | 2 -- pandas-stubs/core/strings.pyi | 7 +++---- tests/test_series.py | 21 ++++++++++++++++----- 4 files changed, 22 insertions(+), 11 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 478f60da0..273f1dd54 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -576,6 +576,9 @@ S2 = TypeVar( | list[str], ) +UnknownSeries: TypeAlias = Series[Any] +UnknownIndex: TypeAlias = Index[Any] + IndexingInt: TypeAlias = ( int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8 ) diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 10b866452..85044ed7f 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -2133,8 +2133,6 @@ class Series(IndexOpsMixin[S1], NDFrame): ) -> Self: ... def __bool__(self) -> NoReturn: ... -class UnknownSeries(Series[Any]): ... - class TimestampSeries(Series[Timestamp]): @property def dt(self) -> TimestampProperties: ... # type: ignore[override] # pyright: ignore[reportIncompatibleMethodOverride] diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 18f42831d..f6bc547c5 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -12,6 +12,7 @@ from typing import ( overload, ) +import numpy as np import numpy.typing as npt import pandas as pd from pandas import ( @@ -21,8 +22,6 @@ from pandas import ( Series, ) from pandas.core.base import NoNewAttributesMixin -from pandas.core.indexes.base import UnknownIndex -from pandas.core.series import UnknownSeries from pandas._libs.tslibs.nattype import NaTType from pandas._typing import ( @@ -64,7 +63,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def cat( self, others: ( - UnknownSeries | UnknownIndex | pd.DataFrame | npt.NDArray[Any] | list[Any] + Series[str] | Index[str] | pd.DataFrame | npt.NDArray[np.str_] | list[str] ), sep: str = ..., na_rep: str | None = ..., @@ -167,7 +166,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def count(self, pat: str, flags: int = ...) -> Series[int]: ... def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... - def findall(self, pat: str, flags: int = ...) -> UnknownSeries: ... + def findall(self, pat: str, flags: int = ...) -> Series[list[str]]: ... @overload def extract( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... diff --git a/tests/test_series.py b/tests/test_series.py index 1b395ee92..25574f9f0 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -67,14 +67,15 @@ TimestampSeries, ) + from pandas._typing import UnknownSeries + else: TimedeltaSeries: TypeAlias = pd.Series TimestampSeries: TypeAlias = pd.Series OffsetSeries: TypeAlias = pd.Series + UnknownSeries: TypeAlias = pd.Series if TYPE_CHECKING: - from pandas.core.series import UnknownSeries # noqa: F401 - from pandas._typing import ( BooleanDtypeArg, BytesDtypeArg, @@ -91,6 +92,7 @@ ) from pandas._typing import np_ndarray_int # noqa: F401 + # Tests will use numpy 2.1 in python 3.10 or later # From Numpy 2.1 __init__.pyi _DTypeKind: TypeAlias = Literal[ @@ -1597,7 +1599,7 @@ def test_string_accessors(): check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s.str.find("p"), pd.Series), pd.Series) - check(assert_type(s.str.findall("pp"), "UnknownSeries"), pd.Series) + check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"), pd.Series, list) check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(s.str.get(2), pd.Series), pd.Series) check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) @@ -1668,8 +1670,17 @@ def test_series_overloads_cat(): check(assert_type(s.str.cat(sep=";"), str), str) check(assert_type(s.str.cat(None, sep=";"), str), str) check( - assert_type(s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), pd.Series), - pd.Series, + assert_type( + s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), UnknownSeries + ), + UnknownSeries, + ) + check( + assert_type( + s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), + UnknownSeries, + ), + UnknownSeries, ) From 5edf9827faa2b0f0c6100aa9558e43df609cd8f3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:21:48 +0000 Subject: [PATCH 11/39] use Series[str] as .cat return type --- pandas-stubs/core/strings.pyi | 2 +- tests/test_series.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index f6bc547c5..288031f85 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -68,7 +68,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): sep: str = ..., na_rep: str | None = ..., join: JoinHow = ..., - ) -> T: ... + ) -> Series[str]: ... @overload def split( self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ... diff --git a/tests/test_series.py b/tests/test_series.py index 25574f9f0..28e1f5f9a 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -66,14 +66,10 @@ TimedeltaSeries, TimestampSeries, ) - - from pandas._typing import UnknownSeries - else: TimedeltaSeries: TypeAlias = pd.Series TimestampSeries: TypeAlias = pd.Series OffsetSeries: TypeAlias = pd.Series - UnknownSeries: TypeAlias = pd.Series if TYPE_CHECKING: from pandas._typing import ( @@ -1671,16 +1667,18 @@ def test_series_overloads_cat(): check(assert_type(s.str.cat(None, sep=";"), str), str) check( assert_type( - s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), UnknownSeries + s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), "pd.Series[str]" ), - UnknownSeries, + pd.Series, + str, ) check( assert_type( s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), - UnknownSeries, + "pd.Series[str]", ), - UnknownSeries, + pd.Series, + str, ) From 9a47508c5f38a1642b7c84a94ee1e98f23df8e8e Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:23:58 +0000 Subject: [PATCH 12/39] use -> T so it matches other .str methods like .str.uppercase --- pandas-stubs/core/strings.pyi | 2 +- tests/test_series.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 288031f85..f6bc547c5 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -68,7 +68,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): sep: str = ..., na_rep: str | None = ..., join: JoinHow = ..., - ) -> Series[str]: ... + ) -> T: ... @overload def split( self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ... diff --git a/tests/test_series.py b/tests/test_series.py index 28e1f5f9a..8d04c5d1f 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -66,10 +66,13 @@ TimedeltaSeries, TimestampSeries, ) + + from pandas._typing import UnknownSeries else: TimedeltaSeries: TypeAlias = pd.Series TimestampSeries: TypeAlias = pd.Series OffsetSeries: TypeAlias = pd.Series + UnknownSeries: TypeAlias = pd.Series if TYPE_CHECKING: from pandas._typing import ( @@ -1667,18 +1670,16 @@ def test_series_overloads_cat(): check(assert_type(s.str.cat(None, sep=";"), str), str) check( assert_type( - s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), "pd.Series[str]" + s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), UnknownSeries ), - pd.Series, - str, + UnknownSeries, ) check( assert_type( s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), - "pd.Series[str]", + UnknownSeries, ), - pd.Series, - str, + UnknownSeries, ) From 0fabb998784e6391b006ecd344466c1291bbe300 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 10:41:30 +0000 Subject: [PATCH 13/39] use _TS2 for findall --- pandas-stubs/core/strings.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index f6bc547c5..5149ecc9f 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -166,7 +166,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def count(self, pat: str, flags: int = ...) -> Series[int]: ... def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... - def findall(self, pat: str, flags: int = ...) -> Series[list[str]]: ... + def findall(self, pat: str, flags: int = ...) -> _TS2: ... @overload def extract( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... From 427a70788f733b7c8c5cb3ae9b4d703cef39185c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 21:29:23 +0000 Subject: [PATCH 14/39] add test to cover passing UnknownSeries to cat --- pandas-stubs/_typing.pyi | 3 --- pandas-stubs/core/indexes/base.pyi | 3 ++- pandas-stubs/core/series.pyi | 2 ++ tests/test_series.py | 8 ++++++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 273f1dd54..478f60da0 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -576,9 +576,6 @@ S2 = TypeVar( | list[str], ) -UnknownSeries: TypeAlias = Series[Any] -UnknownIndex: TypeAlias = Index[Any] - IndexingInt: TypeAlias = ( int | np.int_ | np.integer | np.unsignedinteger | np.signedinteger | np.int8 ) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 843987739..22f2f2b0d 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -13,6 +13,7 @@ from typing import ( Any, ClassVar, Literal, + TypeAlias, final, overload, ) @@ -455,7 +456,7 @@ class Index(IndexOpsMixin[S1]): ), ) -> Self: ... -class UnknownIndex(Index[Any]): ... +UnknownIndex: TypeAlias = Index[Any] def ensure_index_from_sequences( sequences: Sequence[Sequence[Dtype]], names: list[str] = ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 85044ed7f..94c979427 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -2295,3 +2295,5 @@ class IntervalSeries(Series[Interval[_OrderableT]], Generic[_OrderableT]): @property def array(self) -> IntervalArray: ... def diff(self, periods: int = ...) -> Never: ... + +UnknownSeries: TypeAlias = Series[Any] diff --git a/tests/test_series.py b/tests/test_series.py index 8d04c5d1f..7129b8256 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -65,9 +65,8 @@ OffsetSeries, TimedeltaSeries, TimestampSeries, + UnknownSeries, ) - - from pandas._typing import UnknownSeries else: TimedeltaSeries: TypeAlias = pd.Series TimestampSeries: TypeAlias = pd.Series @@ -1681,6 +1680,11 @@ def test_series_overloads_cat(): ), UnknownSeries, ) + unknown_s: UnknownSeries = pd.DataFrame({"a": ["a", "b"]})["a"] + check( + assert_type(s.str.cat(unknown_s, sep=";"), UnknownSeries), + UnknownSeries, + ) def test_series_overloads_partition(): From de28385162723fa9e1f68e161adaad2bb2fe87bf Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 22:18:48 +0000 Subject: [PATCH 15/39] preserve type in series.str --- pandas-stubs/core/series.pyi | 2 +- pandas-stubs/core/strings.pyi | 40 +++++++++---- tests/test_series.py | 110 +++++++++++++++++++++------------- 3 files changed, 97 insertions(+), 55 deletions(-) diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 94c979427..7dc08514f 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1156,7 +1156,7 @@ class Series(IndexOpsMixin[S1], NDFrame): @property def str( self, - ) -> StringMethods[Series, DataFrame, Series[bool], Series[list[str]]]: ... + ) -> StringMethods[Series[S1], DataFrame, Series[bool], Series[list[str]]]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 5149ecc9f..8ecea6d18 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -95,9 +95,11 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): @overload def partition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ... @overload - def partition(self, sep: str, expand: Literal[False]) -> T: ... + def partition( + self, sep: str, expand: Literal[False] + ) -> pd.Series[type[object]]: ... @overload - def partition(self, *, expand: Literal[False]) -> T: ... + def partition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ... @overload def rpartition(self, sep: str = ...) -> pd.DataFrame: ... @overload @@ -105,9 +107,11 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): @overload def rpartition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ... @overload - def rpartition(self, sep: str, expand: Literal[False]) -> T: ... + def rpartition( + self, sep: str, expand: Literal[False] + ) -> pd.Series[type[object]]: ... @overload - def rpartition(self, *, expand: Literal[False]) -> T: ... + def rpartition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ... def get(self, i: int) -> T: ... def join(self, sep: str) -> T: ... def contains( @@ -147,8 +151,8 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def slice_replace( self, start: int | None = ..., stop: int | None = ..., repl: str | None = ... ) -> T: ... - def decode(self, encoding: str, errors: str = ...) -> T: ... - def encode(self, encoding: str, errors: str = ...) -> T: ... + def decode(self, encoding: str, errors: str = ...) -> Series[str]: ... + def encode(self, encoding: str, errors: str = ...) -> Series[bytes]: ... def strip(self, to_strip: str | None = ...) -> T: ... def lstrip(self, to_strip: str | None = ...) -> T: ... def rstrip(self, to_strip: str | None = ...) -> T: ... @@ -172,15 +176,27 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): self, pat: str, flags: int = ..., *, expand: Literal[True] = ... ) -> pd.DataFrame: ... @overload - def extract(self, pat: str, flags: int, expand: Literal[False]) -> T: ... + def extract( + self, pat: str, flags: int, expand: Literal[False] + ) -> Series[type[object]]: ... @overload - def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> T: ... + def extract( + self, pat: str, flags: int = ..., *, expand: Literal[False] + ) -> Series[type[object]]: ... def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... - def find(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... - def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... + def find( + self, sub: str, start: int = ..., end: int | None = ... + ) -> Series[int]: ... + def rfind( + self, sub: str, start: int = ..., end: int | None = ... + ) -> Series[int]: ... def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ... - def index(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... - def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> T: ... + def index( + self, sub: str, start: int = ..., end: int | None = ... + ) -> Series[int]: ... + def rindex( + self, sub: str, start: int = ..., end: int | None = ... + ) -> Series[int]: ... def len(self) -> Series[int]: ... def lower(self) -> T: ... def upper(self) -> T: ... diff --git a/tests/test_series.py b/tests/test_series.py index 7129b8256..3a782ed82 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1577,10 +1577,11 @@ def test_string_accessors(): ) s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) s3 = pd.Series(["a1", "b2", "c3"]) - check(assert_type(s.str.capitalize(), pd.Series), pd.Series) - check(assert_type(s.str.casefold(), pd.Series), pd.Series) + s4 = pd.Series([b"a1", b"b2", b"c3"]) + check(assert_type(s.str.capitalize(), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.casefold(), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.cat(sep="X"), str), str) - check(assert_type(s.str.center(10), pd.Series), pd.Series) + check(assert_type(s.str.center(10), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_) check( assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"), @@ -1588,20 +1589,20 @@ def test_string_accessors(): np.bool_, ) check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, np.integer) - check(assert_type(s.str.decode("utf-8"), pd.Series), pd.Series) - check(assert_type(s.str.encode("latin-1"), pd.Series), pd.Series) + check(assert_type(s4.str.decode("utf-8"), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes) check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_) check( assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_ ) check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.find("p"), pd.Series), pd.Series) + check(assert_type(s.str.find("p"), "pd.Series[int]"), pd.Series, np.int64) check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"), pd.Series, list) check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.get(2), pd.Series), pd.Series) + check(assert_type(s.str.get(2), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.index("p"), pd.Series), pd.Series) + check(assert_type(s.str.index("p"), "pd.Series[int]"), pd.Series, np.int64) check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_) @@ -1613,20 +1614,20 @@ def test_string_accessors(): check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_) check(assert_type(s2.str.join("-"), pd.Series), pd.Series) check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, np.integer) - check(assert_type(s.str.ljust(80), pd.Series), pd.Series) - check(assert_type(s.str.lower(), pd.Series), pd.Series) - check(assert_type(s.str.lstrip("a"), pd.Series), pd.Series) + check(assert_type(s.str.ljust(80), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.lower(), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.lstrip("a"), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.normalize("NFD"), pd.Series), pd.Series) - check(assert_type(s.str.pad(80, "right"), pd.Series), pd.Series) + check(assert_type(s.str.normalize("NFD"), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.pad(80, "right"), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.removeprefix("a"), pd.Series), pd.Series) - check(assert_type(s.str.removesuffix("e"), pd.Series), pd.Series) - check(assert_type(s.str.repeat(2), pd.Series), pd.Series) - check(assert_type(s.str.replace("a", "X"), pd.Series), pd.Series) - check(assert_type(s.str.rfind("e"), pd.Series), pd.Series) - check(assert_type(s.str.rindex("p"), pd.Series), pd.Series) - check(assert_type(s.str.rjust(80), pd.Series), pd.Series) + check(assert_type(s.str.removeprefix("a"), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.repeat(2), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.rfind("e"), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.str.rindex("p"), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.str.rjust(80), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame) check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list) check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame) @@ -1635,9 +1636,11 @@ def test_string_accessors(): pd.Series, list, ) - check(assert_type(s.str.rstrip(), pd.Series), pd.Series) - check(assert_type(s.str.slice(0, 4, 2), pd.Series), pd.Series) - check(assert_type(s.str.slice_replace(0, 2, "XX"), pd.Series), pd.Series) + check(assert_type(s.str.rstrip(), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.slice(0, 4, 2), "pd.Series[str]"), pd.Series, str) + check( + assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"), pd.Series, str + ) check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list) # GH 194 check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) @@ -1652,13 +1655,19 @@ def test_string_accessors(): pd.Series, np.bool_, ) - check(assert_type(s.str.strip(), pd.Series), pd.Series) - check(assert_type(s.str.swapcase(), pd.Series), pd.Series) - check(assert_type(s.str.title(), pd.Series), pd.Series) - check(assert_type(s.str.translate(None), pd.Series), pd.Series) - check(assert_type(s.str.upper(), pd.Series), pd.Series) - check(assert_type(s.str.wrap(80), pd.Series), pd.Series) - check(assert_type(s.str.zfill(10), pd.Series), pd.Series) + check(assert_type(s.str.strip(), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.swapcase(), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.title(), "pd.Series[str]"), pd.Series, str) + check( + assert_type( + s.str.translate(str.maketrans({"ñ": "n", "ç": "c"})), "pd.Series[str]" + ), + pd.Series, + str, + ) + check(assert_type(s.str.upper(), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.wrap(80), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.str.zfill(10), "pd.Series[str]"), pd.Series, str) def test_series_overloads_cat(): @@ -1669,22 +1678,22 @@ def test_series_overloads_cat(): check(assert_type(s.str.cat(None, sep=";"), str), str) check( assert_type( - s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), UnknownSeries + s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), + "pd.Series[str]", ), - UnknownSeries, + pd.Series, + str, ) check( assert_type( s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), - UnknownSeries, + "pd.Series[str]", ), - UnknownSeries, + pd.Series, + str, ) unknown_s: UnknownSeries = pd.DataFrame({"a": ["a", "b"]})["a"] - check( - assert_type(s.str.cat(unknown_s, sep=";"), UnknownSeries), - UnknownSeries, - ) + check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str) def test_series_overloads_partition(): @@ -1703,13 +1712,21 @@ def test_series_overloads_partition(): check( assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame ) - check(assert_type(s.str.partition(sep=";", expand=False), pd.Series), pd.Series) + check( + assert_type(s.str.partition(sep=";", expand=False), "pd.Series[type[object]]"), + pd.Series, + object, + ) check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame) check( assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame ) - check(assert_type(s.str.rpartition(sep=";", expand=False), pd.Series), pd.Series) + check( + assert_type(s.str.rpartition(sep=";", expand=False), "pd.Series[type[object]]"), + pd.Series, + object, + ) def test_series_overloads_extract(): @@ -1720,10 +1737,19 @@ def test_series_overloads_extract(): check( assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame ) - check(assert_type(s.str.extract(r"[ab](\d)", expand=False), pd.Series), pd.Series) check( - assert_type(s.str.extract(r"[ab](\d)", re.IGNORECASE, False), pd.Series), + assert_type( + s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]" + ), pd.Series, + object, + ) + check( + assert_type( + s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]" + ), + pd.Series, + object, ) From e40d245a7e2c1cc10135f08e39a3df3f355001e2 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 22:24:30 +0000 Subject: [PATCH 16/39] simplify --- pandas-stubs/core/series.pyi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 7dc08514f..e6203e806 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1156,7 +1156,7 @@ class Series(IndexOpsMixin[S1], NDFrame): @property def str( self, - ) -> StringMethods[Series[S1], DataFrame, Series[bool], Series[list[str]]]: ... + ) -> StringMethods[Self, DataFrame, Series[bool], Series[list[str]]]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property From 92dc75ddaf452f30d682cf2103f3168755baa0c6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 22:30:13 +0000 Subject: [PATCH 17/39] use Mapping instead of dict as it is invariant --- tests/test_series.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_series.py b/tests/test_series.py index 3a782ed82..3abde8853 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -3,6 +3,7 @@ from collections.abc import ( Iterable, Iterator, + Mapping, Sequence, ) import datetime @@ -1658,10 +1659,9 @@ def test_string_accessors(): check(assert_type(s.str.strip(), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.swapcase(), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.title(), "pd.Series[str]"), pd.Series, str) + translation_table: Mapping = str.maketrans({"ñ": "n", "ç": "c"}) check( - assert_type( - s.str.translate(str.maketrans({"ñ": "n", "ç": "c"})), "pd.Series[str]" - ), + assert_type(s.str.translate(translation_table), "pd.Series[str]"), pd.Series, str, ) From 231b54d594b70875496ddc7d96251e91262580d0 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Fri, 7 Mar 2025 22:45:07 +0000 Subject: [PATCH 18/39] fixup --- tests/test_series.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_series.py b/tests/test_series.py index 3abde8853..008638afe 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -3,7 +3,6 @@ from collections.abc import ( Iterable, Iterator, - Mapping, Sequence, ) import datetime @@ -1659,9 +1658,8 @@ def test_string_accessors(): check(assert_type(s.str.strip(), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.swapcase(), "pd.Series[str]"), pd.Series, str) check(assert_type(s.str.title(), "pd.Series[str]"), pd.Series, str) - translation_table: Mapping = str.maketrans({"ñ": "n", "ç": "c"}) check( - assert_type(s.str.translate(translation_table), "pd.Series[str]"), + assert_type(s.str.translate({241: "n"}), "pd.Series[str]"), pd.Series, str, ) From 45b8da09d09f2598a8742bc87bb501e82b7f4a41 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 10:39:07 +0000 Subject: [PATCH 19/39] split out into separate file --- pandas-stubs/core/indexes/base.pyi | 10 +- pandas-stubs/core/series.pyi | 10 +- pandas-stubs/core/strings.pyi | 60 ++++++------ test | 2 + tests/test_string_accessors.py | 142 +++++++++++++++++++++++++++++ 5 files changed, 191 insertions(+), 33 deletions(-) create mode 100644 test create mode 100644 tests/test_string_accessors.py diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 22f2f2b0d..9af84b9ce 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -264,7 +264,15 @@ class Index(IndexOpsMixin[S1]): @property def str( self, - ) -> StringMethods[Self, MultiIndex, np_ndarray_bool, Index[list[str]]]: ... + ) -> StringMethods[ + Self, + MultiIndex, + np_ndarray_bool, + Index[list[str]], + Index[int], + Index[bytes], + Index[str], + ]: ... def is_(self, other) -> bool: ... def __len__(self) -> int: ... def __array__(self, dtype=...) -> np.ndarray: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index e6203e806..dff3dab1f 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1156,7 +1156,15 @@ class Series(IndexOpsMixin[S1], NDFrame): @property def str( self, - ) -> StringMethods[Self, DataFrame, Series[bool], Series[list[str]]]: ... + ) -> StringMethods[ + Self, + DataFrame, + Series[bool], + Series[list[str]], + Series[int], + Series[bytes], + Series[str], + ]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 8ecea6d18..7b4eae71c 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -37,8 +37,14 @@ _TS = TypeVar("_TS", bound=DataFrame | MultiIndex) _TS2 = TypeVar("_TS2", bound=Series[list[str]] | Index[list[str]]) # The _TM type is what is used for the result of str.match _TM = TypeVar("_TM", bound=Series[bool] | np_ndarray_bool) +# The _TI type is what is used for the result of str.index / str.find +_TI = TypeVar("_TI", bound=Series[int] | Index[int]) +# The _TE type is what is used for the result of str.encode +_TE = TypeVar("_TE", bound=Series[bytes] | Index[bytes]) +# The _TD type is what is used for the result of str.encode +_TD = TypeVar("_TD", bound=Series[str] | Index[str]) -class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): +class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _TD]): def __init__(self, data: T) -> None: ... def __getitem__(self, key: slice | int) -> T: ... def __iter__(self) -> T: ... @@ -113,7 +119,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): @overload def rpartition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ... def get(self, i: int) -> T: ... - def join(self, sep: str) -> T: ... + def join(self, sep: str) -> _TD: ... def contains( self, pat: str | re.Pattern[str], @@ -121,7 +127,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): flags: int = ..., na: Scalar | NaTType | None = ..., regex: bool = ..., - ) -> Series[bool]: ... + ) -> _TM: ... def match( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... ) -> _TM: ... @@ -151,8 +157,8 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): def slice_replace( self, start: int | None = ..., stop: int | None = ..., repl: str | None = ... ) -> T: ... - def decode(self, encoding: str, errors: str = ...) -> Series[str]: ... - def encode(self, encoding: str, errors: str = ...) -> Series[bytes]: ... + def decode(self, encoding: str, errors: str = ...) -> _TD: ... + def encode(self, encoding: str, errors: str = ...) -> _TE: ... def strip(self, to_strip: str | None = ...) -> T: ... def lstrip(self, to_strip: str | None = ...) -> T: ... def rstrip(self, to_strip: str | None = ...) -> T: ... @@ -167,9 +173,9 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): ) -> T: ... def get_dummies(self, sep: str = ...) -> pd.DataFrame: ... def translate(self, table: dict[int, int | str | None] | None) -> T: ... - def count(self, pat: str, flags: int = ...) -> Series[int]: ... - def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... - def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> Series[bool]: ... + def count(self, pat: str, flags: int = ...) -> _TI: ... + def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _TM: ... + def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _TM: ... def findall(self, pat: str, flags: int = ...) -> _TS2: ... @overload def extract( @@ -184,37 +190,29 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2]): self, pat: str, flags: int = ..., *, expand: Literal[False] ) -> Series[type[object]]: ... def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... - def find( - self, sub: str, start: int = ..., end: int | None = ... - ) -> Series[int]: ... - def rfind( - self, sub: str, start: int = ..., end: int | None = ... - ) -> Series[int]: ... + def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... + def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ... - def index( - self, sub: str, start: int = ..., end: int | None = ... - ) -> Series[int]: ... - def rindex( - self, sub: str, start: int = ..., end: int | None = ... - ) -> Series[int]: ... - def len(self) -> Series[int]: ... + def index(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... + def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... + def len(self) -> _TI: ... def lower(self) -> T: ... def upper(self) -> T: ... def title(self) -> T: ... def capitalize(self) -> T: ... def swapcase(self) -> T: ... def casefold(self) -> T: ... - def isalnum(self) -> Series[bool]: ... - def isalpha(self) -> Series[bool]: ... - def isdigit(self) -> Series[bool]: ... - def isspace(self) -> Series[bool]: ... - def islower(self) -> Series[bool]: ... - def isupper(self) -> Series[bool]: ... - def istitle(self) -> Series[bool]: ... - def isnumeric(self) -> Series[bool]: ... - def isdecimal(self) -> Series[bool]: ... + def isalnum(self) -> _TM: ... + def isalpha(self) -> _TM: ... + def isdigit(self) -> _TM: ... + def isspace(self) -> _TM: ... + def islower(self) -> _TM: ... + def isupper(self) -> _TM: ... + def istitle(self) -> _TM: ... + def isnumeric(self) -> _TM: ... + def isdecimal(self) -> _TM: ... def fullmatch( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... - ) -> Series[bool]: ... + ) -> _TM: ... def removeprefix(self, prefix: str) -> T: ... def removesuffix(self, suffix: str) -> T: ... diff --git a/test b/test new file mode 100644 index 000000000..e446f3ac8 --- /dev/null +++ b/test @@ -0,0 +1,2 @@ + test +ind abc \ No newline at end of file diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py new file mode 100644 index 000000000..6bcfff37b --- /dev/null +++ b/tests/test_string_accessors.py @@ -0,0 +1,142 @@ +import functools +import re +from typing import Any + +import numpy as np +import pandas as pd +import pytest +from typing_extensions import assert_type + +from tests import check + + +@pytest.mark.parametrize("constructor", ["series", "index"]) +@pytest.mark.parametrize( + ("method", "kwargs"), + [ + ("capitalize", {}), + ], +) +def test_string_accessors_type_preserving_series( + constructor: Any, method: str, kwargs: Any +) -> None: + data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + s = pd.Series(data) + _check = functools.partial(check, klass=pd.Series, dtype=str) + _check(assert_type(s.str.capitalize(), "pd.Series[str]")) + _check(assert_type(s.str.casefold(), "pd.Series[str]")) + check(assert_type(s.str.cat(sep="X"), str), str) + _check(assert_type(s.str.center(10), "pd.Series[str]")) + _check(assert_type(s.str.get(2), "pd.Series[str]")) + _check(assert_type(s.str.ljust(80), "pd.Series[str]")) + _check(assert_type(s.str.lower(), "pd.Series[str]")) + _check(assert_type(s.str.lstrip("a"), "pd.Series[str]")) + _check(assert_type(s.str.normalize("NFD"), "pd.Series[str]")) + _check(assert_type(s.str.pad(80, "right"), "pd.Series[str]")) + _check(assert_type(s.str.removeprefix("a"), "pd.Series[str]")) + _check(assert_type(s.str.removesuffix("e"), "pd.Series[str]")) + _check(assert_type(s.str.repeat(2), "pd.Series[str]")) + _check(assert_type(s.str.replace("a", "X"), "pd.Series[str]")) + _check(assert_type(s.str.rjust(80), "pd.Series[str]")) + _check(assert_type(s.str.rstrip(), "pd.Series[str]")) + _check(assert_type(s.str.slice(0, 4, 2), "pd.Series[str]")) + _check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]")) + _check(assert_type(s.str.strip(), "pd.Series[str]")) + _check(assert_type(s.str.swapcase(), "pd.Series[str]")) + _check(assert_type(s.str.title(), "pd.Series[str]")) + _check( + assert_type(s.str.translate({241: "n"}), "pd.Series[str]"), + ) + _check(assert_type(s.str.upper(), "pd.Series[str]")) + _check(assert_type(s.str.wrap(80), "pd.Series[str]")) + _check(assert_type(s.str.zfill(10), "pd.Series[str]")) + + +def test_string_accessors_type_boolean(): + s = pd.Series( + ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + ) + check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_) + check( + assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"), + pd.Series, + np.bool_, + ) + check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_) + check( + assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"), + pd.Series, + np.bool_, + ) + check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_) + check( + assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_ + ) + check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.isdigit(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.isnumeric(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.islower(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.isspace(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.istitle(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_) + + +def test_string_accessors_type_integer(): + s = pd.Series( + ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + ) + check(assert_type(s.str.find("p"), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.str.index("p"), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.str.rfind("e"), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.str.rindex("p"), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, np.integer) + check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, np.integer) + + +def test_string_accessors_encode_decode(): + s_str = pd.Series(["a1", "b2", "c3"]) + s_bytes = pd.Series([b"a1", b"b2", b"c3"]) + s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + check( + assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"), + "pd.Series[str]", + str, + ) + check( + assert_type(s_str.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes + ) + check(assert_type(s2.str.join("-"), "pd.Series[str]"), pd.Series, str) + + +def test_string_accessors_list(): + s = pd.Series( + ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + ) + check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"), pd.Series, list) + check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list) + # GH 194 + check( + assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"), + pd.Series, + list, + ) + check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list) + check( + assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"), + pd.Series, + list, + ) + + +# def test_string_accessors_expanding(): +# check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) +# check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) +# check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) +# check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame) +# check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame) +# check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame) +# check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) From 385b1bd4f60fcd6476223f534708b2799e5a72f3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 10:40:40 +0000 Subject: [PATCH 20/39] split out into separate file --- tests/test_string_accessors.py | 46 +++++++++++++++++++++++++--------- 1 file changed, 34 insertions(+), 12 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 6bcfff37b..e05956b6d 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -1,25 +1,14 @@ import functools import re -from typing import Any import numpy as np import pandas as pd -import pytest from typing_extensions import assert_type from tests import check -@pytest.mark.parametrize("constructor", ["series", "index"]) -@pytest.mark.parametrize( - ("method", "kwargs"), - [ - ("capitalize", {}), - ], -) -def test_string_accessors_type_preserving_series( - constructor: Any, method: str, kwargs: Any -) -> None: +def test_string_accessors_type_preserving_series() -> None: data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] s = pd.Series(data) _check = functools.partial(check, klass=pd.Series, dtype=str) @@ -52,6 +41,39 @@ def test_string_accessors_type_preserving_series( _check(assert_type(s.str.zfill(10), "pd.Series[str]")) +def test_string_accessors_type_preserving_index() -> None: + data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + idx = pd.Index(data) + _check = functools.partial(check, klass=pd.Index, dtype=str) + _check(assert_type(idx.str.capitalize(), "pd.Index[str]")) + _check(assert_type(idx.str.casefold(), "pd.Index[str]")) + check(assert_type(idx.str.cat(sep="X"), str), str) + _check(assert_type(idx.str.center(10), "pd.Index[str]")) + _check(assert_type(idx.str.get(2), "pd.Index[str]")) + _check(assert_type(idx.str.ljust(80), "pd.Index[str]")) + _check(assert_type(idx.str.lower(), "pd.Index[str]")) + _check(assert_type(idx.str.lstrip("a"), "pd.Index[str]")) + _check(assert_type(idx.str.normalize("NFD"), "pd.Index[str]")) + _check(assert_type(idx.str.pad(80, "right"), "pd.Index[str]")) + _check(assert_type(idx.str.removeprefix("a"), "pd.Index[str]")) + _check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]")) + _check(assert_type(idx.str.repeat(2), "pd.Index[str]")) + _check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]")) + _check(assert_type(idx.str.rjust(80), "pd.Index[str]")) + _check(assert_type(idx.str.rstrip(), "pd.Index[str]")) + _check(assert_type(idx.str.slice(0, 4, 2), "pd.Index[str]")) + _check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]")) + _check(assert_type(idx.str.strip(), "pd.Index[str]")) + _check(assert_type(idx.str.swapcase(), "pd.Index[str]")) + _check(assert_type(idx.str.title(), "pd.Index[str]")) + _check( + assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"), + ) + _check(assert_type(idx.str.upper(), "pd.Index[str]")) + _check(assert_type(idx.str.wrap(80), "pd.Index[str]")) + _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) + + def test_string_accessors_type_boolean(): s = pd.Series( ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] From 412b1ab90fff71cf314c908700695237c6465654 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 10:45:23 +0000 Subject: [PATCH 21/39] type check boolean return values --- tests/test_string_accessors.py | 79 ++++++++++++++++++++++------------ 1 file changed, 52 insertions(+), 27 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index e05956b6d..54ce50525 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -74,37 +74,62 @@ def test_string_accessors_type_preserving_index() -> None: _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) -def test_string_accessors_type_boolean(): - s = pd.Series( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_) - check( +def test_string_accessors_type_boolean_series(): + data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + s = pd.Series(data) + _check = functools.partial(check, klass=pd.Series, dtype=bool) + _check(assert_type(s.str.startswith("a"), "pd.Series[bool]")) + _check( assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"), - pd.Series, - np.bool_, ) - check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_) - check( + _check( + assert_type(s.str.contains("a"), "pd.Series[bool]"), + ) + _check( assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"), - pd.Series, - np.bool_, ) - check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_) - check( - assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_ + _check(assert_type(s.str.endswith("e"), "pd.Series[bool]")) + _check(assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]")) + _check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]")) + _check(assert_type(s.str.isalnum(), "pd.Series[bool]")) + _check(assert_type(s.str.isalpha(), "pd.Series[bool]")) + _check(assert_type(s.str.isdecimal(), "pd.Series[bool]")) + _check(assert_type(s.str.isdigit(), "pd.Series[bool]")) + _check(assert_type(s.str.isnumeric(), "pd.Series[bool]")) + _check(assert_type(s.str.islower(), "pd.Series[bool]")) + _check(assert_type(s.str.isspace(), "pd.Series[bool]")) + _check(assert_type(s.str.istitle(), "pd.Series[bool]")) + _check(assert_type(s.str.isupper(), "pd.Series[bool]")) + _check(assert_type(s.str.match("pp"), "pd.Series[bool]")) + + +def test_string_accessors_type_boolean_index(): + data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + idx = pd.Index(data) + _check = functools.partial(check, klass=np.ndarray, dtype=np.bool_) + _check(assert_type(idx.str.startswith("a"), "npt.NDArray[np.bool_]")) + _check( + assert_type(idx.str.startswith(("a", "b")), "npt.NDArray[np.bool_]"), ) - check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isdigit(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isnumeric(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.islower(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isspace(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.istitle(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_) + _check( + assert_type(idx.str.contains("a"), "npt.NDArray[np.bool_]"), + ) + _check( + assert_type(idx.str.contains(re.compile(r"a")), "npt.NDArray[np.bool_]"), + ) + _check(assert_type(idx.str.endswith("e"), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.endswith(("e", "f")), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.fullmatch("apple"), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.isalnum(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.isalpha(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.isdecimal(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.isdigit(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.isnumeric(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.islower(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.isspace(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.istitle(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.isupper(), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.match("pp"), "npt.NDArray[np.bool_]")) def test_string_accessors_type_integer(): @@ -125,7 +150,7 @@ def test_string_accessors_encode_decode(): s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) check( assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"), - "pd.Series[str]", + pd.Series, str, ) check( From 2463ce977f7692bf4f2ce0b85bb53ddc8abb3b53 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:05:41 +0000 Subject: [PATCH 22/39] integer return type --- tests/test_string_accessors.py | 44 ++++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 54ce50525..0c267669e 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -7,10 +7,11 @@ from tests import check +DATA = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] + def test_string_accessors_type_preserving_series() -> None: - data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - s = pd.Series(data) + s = pd.Series(DATA) _check = functools.partial(check, klass=pd.Series, dtype=str) _check(assert_type(s.str.capitalize(), "pd.Series[str]")) _check(assert_type(s.str.casefold(), "pd.Series[str]")) @@ -42,8 +43,7 @@ def test_string_accessors_type_preserving_series() -> None: def test_string_accessors_type_preserving_index() -> None: - data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - idx = pd.Index(data) + idx = pd.Index(DATA) _check = functools.partial(check, klass=pd.Index, dtype=str) _check(assert_type(idx.str.capitalize(), "pd.Index[str]")) _check(assert_type(idx.str.casefold(), "pd.Index[str]")) @@ -75,8 +75,7 @@ def test_string_accessors_type_preserving_index() -> None: def test_string_accessors_type_boolean_series(): - data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - s = pd.Series(data) + s = pd.Series(DATA) _check = functools.partial(check, klass=pd.Series, dtype=bool) _check(assert_type(s.str.startswith("a"), "pd.Series[bool]")) _check( @@ -104,8 +103,7 @@ def test_string_accessors_type_boolean_series(): def test_string_accessors_type_boolean_index(): - data = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - idx = pd.Index(data) + idx = pd.Index(DATA) _check = functools.partial(check, klass=np.ndarray, dtype=np.bool_) _check(assert_type(idx.str.startswith("a"), "npt.NDArray[np.bool_]")) _check( @@ -132,16 +130,26 @@ def test_string_accessors_type_boolean_index(): _check(assert_type(idx.str.match("pp"), "npt.NDArray[np.bool_]")) -def test_string_accessors_type_integer(): - s = pd.Series( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - check(assert_type(s.str.find("p"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.index("p"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.rfind("e"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.rindex("p"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, np.integer) - check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, np.integer) +def test_string_accessors_type_integer_series(): + s = pd.Series(DATA) + _check = functools.partial(check, klass=pd.Series, dtype=np.integer) + _check(assert_type(s.str.find("p"), "pd.Series[int]")) + _check(assert_type(s.str.index("p"), "pd.Series[int]")) + _check(assert_type(s.str.rfind("e"), "pd.Series[int]")) + _check(assert_type(s.str.rindex("p"), "pd.Series[int]")) + _check(assert_type(s.str.count("pp"), "pd.Series[int]")) + _check(assert_type(s.str.len(), "pd.Series[int]")) + + +def test_string_accessors_type_integer_index(): + idx = pd.Index(DATA) + _check = functools.partial(check, klass=pd.Index, dtype=np.integer) + _check(assert_type(idx.str.find("p"), "pd.Index[int]")) + _check(assert_type(idx.str.index("p"), "pd.Index[int]")) + _check(assert_type(idx.str.rfind("e"), "pd.Index[int]")) + _check(assert_type(idx.str.rindex("p"), "pd.Index[int]")) + _check(assert_type(idx.str.count("pp"), "pd.Index[int]")) + _check(assert_type(idx.str.len(), "pd.Index[int]")) def test_string_accessors_encode_decode(): From b0cade629da45b868d016cf06a72e8eed55196cd Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:08:06 +0000 Subject: [PATCH 23/39] integer return type --- tests/test_string_accessors.py | 39 +++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 0c267669e..b9ebfe985 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -2,11 +2,16 @@ import re import numpy as np +import numpy.typing as npt import pandas as pd from typing_extensions import assert_type from tests import check +# Separately define here so pytest works +np_ndarray_bool = npt.NDArray[np.bool_] + + DATA = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] @@ -105,29 +110,29 @@ def test_string_accessors_type_boolean_series(): def test_string_accessors_type_boolean_index(): idx = pd.Index(DATA) _check = functools.partial(check, klass=np.ndarray, dtype=np.bool_) - _check(assert_type(idx.str.startswith("a"), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.startswith("a"), np_ndarray_bool)) _check( - assert_type(idx.str.startswith(("a", "b")), "npt.NDArray[np.bool_]"), + assert_type(idx.str.startswith(("a", "b")), np_ndarray_bool), ) _check( - assert_type(idx.str.contains("a"), "npt.NDArray[np.bool_]"), + assert_type(idx.str.contains("a"), np_ndarray_bool), ) _check( - assert_type(idx.str.contains(re.compile(r"a")), "npt.NDArray[np.bool_]"), + assert_type(idx.str.contains(re.compile(r"a")), np_ndarray_bool), ) - _check(assert_type(idx.str.endswith("e"), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.endswith(("e", "f")), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.fullmatch("apple"), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.isalnum(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.isalpha(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.isdecimal(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.isdigit(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.isnumeric(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.islower(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.isspace(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.istitle(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.isupper(), "npt.NDArray[np.bool_]")) - _check(assert_type(idx.str.match("pp"), "npt.NDArray[np.bool_]")) + _check(assert_type(idx.str.endswith("e"), np_ndarray_bool)) + _check(assert_type(idx.str.endswith(("e", "f")), np_ndarray_bool)) + _check(assert_type(idx.str.fullmatch("apple"), np_ndarray_bool)) + _check(assert_type(idx.str.isalnum(), np_ndarray_bool)) + _check(assert_type(idx.str.isalpha(), np_ndarray_bool)) + _check(assert_type(idx.str.isdecimal(), np_ndarray_bool)) + _check(assert_type(idx.str.isdigit(), np_ndarray_bool)) + _check(assert_type(idx.str.isnumeric(), np_ndarray_bool)) + _check(assert_type(idx.str.islower(), np_ndarray_bool)) + _check(assert_type(idx.str.isspace(), np_ndarray_bool)) + _check(assert_type(idx.str.istitle(), np_ndarray_bool)) + _check(assert_type(idx.str.isupper(), np_ndarray_bool)) + _check(assert_type(idx.str.match("pp"), np_ndarray_bool)) def test_string_accessors_type_integer_series(): From 29710a4eb9e4cacbe599442002911c46f87f065b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:13:55 +0000 Subject: [PATCH 24/39] strings and bytes --- tests/test_string_accessors.py | 43 +++++++++++++++++++++------------- 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index b9ebfe985..1571b7f96 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -79,7 +79,7 @@ def test_string_accessors_type_preserving_index() -> None: _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) -def test_string_accessors_type_boolean_series(): +def test_string_accessors_boolean_series(): s = pd.Series(DATA) _check = functools.partial(check, klass=pd.Series, dtype=bool) _check(assert_type(s.str.startswith("a"), "pd.Series[bool]")) @@ -107,7 +107,7 @@ def test_string_accessors_type_boolean_series(): _check(assert_type(s.str.match("pp"), "pd.Series[bool]")) -def test_string_accessors_type_boolean_index(): +def test_string_accessors_boolean_index(): idx = pd.Index(DATA) _check = functools.partial(check, klass=np.ndarray, dtype=np.bool_) _check(assert_type(idx.str.startswith("a"), np_ndarray_bool)) @@ -135,7 +135,7 @@ def test_string_accessors_type_boolean_index(): _check(assert_type(idx.str.match("pp"), np_ndarray_bool)) -def test_string_accessors_type_integer_series(): +def test_string_accessors_integer_series(): s = pd.Series(DATA) _check = functools.partial(check, klass=pd.Series, dtype=np.integer) _check(assert_type(s.str.find("p"), "pd.Series[int]")) @@ -146,7 +146,7 @@ def test_string_accessors_type_integer_series(): _check(assert_type(s.str.len(), "pd.Series[int]")) -def test_string_accessors_type_integer_index(): +def test_string_accessors_integer_index(): idx = pd.Index(DATA) _check = functools.partial(check, klass=pd.Index, dtype=np.integer) _check(assert_type(idx.str.find("p"), "pd.Index[int]")) @@ -157,19 +157,30 @@ def test_string_accessors_type_integer_index(): _check(assert_type(idx.str.len(), "pd.Index[int]")) -def test_string_accessors_encode_decode(): - s_str = pd.Series(["a1", "b2", "c3"]) - s_bytes = pd.Series([b"a1", b"b2", b"c3"]) +def test_string_accessors_string_series(): + s = pd.Series([b"a1", b"b2", b"c3"]) + _check = functools.partial(check, klass=pd.Series, dtype=str) + _check(assert_type(s.str.decode("utf-8"), "pd.Series[str]")) s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) - check( - assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]"), - pd.Series, - str, - ) - check( - assert_type(s_str.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes - ) - check(assert_type(s2.str.join("-"), "pd.Series[str]"), pd.Series, str) + _check(assert_type(s2.str.join("-"), "pd.Series[str]")) + + +def test_string_accessors_string_index(): + idx = pd.Index([b"a1", b"b2", b"c3"]) + _check = functools.partial(check, klass=pd.Index, dtype=str) + _check(assert_type(idx.str.decode("utf-8"), "pd.Index[str]")) + idx2 = pd.Index([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + _check(assert_type(idx2.str.join("-"), "pd.Index[str]")) + + +def test_string_accessors_bytes_series(): + s = pd.Series(["a1", "b2", "c3"]) + check(assert_type(s.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes) + + +def test_string_accessors_bytes_index(): + s = pd.Index(["a1", "b2", "c3"]) + check(assert_type(s.str.encode("latin-1"), "pd.Index[bytes]"), pd.Index, bytes) def test_string_accessors_list(): From 32988683ed7375de5e9c59234c9e151a58a00875 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:15:30 +0000 Subject: [PATCH 25/39] list --- tests/test_string_accessors.py | 36 ++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 1571b7f96..e3e3efef5 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -183,24 +183,26 @@ def test_string_accessors_bytes_index(): check(assert_type(s.str.encode("latin-1"), "pd.Index[bytes]"), pd.Index, bytes) -def test_string_accessors_list(): - s = pd.Series( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"), pd.Series, list) - check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list) +def test_string_accessors_list_series(): + s = pd.Series(DATA) + _check = functools.partial(check, klass=pd.Series, dtype=list) + _check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]")) + _check(assert_type(s.str.split("a"), "pd.Series[list[str]]")) # GH 194 - check( - assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"), - pd.Series, - list, - ) - check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list) - check( - assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"), - pd.Series, - list, - ) + _check(assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]")) + _check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]")) + _check(assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]")) + + +def test_string_accessors_list_index(): + idx = pd.Index(DATA) + _check = functools.partial(check, klass=pd.Index, dtype=list) + _check(assert_type(idx.str.findall("pp"), "pd.Index[list[str]]")) + _check(assert_type(idx.str.split("a"), "pd.Index[list[str]]")) + # GH 194 + _check(assert_type(idx.str.split("a", expand=False), "pd.Index[list[str]]")) + _check(assert_type(idx.str.rsplit("a"), "pd.Index[list[str]]")) + _check(assert_type(idx.str.rsplit("a", expand=False), "pd.Index[list[str]]")) # def test_string_accessors_expanding(): From 5dfa7fa6f280765d53689caf425e511a7858efd6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:18:05 +0000 Subject: [PATCH 26/39] expanding --- tests/test_string_accessors.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index e3e3efef5..1374a6797 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -205,11 +205,25 @@ def test_string_accessors_list_index(): _check(assert_type(idx.str.rsplit("a", expand=False), "pd.Index[list[str]]")) -# def test_string_accessors_expanding(): -# check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) -# check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) -# check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) -# check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame) -# check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame) -# check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame) -# check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) +def test_string_accessors_expanding_series(): + s = pd.Series(["a1", "b2", "c3"]) + _check = functools.partial(check, klass=pd.DataFrame) + _check(assert_type(s.str.extract(r"([ab])?(\d)"), pd.DataFrame)) + _check(assert_type(s.str.extractall(r"([ab])?(\d)"), pd.DataFrame)) + _check(assert_type(s.str.get_dummies(), pd.DataFrame)) + _check(assert_type(s.str.partition("p"), pd.DataFrame)) + _check(assert_type(s.str.rpartition("p"), pd.DataFrame)) + _check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame)) + _check(assert_type(s.str.split("a", expand=True), pd.DataFrame)) + + +def test_string_accessors_expanding_index(): + idx = pd.Index(["a1", "b2", "c3"]) + _check = functools.partial(check, klass=pd.MultiIndex) + _check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.MultiIndex)) + _check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.MultiIndex)) + _check(assert_type(idx.str.get_dummies(), pd.MultiIndex)) + _check(assert_type(idx.str.partition("p"), pd.MultiIndex)) + _check(assert_type(idx.str.rpartition("p"), pd.MultiIndex)) + _check(assert_type(idx.str.rsplit("a", expand=True), pd.MultiIndex)) + _check(assert_type(idx.str.split("a", expand=True), pd.MultiIndex)) From 3d581a8e1d8daf885c9abbcd8982650c6b78289c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:20:19 +0000 Subject: [PATCH 27/39] fixup --- pandas-stubs/core/strings.pyi | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 7b4eae71c..300428c9f 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -95,11 +95,11 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ... ) -> _TS2: ... @overload - def partition(self, sep: str = ...) -> pd.DataFrame: ... + def partition(self, sep: str = ...) -> _TS: ... @overload - def partition(self, *, expand: Literal[True]) -> pd.DataFrame: ... + def partition(self, *, expand: Literal[True]) -> _TS: ... @overload - def partition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ... + def partition(self, sep: str, expand: Literal[True]) -> _TS: ... @overload def partition( self, sep: str, expand: Literal[False] @@ -107,7 +107,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ @overload def partition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ... @overload - def rpartition(self, sep: str = ...) -> pd.DataFrame: ... + def rpartition(self, sep: str = ...) -> _TS: ... @overload def rpartition(self, *, expand: Literal[True]) -> pd.DataFrame: ... @overload @@ -171,7 +171,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ break_long_words: bool | None = ..., break_on_hyphens: bool | None = ..., ) -> T: ... - def get_dummies(self, sep: str = ...) -> pd.DataFrame: ... + def get_dummies(self, sep: str = ...) -> _TS: ... def translate(self, table: dict[int, int | str | None] | None) -> T: ... def count(self, pat: str, flags: int = ...) -> _TI: ... def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _TM: ... @@ -180,7 +180,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ @overload def extract( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... - ) -> pd.DataFrame: ... + ) -> _TS: ... @overload def extract( self, pat: str, flags: int, expand: Literal[False] @@ -189,7 +189,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ def extract( self, pat: str, flags: int = ..., *, expand: Literal[False] ) -> Series[type[object]]: ... - def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... + def extractall(self, pat: str, flags: int = ...) -> _TS: ... def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ... From 005759c08f31bd1aa69add0311ae4ddf5288153c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:27:46 +0000 Subject: [PATCH 28/39] keep fixing --- pandas-stubs/core/indexes/base.pyi | 1 + pandas-stubs/core/series.pyi | 1 + pandas-stubs/core/strings.pyi | 26 +++++----- tests/test_string_accessors.py | 78 ++++++++++++++++++++++++++++-- 4 files changed, 90 insertions(+), 16 deletions(-) diff --git a/pandas-stubs/core/indexes/base.pyi b/pandas-stubs/core/indexes/base.pyi index 9af84b9ce..2708e2247 100644 --- a/pandas-stubs/core/indexes/base.pyi +++ b/pandas-stubs/core/indexes/base.pyi @@ -272,6 +272,7 @@ class Index(IndexOpsMixin[S1]): Index[int], Index[bytes], Index[str], + Index[type[object]], ]: ... def is_(self, other) -> bool: ... def __len__(self) -> int: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index dff3dab1f..4863a2ee7 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -1164,6 +1164,7 @@ class Series(IndexOpsMixin[S1], NDFrame): Series[int], Series[bytes], Series[str], + Series[type[object]], ]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 300428c9f..354989c4f 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -43,8 +43,12 @@ _TI = TypeVar("_TI", bound=Series[int] | Index[int]) _TE = TypeVar("_TE", bound=Series[bytes] | Index[bytes]) # The _TD type is what is used for the result of str.encode _TD = TypeVar("_TD", bound=Series[str] | Index[str]) +# The _TO type is what is used for the result of str.encode +_TO = TypeVar("_TO", bound=Series[type[object]] | Index[type[object]]) -class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _TD]): +class StringMethods( + NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _TD, _TO] +): def __init__(self, data: T) -> None: ... def __getitem__(self, key: slice | int) -> T: ... def __iter__(self) -> T: ... @@ -101,23 +105,19 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ @overload def partition(self, sep: str, expand: Literal[True]) -> _TS: ... @overload - def partition( - self, sep: str, expand: Literal[False] - ) -> pd.Series[type[object]]: ... + def partition(self, sep: str, expand: Literal[False]) -> _TO: ... @overload - def partition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ... + def partition(self, *, expand: Literal[False]) -> _TO: ... @overload def rpartition(self, sep: str = ...) -> _TS: ... @overload - def rpartition(self, *, expand: Literal[True]) -> pd.DataFrame: ... + def rpartition(self, *, expand: Literal[True]) -> _TS: ... @overload - def rpartition(self, sep: str, expand: Literal[True]) -> pd.DataFrame: ... + def rpartition(self, sep: str, expand: Literal[True]) -> _TS: ... @overload - def rpartition( - self, sep: str, expand: Literal[False] - ) -> pd.Series[type[object]]: ... + def rpartition(self, sep: str, expand: Literal[False]) -> _TO: ... @overload - def rpartition(self, *, expand: Literal[False]) -> pd.Series[type[object]]: ... + def rpartition(self, *, expand: Literal[False]) -> _TO: ... def get(self, i: int) -> T: ... def join(self, sep: str) -> _TD: ... def contains( @@ -180,7 +180,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ @overload def extract( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... - ) -> _TS: ... + ) -> pd.DataFrame: ... @overload def extract( self, pat: str, flags: int, expand: Literal[False] @@ -189,7 +189,7 @@ class StringMethods(NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _ def extract( self, pat: str, flags: int = ..., *, expand: Literal[False] ) -> Series[type[object]]: ... - def extractall(self, pat: str, flags: int = ...) -> _TS: ... + def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ... diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 1374a6797..6ec349cbf 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -81,7 +81,7 @@ def test_string_accessors_type_preserving_index() -> None: def test_string_accessors_boolean_series(): s = pd.Series(DATA) - _check = functools.partial(check, klass=pd.Series, dtype=bool) + _check = functools.partial(check, klass=pd.Series, dtype=np.bool_) _check(assert_type(s.str.startswith("a"), "pd.Series[bool]")) _check( assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"), @@ -220,10 +220,82 @@ def test_string_accessors_expanding_series(): def test_string_accessors_expanding_index(): idx = pd.Index(["a1", "b2", "c3"]) _check = functools.partial(check, klass=pd.MultiIndex) - _check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.MultiIndex)) - _check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.MultiIndex)) _check(assert_type(idx.str.get_dummies(), pd.MultiIndex)) _check(assert_type(idx.str.partition("p"), pd.MultiIndex)) _check(assert_type(idx.str.rpartition("p"), pd.MultiIndex)) _check(assert_type(idx.str.rsplit("a", expand=True), pd.MultiIndex)) _check(assert_type(idx.str.split("a", expand=True), pd.MultiIndex)) + + # These ones are the odd ones out? + check(assert_type(idx.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) + check(assert_type(idx.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) + + +def test_series_overloads_partition(): + s = pd.Series( + [ + "ap;pl;ep", + "ban;an;ap", + "Che;rr;yp", + "DA;TEp", + "eGGp;LANT;p", + "12;3p", + "23.45p", + ] + ) + check(assert_type(s.str.partition(sep=";"), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(s.str.partition(sep=";", expand=False), "pd.Series[type[object]]"), + pd.Series, + object, + ) + + check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame + ) + check( + assert_type(s.str.rpartition(sep=";", expand=False), "pd.Series[type[object]]"), + pd.Series, + object, + ) + + +def test_index_overloads_partition(): + idx = pd.Index( + [ + "ap;pl;ep", + "ban;an;ap", + "Che;rr;yp", + "DA;TEp", + "eGGp;LANT;p", + "12;3p", + "23.45p", + ] + ) + check(assert_type(idx.str.partition(sep=";"), pd.MultiIndex), pd.MultiIndex) + check( + assert_type(idx.str.partition(sep=";", expand=True), pd.MultiIndex), + pd.MultiIndex, + ) + check( + assert_type(idx.str.partition(sep=";", expand=False), "pd.Index[type[object]]"), + pd.Index, + object, + ) + + check(assert_type(idx.str.rpartition(sep=";"), pd.MultiIndex), pd.MultiIndex) + check( + assert_type(idx.str.rpartition(sep=";", expand=True), pd.MultiIndex), + pd.MultiIndex, + ) + check( + assert_type( + idx.str.rpartition(sep=";", expand=False), "pd.Index[type[object]]" + ), + pd.Index, + object, + ) From aca32d534fab1fbe79c6d44794478dd646cd4f8b Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:30:40 +0000 Subject: [PATCH 29/39] keep fixing --- tests/test_series.py | 156 --------------------------------- tests/test_string_accessors.py | 50 +++++++++++ 2 files changed, 50 insertions(+), 156 deletions(-) diff --git a/tests/test_series.py b/tests/test_series.py index 008638afe..1fd5d81a3 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1571,162 +1571,6 @@ def test_categorical_codes(): assert_type(cat.codes, "np_ndarray_int") -def test_string_accessors(): - s = pd.Series( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) - s3 = pd.Series(["a1", "b2", "c3"]) - s4 = pd.Series([b"a1", b"b2", b"c3"]) - check(assert_type(s.str.capitalize(), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.casefold(), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.cat(sep="X"), str), str) - check(assert_type(s.str.center(10), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.contains("a"), "pd.Series[bool]"), pd.Series, np.bool_) - check( - assert_type(s.str.contains(re.compile(r"a")), "pd.Series[bool]"), - pd.Series, - np.bool_, - ) - check(assert_type(s.str.count("pp"), "pd.Series[int]"), pd.Series, np.integer) - check(assert_type(s4.str.decode("utf-8"), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.encode("latin-1"), "pd.Series[bytes]"), pd.Series, bytes) - check(assert_type(s.str.endswith("e"), "pd.Series[bool]"), pd.Series, np.bool_) - check( - assert_type(s.str.endswith(("e", "f")), "pd.Series[bool]"), pd.Series, np.bool_ - ) - check(assert_type(s3.str.extract(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) - check(assert_type(s3.str.extractall(r"([ab])?(\d)"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.find("p"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.findall("pp"), "pd.Series[list[str]]"), pd.Series, list) - check(assert_type(s.str.fullmatch("apple"), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.get(2), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.get_dummies(), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.index("p"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.isalnum(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isalpha(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isdecimal(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isdigit(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isnumeric(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.islower(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isspace(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.istitle(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.isupper(), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s2.str.join("-"), pd.Series), pd.Series) - check(assert_type(s.str.len(), "pd.Series[int]"), pd.Series, np.integer) - check(assert_type(s.str.ljust(80), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.lower(), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.lstrip("a"), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.match("pp"), "pd.Series[bool]"), pd.Series, np.bool_) - check(assert_type(s.str.normalize("NFD"), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.pad(80, "right"), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.partition("p"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.removeprefix("a"), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.removesuffix("e"), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.repeat(2), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.replace("a", "X"), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.rfind("e"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.rindex("p"), "pd.Series[int]"), pd.Series, np.int64) - check(assert_type(s.str.rjust(80), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.rpartition("p"), pd.DataFrame), pd.DataFrame) - check(assert_type(s.str.rsplit("a"), "pd.Series[list[str]]"), pd.Series, list) - check(assert_type(s.str.rsplit("a", expand=True), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.rsplit("a", expand=False), "pd.Series[list[str]]"), - pd.Series, - list, - ) - check(assert_type(s.str.rstrip(), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.slice(0, 4, 2), "pd.Series[str]"), pd.Series, str) - check( - assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]"), pd.Series, str - ) - check(assert_type(s.str.split("a"), "pd.Series[list[str]]"), pd.Series, list) - # GH 194 - check(assert_type(s.str.split("a", expand=True), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.split("a", expand=False), "pd.Series[list[str]]"), - pd.Series, - list, - ) - check(assert_type(s.str.startswith("a"), "pd.Series[bool]"), pd.Series, np.bool_) - check( - assert_type(s.str.startswith(("a", "b")), "pd.Series[bool]"), - pd.Series, - np.bool_, - ) - check(assert_type(s.str.strip(), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.swapcase(), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.title(), "pd.Series[str]"), pd.Series, str) - check( - assert_type(s.str.translate({241: "n"}), "pd.Series[str]"), - pd.Series, - str, - ) - check(assert_type(s.str.upper(), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.wrap(80), "pd.Series[str]"), pd.Series, str) - check(assert_type(s.str.zfill(10), "pd.Series[str]"), pd.Series, str) - - -def test_series_overloads_cat(): - s = pd.Series( - ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] - ) - check(assert_type(s.str.cat(sep=";"), str), str) - check(assert_type(s.str.cat(None, sep=";"), str), str) - check( - assert_type( - s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), - "pd.Series[str]", - ), - pd.Series, - str, - ) - check( - assert_type( - s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), - "pd.Series[str]", - ), - pd.Series, - str, - ) - unknown_s: UnknownSeries = pd.DataFrame({"a": ["a", "b"]})["a"] - check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str) - - -def test_series_overloads_partition(): - s = pd.Series( - [ - "ap;pl;ep", - "ban;an;ap", - "Che;rr;yp", - "DA;TEp", - "eGGp;LANT;p", - "12;3p", - "23.45p", - ] - ) - check(assert_type(s.str.partition(sep=";"), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.partition(sep=";", expand=True), pd.DataFrame), pd.DataFrame - ) - check( - assert_type(s.str.partition(sep=";", expand=False), "pd.Series[type[object]]"), - pd.Series, - object, - ) - - check(assert_type(s.str.rpartition(sep=";"), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.rpartition(sep=";", expand=True), pd.DataFrame), pd.DataFrame - ) - check( - assert_type(s.str.rpartition(sep=";", expand=False), "pd.Series[type[object]]"), - pd.Series, - object, - ) - - def test_series_overloads_extract(): s = pd.Series( ["appl;ep", "ban;anap", "Cherr;yp", "DATEp", "eGGp;LANTp", "12;3p", "23.45p"] diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 6ec349cbf..9848b77df 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -299,3 +299,53 @@ def test_index_overloads_partition(): pd.Index, object, ) + + +def test_series_overloads_cat(): + s = pd.Series(DATA) + check(assert_type(s.str.cat(sep=";"), str), str) + check(assert_type(s.str.cat(None, sep=";"), str), str) + check( + assert_type( + s.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), + "pd.Series[str]", + ), + pd.Series, + str, + ) + check( + assert_type( + s.str.cat(pd.Series(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), + "pd.Series[str]", + ), + pd.Series, + str, + ) + unknown_s = pd.DataFrame({"a": ["a", "b"]})["a"] + check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str) + + +def test_index_overloads_cat(): + idx = pd.Index(DATA) + check(assert_type(idx.str.cat(sep=";"), str), str) + check(assert_type(idx.str.cat(None, sep=";"), str), str) + check( + assert_type( + idx.str.cat(["A", "B", "C", "D", "E", "F", "G"], sep=";"), + "pd.Index[str]", + ), + pd.Index, + str, + ) + check( + assert_type( + idx.str.cat(pd.Index(["A", "B", "C", "D", "E", "F", "G"]), sep=";"), + "pd.Index[str]", + ), + pd.Index, + str, + ) + unknown_idx = pd.DataFrame({"a": ["a", "b"]}).set_index("a").index + check( + assert_type(idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), pd.Index, str + ) From b24430897ad6f240dad997c22aeaa5429ce8abfb Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:34:27 +0000 Subject: [PATCH 30/39] overloads cat --- tests/test_string_accessors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 9848b77df..305f9052b 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -321,7 +321,7 @@ def test_series_overloads_cat(): pd.Series, str, ) - unknown_s = pd.DataFrame({"a": ["a", "b"]})["a"] + unknown_s = pd.DataFrame({"a": list("abcdefg")})["a"] check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str) @@ -345,7 +345,7 @@ def test_index_overloads_cat(): pd.Index, str, ) - unknown_idx = pd.DataFrame({"a": ["a", "b"]}).set_index("a").index + unknown_idx = pd.DataFrame({"a": list("abcdefg")}).set_index("a").index check( assert_type(idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), pd.Index, str ) From 0d1fc59fd21be73a5bf281ef28325c0510619913 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:36:09 +0000 Subject: [PATCH 31/39] fixup str.extract --- pandas-stubs/core/strings.pyi | 8 ++---- tests/test_series.py | 24 ------------------ tests/test_string_accessors.py | 45 ++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 354989c4f..4a98844e9 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -182,13 +182,9 @@ class StringMethods( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... ) -> pd.DataFrame: ... @overload - def extract( - self, pat: str, flags: int, expand: Literal[False] - ) -> Series[type[object]]: ... + def extract(self, pat: str, flags: int, expand: Literal[False]) -> _TO: ... @overload - def extract( - self, pat: str, flags: int = ..., *, expand: Literal[False] - ) -> Series[type[object]]: ... + def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> _TO: ... def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... diff --git a/tests/test_series.py b/tests/test_series.py index 1fd5d81a3..217ecd692 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1571,30 +1571,6 @@ def test_categorical_codes(): assert_type(cat.codes, "np_ndarray_int") -def test_series_overloads_extract(): - s = pd.Series( - ["appl;ep", "ban;anap", "Cherr;yp", "DATEp", "eGGp;LANTp", "12;3p", "23.45p"] - ) - check(assert_type(s.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame) - check( - assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame - ) - check( - assert_type( - s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]" - ), - pd.Series, - object, - ) - check( - assert_type( - s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]" - ), - pd.Series, - object, - ) - - def test_relops() -> None: # GH 175 s: str = "abc" diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 305f9052b..57eb38933 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -349,3 +349,48 @@ def test_index_overloads_cat(): check( assert_type(idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), pd.Index, str ) + + +def test_series_overloads_extract(): + s = pd.Series(DATA) + check(assert_type(s.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame) + check( + assert_type(s.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), pd.DataFrame + ) + check( + assert_type( + s.str.extract(r"[ab](\d)", expand=False), "pd.Series[type[object]]" + ), + pd.Series, + object, + ) + check( + assert_type( + s.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Series[type[object]]" + ), + pd.Series, + object, + ) + + +def test_index_overloads_extract(): + idx = pd.Index(DATA) + check(assert_type(idx.str.extract(r"[ab](\d)"), pd.DataFrame), pd.DataFrame) + check( + assert_type(idx.str.extract(r"[ab](\d)", expand=True), pd.DataFrame), + pd.DataFrame, + ) + check( + assert_type( + idx.str.extract(r"[ab](\d)", expand=False), "pd.Index[type[object]]" + ), + pd.Index, + object, + ) + check( + assert_type( + idx.str.extract(r"[ab](\d)", re.IGNORECASE, False), "pd.Index[type[object]]" + ), + pd.Index, + object, + ) From 7ccfa0d9378c76c64fdbb2ef10285e062c554220 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:49:48 +0000 Subject: [PATCH 32/39] rename for clarity --- pandas-stubs/core/strings.pyi | 117 ++++++++++++++++++---------------- 1 file changed, 61 insertions(+), 56 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index 4a98844e9..fddb15208 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -31,23 +31,24 @@ from pandas._typing import ( np_ndarray_bool, ) -# The _TS type is what is used for the result of str.split with expand=True -_TS = TypeVar("_TS", bound=DataFrame | MultiIndex) -# The _TS2 type is what is used for the result of str.split with expand=False -_TS2 = TypeVar("_TS2", bound=Series[list[str]] | Index[list[str]]) -# The _TM type is what is used for the result of str.match -_TM = TypeVar("_TM", bound=Series[bool] | np_ndarray_bool) -# The _TI type is what is used for the result of str.index / str.find -_TI = TypeVar("_TI", bound=Series[int] | Index[int]) -# The _TE type is what is used for the result of str.encode -_TE = TypeVar("_TE", bound=Series[bytes] | Index[bytes]) -# The _TD type is what is used for the result of str.encode -_TD = TypeVar("_TD", bound=Series[str] | Index[str]) -# The _TO type is what is used for the result of str.encode -_TO = TypeVar("_TO", bound=Series[type[object]] | Index[type[object]]) +# Used for the result of str.split with expand=True +_T_EXPANDING = TypeVar("_T_EXPANDING", bound=DataFrame | MultiIndex) +# Used for the result of str.split with expand=False +_T_LIST_STR = TypeVar("_T_LIST_STR", bound=Series[list[str]] | Index[list[str]]) +# Used for the result of str.match +_T_BOOL = TypeVar("_T_BOOL", bound=Series[bool] | np_ndarray_bool) +# Used for the result of str.index / str.find +_T_INT = TypeVar("_T_INT", bound=Series[int] | Index[int]) +# Used for the result of str.encode +_T_BYTES = TypeVar("_T_BYTES", bound=Series[bytes] | Index[bytes]) +# Used for the result of str.decode +_T_STR = TypeVar("_T_STR", bound=Series[str] | Index[str]) +# Used for the result of str.partition +_T_OBJECT = TypeVar("_T_OBJECT", bound=Series[type[object]] | Index[type[object]]) class StringMethods( - NoNewAttributesMixin, Generic[T, _TS, _TM, _TS2, _TI, _TE, _TD, _TO] + NoNewAttributesMixin, + Generic[T, _T_EXPANDING, _T_BOOL, _T_LIST_STR, _T_INT, _T_BYTES, _T_STR, _T_OBJECT], ): def __init__(self, data: T) -> None: ... def __getitem__(self, key: slice | int) -> T: ... @@ -82,7 +83,7 @@ class StringMethods( @overload def split( self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ... - ) -> _TS: ... + ) -> _T_EXPANDING: ... @overload def split( self, @@ -91,35 +92,37 @@ class StringMethods( n: int = ..., expand: Literal[False] = ..., regex: bool = ..., - ) -> _TS2: ... + ) -> _T_LIST_STR: ... @overload - def rsplit(self, pat: str = ..., *, n: int = ..., expand: Literal[True]) -> _TS: ... + def rsplit( + self, pat: str = ..., *, n: int = ..., expand: Literal[True] + ) -> _T_EXPANDING: ... @overload def rsplit( self, pat: str = ..., *, n: int = ..., expand: Literal[False] = ... - ) -> _TS2: ... + ) -> _T_LIST_STR: ... @overload - def partition(self, sep: str = ...) -> _TS: ... + def partition(self, sep: str = ...) -> _T_EXPANDING: ... @overload - def partition(self, *, expand: Literal[True]) -> _TS: ... + def partition(self, *, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def partition(self, sep: str, expand: Literal[True]) -> _TS: ... + def partition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def partition(self, sep: str, expand: Literal[False]) -> _TO: ... + def partition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ... @overload - def partition(self, *, expand: Literal[False]) -> _TO: ... + def partition(self, *, expand: Literal[False]) -> _T_OBJECT: ... @overload - def rpartition(self, sep: str = ...) -> _TS: ... + def rpartition(self, sep: str = ...) -> _T_EXPANDING: ... @overload - def rpartition(self, *, expand: Literal[True]) -> _TS: ... + def rpartition(self, *, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def rpartition(self, sep: str, expand: Literal[True]) -> _TS: ... + def rpartition(self, sep: str, expand: Literal[True]) -> _T_EXPANDING: ... @overload - def rpartition(self, sep: str, expand: Literal[False]) -> _TO: ... + def rpartition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ... @overload - def rpartition(self, *, expand: Literal[False]) -> _TO: ... + def rpartition(self, *, expand: Literal[False]) -> _T_OBJECT: ... def get(self, i: int) -> T: ... - def join(self, sep: str) -> _TD: ... + def join(self, sep: str) -> _T_STR: ... def contains( self, pat: str | re.Pattern[str], @@ -127,10 +130,10 @@ class StringMethods( flags: int = ..., na: Scalar | NaTType | None = ..., regex: bool = ..., - ) -> _TM: ... + ) -> _T_BOOL: ... def match( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... - ) -> _TM: ... + ) -> _T_BOOL: ... def replace( self, pat: str, @@ -157,8 +160,8 @@ class StringMethods( def slice_replace( self, start: int | None = ..., stop: int | None = ..., repl: str | None = ... ) -> T: ... - def decode(self, encoding: str, errors: str = ...) -> _TD: ... - def encode(self, encoding: str, errors: str = ...) -> _TE: ... + def decode(self, encoding: str, errors: str = ...) -> _T_STR: ... + def encode(self, encoding: str, errors: str = ...) -> _T_BYTES: ... def strip(self, to_strip: str | None = ...) -> T: ... def lstrip(self, to_strip: str | None = ...) -> T: ... def rstrip(self, to_strip: str | None = ...) -> T: ... @@ -171,44 +174,46 @@ class StringMethods( break_long_words: bool | None = ..., break_on_hyphens: bool | None = ..., ) -> T: ... - def get_dummies(self, sep: str = ...) -> _TS: ... + def get_dummies(self, sep: str = ...) -> _T_EXPANDING: ... def translate(self, table: dict[int, int | str | None] | None) -> T: ... - def count(self, pat: str, flags: int = ...) -> _TI: ... - def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _TM: ... - def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _TM: ... - def findall(self, pat: str, flags: int = ...) -> _TS2: ... + def count(self, pat: str, flags: int = ...) -> _T_INT: ... + def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... + def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... + def findall(self, pat: str, flags: int = ...) -> _T_LIST_STR: ... @overload def extract( self, pat: str, flags: int = ..., *, expand: Literal[True] = ... ) -> pd.DataFrame: ... @overload - def extract(self, pat: str, flags: int, expand: Literal[False]) -> _TO: ... + def extract(self, pat: str, flags: int, expand: Literal[False]) -> _T_OBJECT: ... @overload - def extract(self, pat: str, flags: int = ..., *, expand: Literal[False]) -> _TO: ... + def extract( + self, pat: str, flags: int = ..., *, expand: Literal[False] + ) -> _T_OBJECT: ... def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... - def find(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... - def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... + def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ... - def index(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... - def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _TI: ... - def len(self) -> _TI: ... + def index(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... + def len(self) -> _T_INT: ... def lower(self) -> T: ... def upper(self) -> T: ... def title(self) -> T: ... def capitalize(self) -> T: ... def swapcase(self) -> T: ... def casefold(self) -> T: ... - def isalnum(self) -> _TM: ... - def isalpha(self) -> _TM: ... - def isdigit(self) -> _TM: ... - def isspace(self) -> _TM: ... - def islower(self) -> _TM: ... - def isupper(self) -> _TM: ... - def istitle(self) -> _TM: ... - def isnumeric(self) -> _TM: ... - def isdecimal(self) -> _TM: ... + def isalnum(self) -> _T_BOOL: ... + def isalpha(self) -> _T_BOOL: ... + def isdigit(self) -> _T_BOOL: ... + def isspace(self) -> _T_BOOL: ... + def islower(self) -> _T_BOOL: ... + def isupper(self) -> _T_BOOL: ... + def istitle(self) -> _T_BOOL: ... + def isnumeric(self) -> _T_BOOL: ... + def isdecimal(self) -> _T_BOOL: ... def fullmatch( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... - ) -> _TM: ... + ) -> _T_BOOL: ... def removeprefix(self, prefix: str) -> T: ... def removesuffix(self, suffix: str) -> T: ... From b4839a01f92856427cd7dcf29344e55b13fde7a6 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:50:58 +0000 Subject: [PATCH 33/39] lint --- tests/test_series.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_series.py b/tests/test_series.py index 217ecd692..6fd1950e9 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -65,13 +65,11 @@ OffsetSeries, TimedeltaSeries, TimestampSeries, - UnknownSeries, ) else: TimedeltaSeries: TypeAlias = pd.Series TimestampSeries: TypeAlias = pd.Series OffsetSeries: TypeAlias = pd.Series - UnknownSeries: TypeAlias = pd.Series if TYPE_CHECKING: from pandas._typing import ( From 17e280f6e1054a9e896e30eba7764f7568da0f10 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 8 Mar 2025 11:57:37 +0000 Subject: [PATCH 34/39] annotate idx2 as per mypys request --- test | 2 -- tests/test_string_accessors.py | 4 +++- 2 files changed, 3 insertions(+), 3 deletions(-) delete mode 100644 test diff --git a/test b/test deleted file mode 100644 index e446f3ac8..000000000 --- a/test +++ /dev/null @@ -1,2 +0,0 @@ - test -ind abc \ No newline at end of file diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 57eb38933..3906b9c05 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -169,7 +169,9 @@ def test_string_accessors_string_index(): idx = pd.Index([b"a1", b"b2", b"c3"]) _check = functools.partial(check, klass=pd.Index, dtype=str) _check(assert_type(idx.str.decode("utf-8"), "pd.Index[str]")) - idx2 = pd.Index([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + idx2: "pd.Index[list]" = pd.Index( + [["apple", "banana"], ["cherry", "date"], [1, "eggplant"]] + ) _check(assert_type(idx2.str.join("-"), "pd.Index[str]")) From 208a55ca3427e5e1571d39c851e139e72204fd88 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 10 Mar 2025 12:11:03 +0000 Subject: [PATCH 35/39] return _T_STR, except for `slice` because that one preserves the input types --- pandas-stubs/core/strings.pyi | 52 ++++++------ tests/test_string_accessors.py | 151 ++++++++++++++++++--------------- 2 files changed, 110 insertions(+), 93 deletions(-) diff --git a/pandas-stubs/core/strings.pyi b/pandas-stubs/core/strings.pyi index fddb15208..4d215e82a 100644 --- a/pandas-stubs/core/strings.pyi +++ b/pandas-stubs/core/strings.pyi @@ -51,8 +51,8 @@ class StringMethods( Generic[T, _T_EXPANDING, _T_BOOL, _T_LIST_STR, _T_INT, _T_BYTES, _T_STR, _T_OBJECT], ): def __init__(self, data: T) -> None: ... - def __getitem__(self, key: slice | int) -> T: ... - def __iter__(self) -> T: ... + def __getitem__(self, key: slice | int) -> _T_STR: ... + def __iter__(self) -> _T_STR: ... @overload def cat( self, @@ -79,7 +79,7 @@ class StringMethods( sep: str = ..., na_rep: str | None = ..., join: JoinHow = ..., - ) -> T: ... + ) -> _T_STR: ... @overload def split( self, pat: str = ..., *, n: int = ..., expand: Literal[True], regex: bool = ... @@ -121,7 +121,7 @@ class StringMethods( def rpartition(self, sep: str, expand: Literal[False]) -> _T_OBJECT: ... @overload def rpartition(self, *, expand: Literal[False]) -> _T_OBJECT: ... - def get(self, i: int) -> T: ... + def get(self, i: int) -> _T_STR: ... def join(self, sep: str) -> _T_STR: ... def contains( self, @@ -142,29 +142,29 @@ class StringMethods( case: bool | None = ..., flags: int = ..., regex: bool = ..., - ) -> T: ... - def repeat(self, repeats: int | Sequence[int]) -> T: ... + ) -> _T_STR: ... + def repeat(self, repeats: int | Sequence[int]) -> _T_STR: ... def pad( self, width: int, side: Literal["left", "right", "both"] = ..., fillchar: str = ..., - ) -> T: ... - def center(self, width: int, fillchar: str = ...) -> T: ... - def ljust(self, width: int, fillchar: str = ...) -> T: ... - def rjust(self, width: int, fillchar: str = ...) -> T: ... - def zfill(self, width: int) -> T: ... + ) -> _T_STR: ... + def center(self, width: int, fillchar: str = ...) -> _T_STR: ... + def ljust(self, width: int, fillchar: str = ...) -> _T_STR: ... + def rjust(self, width: int, fillchar: str = ...) -> _T_STR: ... + def zfill(self, width: int) -> _T_STR: ... def slice( self, start: int | None = ..., stop: int | None = ..., step: int | None = ... ) -> T: ... def slice_replace( self, start: int | None = ..., stop: int | None = ..., repl: str | None = ... - ) -> T: ... + ) -> _T_STR: ... def decode(self, encoding: str, errors: str = ...) -> _T_STR: ... def encode(self, encoding: str, errors: str = ...) -> _T_BYTES: ... - def strip(self, to_strip: str | None = ...) -> T: ... - def lstrip(self, to_strip: str | None = ...) -> T: ... - def rstrip(self, to_strip: str | None = ...) -> T: ... + def strip(self, to_strip: str | None = ...) -> _T_STR: ... + def lstrip(self, to_strip: str | None = ...) -> _T_STR: ... + def rstrip(self, to_strip: str | None = ...) -> _T_STR: ... def wrap( self, width: int, @@ -173,9 +173,9 @@ class StringMethods( drop_whitespace: bool | None = ..., break_long_words: bool | None = ..., break_on_hyphens: bool | None = ..., - ) -> T: ... + ) -> _T_STR: ... def get_dummies(self, sep: str = ...) -> _T_EXPANDING: ... - def translate(self, table: dict[int, int | str | None] | None) -> T: ... + def translate(self, table: dict[int, int | str | None] | None) -> _T_STR: ... def count(self, pat: str, flags: int = ...) -> _T_INT: ... def startswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... def endswith(self, pat: str | tuple[str, ...], na: Any = ...) -> _T_BOOL: ... @@ -193,16 +193,16 @@ class StringMethods( def extractall(self, pat: str, flags: int = ...) -> pd.DataFrame: ... def find(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... def rfind(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... - def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> T: ... + def normalize(self, form: Literal["NFC", "NFKC", "NFD", "NFKD"]) -> _T_STR: ... def index(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... def rindex(self, sub: str, start: int = ..., end: int | None = ...) -> _T_INT: ... def len(self) -> _T_INT: ... - def lower(self) -> T: ... - def upper(self) -> T: ... - def title(self) -> T: ... - def capitalize(self) -> T: ... - def swapcase(self) -> T: ... - def casefold(self) -> T: ... + def lower(self) -> _T_STR: ... + def upper(self) -> _T_STR: ... + def title(self) -> _T_STR: ... + def capitalize(self) -> _T_STR: ... + def swapcase(self) -> _T_STR: ... + def casefold(self) -> _T_STR: ... def isalnum(self) -> _T_BOOL: ... def isalpha(self) -> _T_BOOL: ... def isdigit(self) -> _T_BOOL: ... @@ -215,5 +215,5 @@ class StringMethods( def fullmatch( self, pat: str, case: bool = ..., flags: int = ..., na: Any = ... ) -> _T_BOOL: ... - def removeprefix(self, prefix: str) -> T: ... - def removesuffix(self, suffix: str) -> T: ... + def removeprefix(self, prefix: str) -> _T_STR: ... + def removesuffix(self, suffix: str) -> _T_STR: ... diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 3906b9c05..54fd79909 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -13,70 +13,21 @@ DATA = ["applep", "bananap", "Cherryp", "DATEp", "eGGpLANTp", "123p", "23.45p"] +DATA_BYTES = [b"applep", b"bananap"] def test_string_accessors_type_preserving_series() -> None: - s = pd.Series(DATA) - _check = functools.partial(check, klass=pd.Series, dtype=str) - _check(assert_type(s.str.capitalize(), "pd.Series[str]")) - _check(assert_type(s.str.casefold(), "pd.Series[str]")) - check(assert_type(s.str.cat(sep="X"), str), str) - _check(assert_type(s.str.center(10), "pd.Series[str]")) - _check(assert_type(s.str.get(2), "pd.Series[str]")) - _check(assert_type(s.str.ljust(80), "pd.Series[str]")) - _check(assert_type(s.str.lower(), "pd.Series[str]")) - _check(assert_type(s.str.lstrip("a"), "pd.Series[str]")) - _check(assert_type(s.str.normalize("NFD"), "pd.Series[str]")) - _check(assert_type(s.str.pad(80, "right"), "pd.Series[str]")) - _check(assert_type(s.str.removeprefix("a"), "pd.Series[str]")) - _check(assert_type(s.str.removesuffix("e"), "pd.Series[str]")) - _check(assert_type(s.str.repeat(2), "pd.Series[str]")) - _check(assert_type(s.str.replace("a", "X"), "pd.Series[str]")) - _check(assert_type(s.str.rjust(80), "pd.Series[str]")) - _check(assert_type(s.str.rstrip(), "pd.Series[str]")) - _check(assert_type(s.str.slice(0, 4, 2), "pd.Series[str]")) - _check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]")) - _check(assert_type(s.str.strip(), "pd.Series[str]")) - _check(assert_type(s.str.swapcase(), "pd.Series[str]")) - _check(assert_type(s.str.title(), "pd.Series[str]")) - _check( - assert_type(s.str.translate({241: "n"}), "pd.Series[str]"), - ) - _check(assert_type(s.str.upper(), "pd.Series[str]")) - _check(assert_type(s.str.wrap(80), "pd.Series[str]")) - _check(assert_type(s.str.zfill(10), "pd.Series[str]")) + s_str = pd.Series(DATA) + s_bytes = pd.Series(DATA_BYTES) + check(assert_type(s_str.str.slice(0, 4, 2), "pd.Series[str]"), pd.Series, str) + check(assert_type(s_bytes.str.slice(0, 4, 2), "pd.Series[bytes]"), pd.Series, bytes) def test_string_accessors_type_preserving_index() -> None: - idx = pd.Index(DATA) - _check = functools.partial(check, klass=pd.Index, dtype=str) - _check(assert_type(idx.str.capitalize(), "pd.Index[str]")) - _check(assert_type(idx.str.casefold(), "pd.Index[str]")) - check(assert_type(idx.str.cat(sep="X"), str), str) - _check(assert_type(idx.str.center(10), "pd.Index[str]")) - _check(assert_type(idx.str.get(2), "pd.Index[str]")) - _check(assert_type(idx.str.ljust(80), "pd.Index[str]")) - _check(assert_type(idx.str.lower(), "pd.Index[str]")) - _check(assert_type(idx.str.lstrip("a"), "pd.Index[str]")) - _check(assert_type(idx.str.normalize("NFD"), "pd.Index[str]")) - _check(assert_type(idx.str.pad(80, "right"), "pd.Index[str]")) - _check(assert_type(idx.str.removeprefix("a"), "pd.Index[str]")) - _check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]")) - _check(assert_type(idx.str.repeat(2), "pd.Index[str]")) - _check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]")) - _check(assert_type(idx.str.rjust(80), "pd.Index[str]")) - _check(assert_type(idx.str.rstrip(), "pd.Index[str]")) - _check(assert_type(idx.str.slice(0, 4, 2), "pd.Index[str]")) - _check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]")) - _check(assert_type(idx.str.strip(), "pd.Index[str]")) - _check(assert_type(idx.str.swapcase(), "pd.Index[str]")) - _check(assert_type(idx.str.title(), "pd.Index[str]")) - _check( - assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"), - ) - _check(assert_type(idx.str.upper(), "pd.Index[str]")) - _check(assert_type(idx.str.wrap(80), "pd.Index[str]")) - _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) + idx_str = pd.Index(DATA) + idx_bytes = pd.Index(DATA_BYTES) + check(assert_type(idx_str.str.slice(0, 4, 2), "pd.Index[str]"), pd.Index, str) + check(assert_type(idx_bytes.str.slice(0, 4, 2), "pd.Index[bytes]"), pd.Index, bytes) def test_string_accessors_boolean_series(): @@ -158,21 +109,73 @@ def test_string_accessors_integer_index(): def test_string_accessors_string_series(): - s = pd.Series([b"a1", b"b2", b"c3"]) + s = pd.Series(DATA) _check = functools.partial(check, klass=pd.Series, dtype=str) - _check(assert_type(s.str.decode("utf-8"), "pd.Series[str]")) - s2 = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) - _check(assert_type(s2.str.join("-"), "pd.Series[str]")) + _check(assert_type(s.str.capitalize(), "pd.Series[str]")) + _check(assert_type(s.str.casefold(), "pd.Series[str]")) + check(assert_type(s.str.cat(sep="X"), str), str) + _check(assert_type(s.str.center(10), "pd.Series[str]")) + _check(assert_type(s.str.get(2), "pd.Series[str]")) + _check(assert_type(s.str.ljust(80), "pd.Series[str]")) + _check(assert_type(s.str.lower(), "pd.Series[str]")) + _check(assert_type(s.str.lstrip("a"), "pd.Series[str]")) + _check(assert_type(s.str.normalize("NFD"), "pd.Series[str]")) + _check(assert_type(s.str.pad(80, "right"), "pd.Series[str]")) + _check(assert_type(s.str.removeprefix("a"), "pd.Series[str]")) + _check(assert_type(s.str.removesuffix("e"), "pd.Series[str]")) + _check(assert_type(s.str.repeat(2), "pd.Series[str]")) + _check(assert_type(s.str.replace("a", "X"), "pd.Series[str]")) + _check(assert_type(s.str.rjust(80), "pd.Series[str]")) + _check(assert_type(s.str.rstrip(), "pd.Series[str]")) + _check(assert_type(s.str.slice_replace(0, 2, "XX"), "pd.Series[str]")) + _check(assert_type(s.str.strip(), "pd.Series[str]")) + _check(assert_type(s.str.swapcase(), "pd.Series[str]")) + _check(assert_type(s.str.title(), "pd.Series[str]")) + _check( + assert_type(s.str.translate({241: "n"}), "pd.Series[str]"), + ) + _check(assert_type(s.str.upper(), "pd.Series[str]")) + _check(assert_type(s.str.wrap(80), "pd.Series[str]")) + _check(assert_type(s.str.zfill(10), "pd.Series[str]")) + s_bytes = pd.Series([b"a1", b"b2", b"c3"]) + _check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]")) + s_list = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + _check(assert_type(s_list.str.join("-"), "pd.Series[str]")) def test_string_accessors_string_index(): - idx = pd.Index([b"a1", b"b2", b"c3"]) + idx = pd.Index(DATA) _check = functools.partial(check, klass=pd.Index, dtype=str) - _check(assert_type(idx.str.decode("utf-8"), "pd.Index[str]")) - idx2: "pd.Index[list]" = pd.Index( - [["apple", "banana"], ["cherry", "date"], [1, "eggplant"]] + _check(assert_type(idx.str.capitalize(), "pd.Index[str]")) + _check(assert_type(idx.str.casefold(), "pd.Index[str]")) + check(assert_type(idx.str.cat(sep="X"), str), str) + _check(assert_type(idx.str.center(10), "pd.Index[str]")) + _check(assert_type(idx.str.get(2), "pd.Index[str]")) + _check(assert_type(idx.str.ljust(80), "pd.Index[str]")) + _check(assert_type(idx.str.lower(), "pd.Index[str]")) + _check(assert_type(idx.str.lstrip("a"), "pd.Index[str]")) + _check(assert_type(idx.str.normalize("NFD"), "pd.Index[str]")) + _check(assert_type(idx.str.pad(80, "right"), "pd.Index[str]")) + _check(assert_type(idx.str.removeprefix("a"), "pd.Index[str]")) + _check(assert_type(idx.str.removesuffix("e"), "pd.Index[str]")) + _check(assert_type(idx.str.repeat(2), "pd.Index[str]")) + _check(assert_type(idx.str.replace("a", "X"), "pd.Index[str]")) + _check(assert_type(idx.str.rjust(80), "pd.Index[str]")) + _check(assert_type(idx.str.rstrip(), "pd.Index[str]")) + _check(assert_type(idx.str.slice_replace(0, 2, "XX"), "pd.Index[str]")) + _check(assert_type(idx.str.strip(), "pd.Index[str]")) + _check(assert_type(idx.str.swapcase(), "pd.Index[str]")) + _check(assert_type(idx.str.title(), "pd.Index[str]")) + _check( + assert_type(idx.str.translate({241: "n"}), "pd.Index[str]"), ) - _check(assert_type(idx2.str.join("-"), "pd.Index[str]")) + _check(assert_type(idx.str.upper(), "pd.Index[str]")) + _check(assert_type(idx.str.wrap(80), "pd.Index[str]")) + _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) + idx_bytes = pd.Index([b"a1", b"b2", b"c3"]) + _check(assert_type(idx_bytes.str.decode("utf-8"), "pd.Index[str]")) + idx_list = pd.Index([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + _check(assert_type(idx_list.str.join("-"), "pd.Index[str]")) def test_string_accessors_bytes_series(): @@ -325,6 +328,12 @@ def test_series_overloads_cat(): ) unknown_s = pd.DataFrame({"a": list("abcdefg")})["a"] check(assert_type(s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), pd.Series, str) + check(assert_type(unknown_s.str.cat(s, sep=";"), "pd.Series[str]"), pd.Series, str) + check( + assert_type(unknown_s.str.cat(unknown_s, sep=";"), "pd.Series[str]"), + pd.Series, + str, + ) def test_index_overloads_cat(): @@ -351,6 +360,14 @@ def test_index_overloads_cat(): check( assert_type(idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), pd.Index, str ) + check( + assert_type(unknown_idx.str.cat(idx, sep=";"), "pd.Index[str]"), pd.Index, str + ) + check( + assert_type(unknown_idx.str.cat(unknown_idx, sep=";"), "pd.Index[str]"), + pd.Index, + str, + ) def test_series_overloads_extract(): From 3dc660e6ad70e6351ae1985918b842726074a835 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Mon, 10 Mar 2025 13:50:43 +0000 Subject: [PATCH 36/39] mypy fixup --- tests/test_string_accessors.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index 54fd79909..b664fa798 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -174,7 +174,9 @@ def test_string_accessors_string_index(): _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) idx_bytes = pd.Index([b"a1", b"b2", b"c3"]) _check(assert_type(idx_bytes.str.decode("utf-8"), "pd.Index[str]")) - idx_list = pd.Index([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + idx_list: "pd.Index[list]" = pd.Index( + [["apple", "banana"], ["cherry", "date"], [1, "eggplant"]] + ) _check(assert_type(idx_list.str.join("-"), "pd.Index[str]")) From b2d4657186ea6f08b6cbf1286a531b0adfa70eb2 Mon Sep 17 00:00:00 2001 From: Irv Lustig Date: Mon, 10 Mar 2025 15:55:05 -0400 Subject: [PATCH 37/39] disallow .str on certain series types --- pandas-stubs/core/frame.pyi | 13 ++++--- pandas-stubs/core/series.pyi | 62 +++++++++++++++++++++++++++------- tests/test_string_accessors.py | 14 ++++++-- 3 files changed, 69 insertions(+), 20 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 03fab417e..59a411af9 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -44,7 +44,10 @@ from pandas.core.indexing import ( _LocIndexer, ) from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg -from pandas.core.series import Series +from pandas.core.series import ( + Series, + UnknownSeries, +) from pandas.core.window import ( Expanding, ExponentialMovingWindow, @@ -244,24 +247,24 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]): if sys.version_info >= (3, 12): class _GetItemHack: @overload - def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> UnknownSeries: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] self, key: Iterable[Hashable] | slice ) -> Self: ... @overload - def __getitem__(self, key: Hashable) -> Series: ... + def __getitem__(self, key: Hashable) -> UnknownSeries: ... else: class _GetItemHack: @overload - def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> UnknownSeries: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __getitem__( # pyright: ignore[reportOverlappingOverload] self, key: Iterable[Hashable] | slice ) -> Self: ... @overload - def __getitem__(self, key: Hashable) -> Series: ... + def __getitem__(self, key: Hashable) -> UnknownSeries: ... class DataFrame(NDFrame, OpsMixin, _GetItemHack): diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 4863a2ee7..34f685401 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -231,6 +231,54 @@ class _LocIndexerSeries(_LocIndexer, Generic[S1]): value: S1 | ArrayLike | Series[S1] | None, ) -> None: ... +class _StrMethods: + @overload + def __get__(self, instance: Series[str], owner: Any) -> StringMethods[ + Series[str], + DataFrame, + Series[bool], + Series[list[str]], + Series[int], + Series[bytes], + Series[str], + Series[type[object]], + ]: ... + @overload + def __get__(self, instance: Series[bytes], owner: Any) -> StringMethods[ + Series[bytes], + DataFrame, + Series[bool], + Series[list[str]], + Series[int], + Series[bytes], + Series[str], + Series[type[object]], + ]: ... + @overload + def __get__(self, instance: Series[list[str]], owner: Any) -> StringMethods[ + Series[list[str]], + DataFrame, + Series[bool], + Series[list[str]], + Series[int], + Series[bytes], + Series[str], + Series[type[object]], + ]: ... + @overload + def __get__(self, instance: Series[S1], owner: Any) -> NoReturn: ... + @overload + def __get__(self, instance: UnknownSeries, owner: Any) -> StringMethods[ + Series, + DataFrame, + Series[bool], + Series[list[str]], + Series[int], + Series[bytes], + Series[str], + Series[type[object]], + ]: ... + _ListLike: TypeAlias = ( ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | IndexOpsMixin[S1] ) @@ -1153,19 +1201,7 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: _bool = ..., ) -> Series[S1]: ... def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ... - @property - def str( - self, - ) -> StringMethods[ - Self, - DataFrame, - Series[bool], - Series[list[str]], - Series[int], - Series[bytes], - Series[str], - Series[type[object]], - ]: ... + str: _StrMethods @property def dt(self) -> CombinedDatetimelikeProperties: ... @property diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index b664fa798..af85a5c16 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -6,7 +6,10 @@ import pandas as pd from typing_extensions import assert_type -from tests import check +from tests import ( + TYPE_CHECKING_INVALID_USAGE, + check, +) # Separately define here so pytest works np_ndarray_bool = npt.NDArray[np.bool_] @@ -139,7 +142,7 @@ def test_string_accessors_string_series(): _check(assert_type(s.str.zfill(10), "pd.Series[str]")) s_bytes = pd.Series([b"a1", b"b2", b"c3"]) _check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]")) - s_list = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + s_list = pd.Series([["apple", "banana"], ["cherry", "date"], ["foo", "eggplant"]]) _check(assert_type(s_list.str.join("-"), "pd.Series[str]")) @@ -415,3 +418,10 @@ def test_index_overloads_extract(): pd.Index, object, ) + + +def test_series_unknown(): + if TYPE_CHECKING_INVALID_USAGE: + s = pd.Series([1, 2, 3]) + s.str.startswith("a") # type:ignore[attr-defined] + s.str.slice(2, 4) # type:ignore[attr-defined] From ce7575ee89014c48880bd366af4b6d17ebebbb52 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 11 Mar 2025 11:00:30 +0000 Subject: [PATCH 38/39] Revert "disallow .str on certain series types" This reverts commit b2d4657186ea6f08b6cbf1286a531b0adfa70eb2. --- pandas-stubs/core/frame.pyi | 13 +++---- pandas-stubs/core/series.pyi | 62 +++++++--------------------------- tests/test_string_accessors.py | 14 ++------ 3 files changed, 20 insertions(+), 69 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 59a411af9..03fab417e 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -44,10 +44,7 @@ from pandas.core.indexing import ( _LocIndexer, ) from pandas.core.interchange.dataframe_protocol import DataFrame as DataFrameXchg -from pandas.core.series import ( - Series, - UnknownSeries, -) +from pandas.core.series import Series from pandas.core.window import ( Expanding, ExponentialMovingWindow, @@ -247,24 +244,24 @@ class _LocIndexerFrame(_LocIndexer, Generic[_T]): if sys.version_info >= (3, 12): class _GetItemHack: @overload - def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> UnknownSeries: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __getitem__( # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] self, key: Iterable[Hashable] | slice ) -> Self: ... @overload - def __getitem__(self, key: Hashable) -> UnknownSeries: ... + def __getitem__(self, key: Hashable) -> Series: ... else: class _GetItemHack: @overload - def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> UnknownSeries: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] + def __getitem__(self, key: Scalar | tuple[Hashable, ...]) -> Series: ... # type: ignore[overload-overlap] # pyright: ignore[reportOverlappingOverload] @overload def __getitem__( # pyright: ignore[reportOverlappingOverload] self, key: Iterable[Hashable] | slice ) -> Self: ... @overload - def __getitem__(self, key: Hashable) -> UnknownSeries: ... + def __getitem__(self, key: Hashable) -> Series: ... class DataFrame(NDFrame, OpsMixin, _GetItemHack): diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 34f685401..4863a2ee7 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -231,54 +231,6 @@ class _LocIndexerSeries(_LocIndexer, Generic[S1]): value: S1 | ArrayLike | Series[S1] | None, ) -> None: ... -class _StrMethods: - @overload - def __get__(self, instance: Series[str], owner: Any) -> StringMethods[ - Series[str], - DataFrame, - Series[bool], - Series[list[str]], - Series[int], - Series[bytes], - Series[str], - Series[type[object]], - ]: ... - @overload - def __get__(self, instance: Series[bytes], owner: Any) -> StringMethods[ - Series[bytes], - DataFrame, - Series[bool], - Series[list[str]], - Series[int], - Series[bytes], - Series[str], - Series[type[object]], - ]: ... - @overload - def __get__(self, instance: Series[list[str]], owner: Any) -> StringMethods[ - Series[list[str]], - DataFrame, - Series[bool], - Series[list[str]], - Series[int], - Series[bytes], - Series[str], - Series[type[object]], - ]: ... - @overload - def __get__(self, instance: Series[S1], owner: Any) -> NoReturn: ... - @overload - def __get__(self, instance: UnknownSeries, owner: Any) -> StringMethods[ - Series, - DataFrame, - Series[bool], - Series[list[str]], - Series[int], - Series[bytes], - Series[str], - Series[type[object]], - ]: ... - _ListLike: TypeAlias = ( ArrayLike | dict[_str, np.ndarray] | Sequence[S1] | IndexOpsMixin[S1] ) @@ -1201,7 +1153,19 @@ class Series(IndexOpsMixin[S1], NDFrame): copy: _bool = ..., ) -> Series[S1]: ... def to_period(self, freq: _str | None = ..., copy: _bool = ...) -> DataFrame: ... - str: _StrMethods + @property + def str( + self, + ) -> StringMethods[ + Self, + DataFrame, + Series[bool], + Series[list[str]], + Series[int], + Series[bytes], + Series[str], + Series[type[object]], + ]: ... @property def dt(self) -> CombinedDatetimelikeProperties: ... @property diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index af85a5c16..b664fa798 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -6,10 +6,7 @@ import pandas as pd from typing_extensions import assert_type -from tests import ( - TYPE_CHECKING_INVALID_USAGE, - check, -) +from tests import check # Separately define here so pytest works np_ndarray_bool = npt.NDArray[np.bool_] @@ -142,7 +139,7 @@ def test_string_accessors_string_series(): _check(assert_type(s.str.zfill(10), "pd.Series[str]")) s_bytes = pd.Series([b"a1", b"b2", b"c3"]) _check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]")) - s_list = pd.Series([["apple", "banana"], ["cherry", "date"], ["foo", "eggplant"]]) + s_list = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) _check(assert_type(s_list.str.join("-"), "pd.Series[str]")) @@ -418,10 +415,3 @@ def test_index_overloads_extract(): pd.Index, object, ) - - -def test_series_unknown(): - if TYPE_CHECKING_INVALID_USAGE: - s = pd.Series([1, 2, 3]) - s.str.startswith("a") # type:ignore[attr-defined] - s.str.slice(2, 4) # type:ignore[attr-defined] From 3e24de01411b79a49684c7a9f6850bea99e6419f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 11 Mar 2025 14:13:24 +0000 Subject: [PATCH 39/39] use Index of list[str] --- tests/test_string_accessors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/test_string_accessors.py b/tests/test_string_accessors.py index b664fa798..4e83160f8 100644 --- a/tests/test_string_accessors.py +++ b/tests/test_string_accessors.py @@ -139,7 +139,7 @@ def test_string_accessors_string_series(): _check(assert_type(s.str.zfill(10), "pd.Series[str]")) s_bytes = pd.Series([b"a1", b"b2", b"c3"]) _check(assert_type(s_bytes.str.decode("utf-8"), "pd.Series[str]")) - s_list = pd.Series([["apple", "banana"], ["cherry", "date"], [1, "eggplant"]]) + s_list = pd.Series([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]]) _check(assert_type(s_list.str.join("-"), "pd.Series[str]")) @@ -174,9 +174,7 @@ def test_string_accessors_string_index(): _check(assert_type(idx.str.zfill(10), "pd.Index[str]")) idx_bytes = pd.Index([b"a1", b"b2", b"c3"]) _check(assert_type(idx_bytes.str.decode("utf-8"), "pd.Index[str]")) - idx_list: "pd.Index[list]" = pd.Index( - [["apple", "banana"], ["cherry", "date"], [1, "eggplant"]] - ) + idx_list = pd.Index([["apple", "banana"], ["cherry", "date"], ["one", "eggplant"]]) _check(assert_type(idx_list.str.join("-"), "pd.Index[str]"))