From 378983cd1baa893ed42cbca090c4ce4af7324ae9 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 12 Mar 2025 15:39:10 +0000 Subject: [PATCH 1/3] add failing test --- tests/test_frame.py | 6 +++++- tests/test_series.py | 5 +++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/tests/test_frame.py b/tests/test_frame.py index e947ceaa9..ae0c8bea0 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2574,8 +2574,10 @@ def test_types_replace() -> None: def test_dataframe_replace() -> None: - df = pd.DataFrame({"col1": ["a", "ab", "ba"]}) + df = pd.DataFrame({"col1": ["a", "ab", "ba"], "col2": [0, 1, 2]}) pattern = re.compile(r"^a.*") + replace_dict_scalar = {0: 1} + replace_dict_per_column = {"col2": {0: 1}} check(assert_type(df.replace("a", "x"), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace("a", "x", regex=True), pd.DataFrame), pd.DataFrame) @@ -2592,6 +2594,7 @@ def test_dataframe_replace() -> None: ) check(assert_type(df.replace({"a": "x"}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(replace_dict_scalar), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace({pattern: "x"}), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace(pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace(regex={"a": "x"}), pd.DataFrame), pd.DataFrame) @@ -2670,6 +2673,7 @@ def test_dataframe_replace() -> None: ) check(assert_type(df.replace({"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(replace_dict_per_column), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace({"col1": {pattern: "x"}}), pd.DataFrame), pd.DataFrame) check( assert_type(df.replace({"col1": pd.Series({"a": "x"})}), pd.DataFrame), diff --git a/tests/test_series.py b/tests/test_series.py index 61e1939a7..752450dc6 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1448,12 +1448,17 @@ def test_types_replace() -> None: def test_series_replace() -> None: s: pd.Series[str] = pd.DataFrame({"col1": ["a", "ab", "ba"]})["col1"] pattern = re.compile(r"^a.*") + replace_dict = {"a": "b"} check(assert_type(s.replace("a", "x"), "pd.Series[str]"), pd.Series) check(assert_type(s.replace(pattern, "x"), "pd.Series[str]"), pd.Series) check( assert_type(s.replace({"a": "z"}), "pd.Series[str]"), pd.Series, ) + check( + assert_type(s.replace(replace_dict), "pd.Series[str]"), + pd.Series, + ) check( assert_type(s.replace(pd.Series({"a": "z"})), "pd.Series[str]"), pd.Series, From 694d40d8f3802211e0b798005ef7fb3b5437ceff Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 12 Mar 2025 15:34:45 +0000 Subject: [PATCH 2/3] fix --- pandas-stubs/_typing.pyi | 2 +- pandas-stubs/core/frame.pyi | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 46c2c6438..635a418f5 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -740,7 +740,7 @@ ReplaceValue: TypeAlias = ( | Pattern | NAType | Sequence[Scalar | Pattern] - | Mapping[Hashable, Scalar] + | Mapping[HashableT, ScalarT] | Series[Any] | None ) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 0fd4a673e..ebcdbc7c4 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -797,20 +797,20 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): @overload def replace( self, - to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., - value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + value: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., *, inplace: Literal[True], - regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + regex: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., ) -> None: ... @overload def replace( self, - to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., - value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + value: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., *, inplace: Literal[False] = ..., - regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + regex: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., ) -> Self: ... def shift( self, From ed1aceef2996fa06e3525cef3339274c912fb1c1 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 12 Mar 2025 17:14:24 +0000 Subject: [PATCH 3/3] use HashableT3 for value/regex --- pandas-stubs/core/frame.pyi | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index ebcdbc7c4..a508c1c1a 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -92,6 +92,7 @@ from pandas._typing import ( HashableT, HashableT1, HashableT2, + HashableT3, IgnoreRaise, IndexingInt, IndexLabel, @@ -798,19 +799,19 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): def replace( self, to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., - value: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + value: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., *, inplace: Literal[True], - regex: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + regex: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., ) -> None: ... @overload def replace( self, to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., - value: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + value: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., *, inplace: Literal[False] = ..., - regex: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + regex: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., ) -> Self: ... def shift( self,