diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 46c2c643..635a418f 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -740,7 +740,7 @@ ReplaceValue: TypeAlias = ( | Pattern | NAType | Sequence[Scalar | Pattern] - | Mapping[Hashable, Scalar] + | Mapping[HashableT, ScalarT] | Series[Any] | None ) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 0fd4a673..a508c1c1 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -92,6 +92,7 @@ from pandas._typing import ( HashableT, HashableT1, HashableT2, + HashableT3, IgnoreRaise, IndexingInt, IndexLabel, @@ -797,20 +798,20 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack): @overload def replace( self, - to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., - value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + value: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., *, inplace: Literal[True], - regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + regex: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., ) -> None: ... @overload def replace( self, - to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., - value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ..., + value: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., *, inplace: Literal[False] = ..., - regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ..., + regex: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ..., ) -> Self: ... def shift( self, diff --git a/tests/test_frame.py b/tests/test_frame.py index e947ceaa..ae0c8bea 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2574,8 +2574,10 @@ def test_types_replace() -> None: def test_dataframe_replace() -> None: - df = pd.DataFrame({"col1": ["a", "ab", "ba"]}) + df = pd.DataFrame({"col1": ["a", "ab", "ba"], "col2": [0, 1, 2]}) pattern = re.compile(r"^a.*") + replace_dict_scalar = {0: 1} + replace_dict_per_column = {"col2": {0: 1}} check(assert_type(df.replace("a", "x"), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace("a", "x", regex=True), pd.DataFrame), pd.DataFrame) @@ -2592,6 +2594,7 @@ def test_dataframe_replace() -> None: ) check(assert_type(df.replace({"a": "x"}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(replace_dict_scalar), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace({pattern: "x"}), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace(pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace(regex={"a": "x"}), pd.DataFrame), pd.DataFrame) @@ -2670,6 +2673,7 @@ def test_dataframe_replace() -> None: ) check(assert_type(df.replace({"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame) + check(assert_type(df.replace(replace_dict_per_column), pd.DataFrame), pd.DataFrame) check(assert_type(df.replace({"col1": {pattern: "x"}}), pd.DataFrame), pd.DataFrame) check( assert_type(df.replace({"col1": pd.Series({"a": "x"})}), pd.DataFrame), diff --git a/tests/test_series.py b/tests/test_series.py index 61e1939a..752450dc 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1448,12 +1448,17 @@ def test_types_replace() -> None: def test_series_replace() -> None: s: pd.Series[str] = pd.DataFrame({"col1": ["a", "ab", "ba"]})["col1"] pattern = re.compile(r"^a.*") + replace_dict = {"a": "b"} check(assert_type(s.replace("a", "x"), "pd.Series[str]"), pd.Series) check(assert_type(s.replace(pattern, "x"), "pd.Series[str]"), pd.Series) check( assert_type(s.replace({"a": "z"}), "pd.Series[str]"), pd.Series, ) + check( + assert_type(s.replace(replace_dict), "pd.Series[str]"), + pd.Series, + ) check( assert_type(s.replace(pd.Series({"a": "z"})), "pd.Series[str]"), pd.Series,