Skip to content

Commit 4b99ad8

Browse files
authored
Fix: DataFrame.replace incompatible with type "dict[str, dict[int, int]]" (pandas-dev#1164)
* add failing test * fix * use HashableT3 for value/regex
1 parent 69b833c commit 4b99ad8

File tree

4 files changed

+18
-8
lines changed

4 files changed

+18
-8
lines changed

pandas-stubs/_typing.pyi

+1-1
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@ ReplaceValue: TypeAlias = (
740740
| Pattern
741741
| NAType
742742
| Sequence[Scalar | Pattern]
743-
| Mapping[Hashable, Scalar]
743+
| Mapping[HashableT, ScalarT]
744744
| Series[Any]
745745
| None
746746
)

pandas-stubs/core/frame.pyi

+7-6
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ from pandas._typing import (
9292
HashableT,
9393
HashableT1,
9494
HashableT2,
95+
HashableT3,
9596
IgnoreRaise,
9697
IndexingInt,
9798
IndexLabel,
@@ -797,20 +798,20 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
797798
@overload
798799
def replace(
799800
self,
800-
to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
801-
value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
801+
to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ...,
802+
value: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ...,
802803
*,
803804
inplace: Literal[True],
804-
regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
805+
regex: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ...,
805806
) -> None: ...
806807
@overload
807808
def replace(
808809
self,
809-
to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
810-
value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
810+
to_replace: ReplaceValue | Mapping[HashableT2, ReplaceValue] = ...,
811+
value: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ...,
811812
*,
812813
inplace: Literal[False] = ...,
813-
regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
814+
regex: ReplaceValue | Mapping[HashableT3, ReplaceValue] = ...,
814815
) -> Self: ...
815816
def shift(
816817
self,

tests/test_frame.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -2574,8 +2574,10 @@ def test_types_replace() -> None:
25742574

25752575

25762576
def test_dataframe_replace() -> None:
2577-
df = pd.DataFrame({"col1": ["a", "ab", "ba"]})
2577+
df = pd.DataFrame({"col1": ["a", "ab", "ba"], "col2": [0, 1, 2]})
25782578
pattern = re.compile(r"^a.*")
2579+
replace_dict_scalar = {0: 1}
2580+
replace_dict_per_column = {"col2": {0: 1}}
25792581
check(assert_type(df.replace("a", "x"), pd.DataFrame), pd.DataFrame)
25802582
check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame)
25812583
check(assert_type(df.replace("a", "x", regex=True), pd.DataFrame), pd.DataFrame)
@@ -2592,6 +2594,7 @@ def test_dataframe_replace() -> None:
25922594
)
25932595

25942596
check(assert_type(df.replace({"a": "x"}), pd.DataFrame), pd.DataFrame)
2597+
check(assert_type(df.replace(replace_dict_scalar), pd.DataFrame), pd.DataFrame)
25952598
check(assert_type(df.replace({pattern: "x"}), pd.DataFrame), pd.DataFrame)
25962599
check(assert_type(df.replace(pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame)
25972600
check(assert_type(df.replace(regex={"a": "x"}), pd.DataFrame), pd.DataFrame)
@@ -2670,6 +2673,7 @@ def test_dataframe_replace() -> None:
26702673
)
26712674

26722675
check(assert_type(df.replace({"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame)
2676+
check(assert_type(df.replace(replace_dict_per_column), pd.DataFrame), pd.DataFrame)
26732677
check(assert_type(df.replace({"col1": {pattern: "x"}}), pd.DataFrame), pd.DataFrame)
26742678
check(
26752679
assert_type(df.replace({"col1": pd.Series({"a": "x"})}), pd.DataFrame),

tests/test_series.py

+5
Original file line numberDiff line numberDiff line change
@@ -1448,12 +1448,17 @@ def test_types_replace() -> None:
14481448
def test_series_replace() -> None:
14491449
s: pd.Series[str] = pd.DataFrame({"col1": ["a", "ab", "ba"]})["col1"]
14501450
pattern = re.compile(r"^a.*")
1451+
replace_dict = {"a": "b"}
14511452
check(assert_type(s.replace("a", "x"), "pd.Series[str]"), pd.Series)
14521453
check(assert_type(s.replace(pattern, "x"), "pd.Series[str]"), pd.Series)
14531454
check(
14541455
assert_type(s.replace({"a": "z"}), "pd.Series[str]"),
14551456
pd.Series,
14561457
)
1458+
check(
1459+
assert_type(s.replace(replace_dict), "pd.Series[str]"),
1460+
pd.Series,
1461+
)
14571462
check(
14581463
assert_type(s.replace(pd.Series({"a": "z"})), "pd.Series[str]"),
14591464
pd.Series,

0 commit comments

Comments
 (0)