Skip to content

Commit 25fe8aa

Browse files
authored
type dataframe.replace and series.replace (#1129)
* type dataframe.replace * the typing never stops * remove unused ReplaceMethod * finish dataframe.replace typing * use typealias * mypy fixup * comment
1 parent e35a729 commit 25fe8aa

File tree

5 files changed

+180
-25
lines changed

5 files changed

+180
-25
lines changed

pandas-stubs/_typing.pyi

+13-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ from collections.abc import (
1010
import datetime
1111
from datetime import tzinfo
1212
from os import PathLike
13+
from re import Pattern
1314
import sys
1415
from typing import (
1516
Any,
@@ -36,6 +37,7 @@ from typing_extensions import (
3637
)
3738

3839
from pandas._libs.interval import Interval
40+
from pandas._libs.missing import NAType
3941
from pandas._libs.tslibs import (
4042
BaseOffset,
4143
Period,
@@ -731,7 +733,17 @@ InterpolateOptions: TypeAlias = Literal[
731733
"cubicspline",
732734
"from_derivatives",
733735
]
734-
ReplaceMethod: TypeAlias = Literal["pad", "ffill", "bfill"]
736+
# Can be passed to `to_replace`, `value`, or `regex` in `Series.replace`.
737+
# `DataFrame.replace` also accepts mappings of these.
738+
ReplaceValue: TypeAlias = (
739+
Scalar
740+
| Pattern
741+
| NAType
742+
| Sequence[Scalar | Pattern]
743+
| Mapping[Hashable, Scalar]
744+
| Series[Any]
745+
| None
746+
)
735747
SortKind: TypeAlias = Literal["quicksort", "mergesort", "heapsort", "stable"]
736748
NaPosition: TypeAlias = Literal["first", "last"]
737749
JoinHow: TypeAlias = Literal["left", "right", "outer", "inner"]

pandas-stubs/core/frame.pyi

+7-12
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ from collections.abc import (
88
Sequence,
99
)
1010
import datetime as dt
11-
from re import Pattern
1211
import sys
1312
from typing import (
1413
Any,
@@ -113,7 +112,7 @@ from pandas._typing import (
113112
RandomState,
114113
ReadBuffer,
115114
Renamer,
116-
ReplaceMethod,
115+
ReplaceValue,
117116
Scalar,
118117
ScalarT,
119118
SequenceNotStr,
@@ -799,24 +798,20 @@ class DataFrame(NDFrame, OpsMixin, _GetItemHack):
799798
@overload
800799
def replace(
801800
self,
802-
to_replace=...,
803-
value: Scalar | NAType | Sequence | Mapping | Pattern | None = ...,
801+
to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
802+
value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
804803
*,
805804
inplace: Literal[True],
806-
limit: int | None = ...,
807-
regex=...,
808-
method: ReplaceMethod = ...,
805+
regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
809806
) -> None: ...
810807
@overload
811808
def replace(
812809
self,
813-
to_replace=...,
814-
value: Scalar | NAType | Sequence | Mapping | Pattern | None = ...,
810+
to_replace: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
811+
value: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
815812
*,
816813
inplace: Literal[False] = ...,
817-
limit: int | None = ...,
818-
regex=...,
819-
method: ReplaceMethod = ...,
814+
regex: ReplaceValue | Mapping[Hashable, ReplaceValue] = ...,
820815
) -> Self: ...
821816
def shift(
822817
self,

pandas-stubs/core/series.pyi

+10-12
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@ from typing import (
2424
overload,
2525
)
2626

27-
from _typing import TimeZones
27+
from _typing import (
28+
ReplaceValue,
29+
TimeZones,
30+
)
2831
from matplotlib.axes import (
2932
Axes as PlotAxes,
3033
SubplotBase,
@@ -141,7 +144,6 @@ from pandas._typing import (
141144
QuantileInterpolation,
142145
RandomState,
143146
Renamer,
144-
ReplaceMethod,
145147
Scalar,
146148
ScalarT,
147149
SequenceNotStr,
@@ -1089,24 +1091,20 @@ class Series(IndexOpsMixin[S1], NDFrame):
10891091
@overload
10901092
def replace(
10911093
self,
1092-
to_replace: _str | list | dict | Series[S1] | float | None = ...,
1093-
value: Scalar | NAType | dict | list | _str | None = ...,
1094+
to_replace: ReplaceValue = ...,
1095+
value: ReplaceValue = ...,
10941096
*,
1095-
limit: int | None = ...,
1096-
regex=...,
1097-
method: ReplaceMethod = ...,
1097+
regex: ReplaceValue = ...,
10981098
inplace: Literal[True],
10991099
) -> None: ...
11001100
@overload
11011101
def replace(
11021102
self,
1103-
to_replace: _str | list | dict | Series[S1] | float | None = ...,
1104-
value: Scalar | NAType | dict | list | _str | None = ...,
1103+
to_replace: ReplaceValue = ...,
1104+
value: ReplaceValue = ...,
11051105
*,
1106+
regex: ReplaceValue = ...,
11061107
inplace: Literal[False] = ...,
1107-
limit: int | None = ...,
1108-
regex=...,
1109-
method: ReplaceMethod = ...,
11101108
) -> Series[S1]: ...
11111109
def shift(
11121110
self,

tests/test_frame.py

+116
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import io
1515
import itertools
1616
from pathlib import Path
17+
import re
1718
import string
1819
import sys
1920
from typing import (
@@ -2570,6 +2571,121 @@ def test_types_replace() -> None:
25702571
assert assert_type(df.replace(1, 2, inplace=True), None) is None
25712572

25722573

2574+
def test_dataframe_replace() -> None:
2575+
df = pd.DataFrame({"col1": ["a", "ab", "ba"]})
2576+
pattern = re.compile(r"^a.*")
2577+
check(assert_type(df.replace("a", "x"), pd.DataFrame), pd.DataFrame)
2578+
check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame)
2579+
check(assert_type(df.replace("a", "x", regex=True), pd.DataFrame), pd.DataFrame)
2580+
check(assert_type(df.replace(pattern, "x"), pd.DataFrame), pd.DataFrame)
2581+
check(assert_type(df.replace(regex="a", value="x"), pd.DataFrame), pd.DataFrame)
2582+
check(assert_type(df.replace(regex=pattern, value="x"), pd.DataFrame), pd.DataFrame)
2583+
2584+
check(assert_type(df.replace(["a"], ["x"]), pd.DataFrame), pd.DataFrame)
2585+
check(assert_type(df.replace([pattern], ["x"]), pd.DataFrame), pd.DataFrame)
2586+
check(assert_type(df.replace(regex=["a"], value=["x"]), pd.DataFrame), pd.DataFrame)
2587+
check(
2588+
assert_type(df.replace(regex=[pattern], value=["x"]), pd.DataFrame),
2589+
pd.DataFrame,
2590+
)
2591+
2592+
check(assert_type(df.replace({"a": "x"}), pd.DataFrame), pd.DataFrame)
2593+
check(assert_type(df.replace({pattern: "x"}), pd.DataFrame), pd.DataFrame)
2594+
check(assert_type(df.replace(pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame)
2595+
check(assert_type(df.replace(regex={"a": "x"}), pd.DataFrame), pd.DataFrame)
2596+
check(assert_type(df.replace(regex={pattern: "x"}), pd.DataFrame), pd.DataFrame)
2597+
check(
2598+
assert_type(df.replace(regex=pd.Series({"a": "x"})), pd.DataFrame), pd.DataFrame
2599+
)
2600+
2601+
check(
2602+
assert_type(df.replace({"col1": "a"}, {"col1": "x"}), pd.DataFrame),
2603+
pd.DataFrame,
2604+
)
2605+
check(
2606+
assert_type(df.replace({"col1": pattern}, {"col1": "x"}), pd.DataFrame),
2607+
pd.DataFrame,
2608+
)
2609+
check(
2610+
assert_type(
2611+
df.replace(pd.Series({"col1": "a"}), pd.Series({"col1": "x"})), pd.DataFrame
2612+
),
2613+
pd.DataFrame,
2614+
)
2615+
check(
2616+
assert_type(df.replace(regex={"col1": "a"}, value={"col1": "x"}), pd.DataFrame),
2617+
pd.DataFrame,
2618+
)
2619+
check(
2620+
assert_type(
2621+
df.replace(regex={"col1": pattern}, value={"col1": "x"}), pd.DataFrame
2622+
),
2623+
pd.DataFrame,
2624+
)
2625+
check(
2626+
assert_type(
2627+
df.replace(regex=pd.Series({"col1": "a"}), value=pd.Series({"col1": "x"})),
2628+
pd.DataFrame,
2629+
),
2630+
pd.DataFrame,
2631+
)
2632+
2633+
check(
2634+
assert_type(df.replace({"col1": ["a"]}, {"col1": ["x"]}), pd.DataFrame),
2635+
pd.DataFrame,
2636+
)
2637+
check(
2638+
assert_type(df.replace({"col1": [pattern]}, {"col1": ["x"]}), pd.DataFrame),
2639+
pd.DataFrame,
2640+
)
2641+
check(
2642+
assert_type(
2643+
df.replace(pd.Series({"col1": ["a"]}), pd.Series({"col1": ["x"]})),
2644+
pd.DataFrame,
2645+
),
2646+
pd.DataFrame,
2647+
)
2648+
check(
2649+
assert_type(
2650+
df.replace(regex={"col1": ["a"]}, value={"col1": ["x"]}), pd.DataFrame
2651+
),
2652+
pd.DataFrame,
2653+
)
2654+
check(
2655+
assert_type(
2656+
df.replace(regex={"col1": [pattern]}, value={"col1": ["x"]}), pd.DataFrame
2657+
),
2658+
pd.DataFrame,
2659+
)
2660+
check(
2661+
assert_type(
2662+
df.replace(
2663+
regex=pd.Series({"col1": ["a"]}), value=pd.Series({"col1": ["x"]})
2664+
),
2665+
pd.DataFrame,
2666+
),
2667+
pd.DataFrame,
2668+
)
2669+
2670+
check(assert_type(df.replace({"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame)
2671+
check(assert_type(df.replace({"col1": {pattern: "x"}}), pd.DataFrame), pd.DataFrame)
2672+
check(
2673+
assert_type(df.replace({"col1": pd.Series({"a": "x"})}), pd.DataFrame),
2674+
pd.DataFrame,
2675+
)
2676+
check(
2677+
assert_type(df.replace(regex={"col1": {"a": "x"}}), pd.DataFrame), pd.DataFrame
2678+
)
2679+
check(
2680+
assert_type(df.replace(regex={"col1": {pattern: "x"}}), pd.DataFrame),
2681+
pd.DataFrame,
2682+
)
2683+
check(
2684+
assert_type(df.replace(regex={"col1": pd.Series({"a": "x"})}), pd.DataFrame),
2685+
pd.DataFrame,
2686+
)
2687+
2688+
25732689
def test_loop_dataframe() -> None:
25742690
# GH 70
25752691
df = pd.DataFrame({"x": [1, 2, 3]})

tests/test_series.py

+34
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,40 @@ def test_types_replace() -> None:
14101410
assert assert_type(s.replace(1, 2, inplace=True), None) is None
14111411

14121412

1413+
def test_series_replace() -> None:
1414+
s: pd.Series[str] = pd.DataFrame({"col1": ["a", "ab", "ba"]})["col1"]
1415+
pattern = re.compile(r"^a.*")
1416+
check(assert_type(s.replace("a", "x"), "pd.Series[str]"), pd.Series)
1417+
check(assert_type(s.replace(pattern, "x"), "pd.Series[str]"), pd.Series)
1418+
check(
1419+
assert_type(s.replace({"a": "z"}), "pd.Series[str]"),
1420+
pd.Series,
1421+
)
1422+
check(
1423+
assert_type(s.replace(pd.Series({"a": "z"})), "pd.Series[str]"),
1424+
pd.Series,
1425+
)
1426+
check(
1427+
assert_type(s.replace({pattern: "z"}), "pd.Series[str]"),
1428+
pd.Series,
1429+
)
1430+
check(assert_type(s.replace(["a"], ["x"]), "pd.Series[str]"), pd.Series)
1431+
check(assert_type(s.replace([pattern], ["x"]), "pd.Series[str]"), pd.Series)
1432+
check(assert_type(s.replace(r"^a.*", "x", regex=True), "pd.Series[str]"), pd.Series)
1433+
check(assert_type(s.replace(value="x", regex=r"^a.*"), "pd.Series[str]"), pd.Series)
1434+
check(
1435+
assert_type(s.replace(value="x", regex=[r"^a.*"]), "pd.Series[str]"), pd.Series
1436+
)
1437+
check(assert_type(s.replace(value="x", regex=pattern), "pd.Series[str]"), pd.Series)
1438+
check(
1439+
assert_type(s.replace(value="x", regex=[pattern]), "pd.Series[str]"), pd.Series
1440+
)
1441+
check(assert_type(s.replace(regex={"a": "x"}), "pd.Series[str]"), pd.Series)
1442+
check(
1443+
assert_type(s.replace(regex=pd.Series({"a": "x"})), "pd.Series[str]"), pd.Series
1444+
)
1445+
1446+
14131447
def test_cat_accessor() -> None:
14141448
# GH 43
14151449
s: pd.Series[str] = pd.Series(

0 commit comments

Comments
 (0)