Skip to content

Commit 7c38e4d

Browse files
authored
Fix replace, fillna allowing pd.NA. Allow list of arguments for .loc (#312)
* Fix replace, fillna allowing pd.NA. Allow list of arguments for .loc * add check and assert_type
1 parent 55a1a0c commit 7c38e4d

File tree

4 files changed

+34
-13
lines changed

4 files changed

+34
-13
lines changed

pandas-stubs/core/frame.pyi

+9-7
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ from pandas.core.window.rolling import (
4040
)
4141
import xarray as xr
4242

43+
from pandas._libs.missing import NAType
4344
from pandas._typing import (
4445
S1,
4546
AggFuncType,
@@ -157,7 +158,8 @@ class _LocIndexerFrame(_LocIndexer):
157158
self,
158159
idx: MaskType
159160
| StrLike
160-
| tuple[MaskType | Index | Sequence[ScalarT] | Scalar | slice, ...],
161+
| tuple[MaskType | Index | Sequence[ScalarT] | Scalar | slice, ...]
162+
| list[ScalarT],
161163
value: S1 | ArrayLike | Series | DataFrame,
162164
) -> None: ...
163165
@overload
@@ -617,7 +619,7 @@ class DataFrame(NDFrame, OpsMixin):
617619
@overload
618620
def fillna(
619621
self,
620-
value: Scalar | dict | Series | DataFrame | None = ...,
622+
value: Scalar | NAType | dict | Series | DataFrame | None = ...,
621623
method: FillnaOptions | None = ...,
622624
axis: AxisType | None = ...,
623625
limit: int = ...,
@@ -628,7 +630,7 @@ class DataFrame(NDFrame, OpsMixin):
628630
@overload
629631
def fillna(
630632
self,
631-
value: Scalar | dict | Series | DataFrame | None = ...,
633+
value: Scalar | NAType | dict | Series | DataFrame | None = ...,
632634
method: FillnaOptions | None = ...,
633635
axis: AxisType | None = ...,
634636
limit: int = ...,
@@ -639,7 +641,7 @@ class DataFrame(NDFrame, OpsMixin):
639641
@overload
640642
def fillna(
641643
self,
642-
value: Scalar | dict | Series | DataFrame | None = ...,
644+
value: Scalar | NAType | dict | Series | DataFrame | None = ...,
643645
method: FillnaOptions | None = ...,
644646
axis: AxisType | None = ...,
645647
inplace: _bool | None = ...,
@@ -650,7 +652,7 @@ class DataFrame(NDFrame, OpsMixin):
650652
def replace(
651653
self,
652654
to_replace=...,
653-
value: Scalar | Sequence | Mapping | Pattern | None = ...,
655+
value: Scalar | NAType | Sequence | Mapping | Pattern | None = ...,
654656
limit: int | None = ...,
655657
regex=...,
656658
method: ReplaceMethod = ...,
@@ -661,7 +663,7 @@ class DataFrame(NDFrame, OpsMixin):
661663
def replace(
662664
self,
663665
to_replace=...,
664-
value: Scalar | Sequence | Mapping | Pattern | None = ...,
666+
value: Scalar | NAType | Sequence | Mapping | Pattern | None = ...,
665667
inplace: Literal[False] = ...,
666668
limit: int | None = ...,
667669
regex=...,
@@ -671,7 +673,7 @@ class DataFrame(NDFrame, OpsMixin):
671673
def replace(
672674
self,
673675
to_replace=...,
674-
value: Scalar | Sequence | Mapping | Pattern | None = ...,
676+
value: Scalar | NAType | Sequence | Mapping | Pattern | None = ...,
675677
inplace: _bool | None = ...,
676678
limit: int | None = ...,
677679
regex=...,

pandas-stubs/core/series.pyi

+7-6
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ from pandas.core.window.rolling import (
5656
)
5757
import xarray as xr
5858

59+
from pandas._libs.missing import NAType
5960
from pandas._typing import (
6061
S1,
6162
AggFuncTypeBase,
@@ -772,7 +773,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
772773
@overload
773774
def fillna(
774775
self,
775-
value: Scalar | dict | Series[S1] | DataFrame | None = ...,
776+
value: Scalar | NAType | dict | Series[S1] | DataFrame | None = ...,
776777
method: FillnaOptions | None = ...,
777778
axis: SeriesAxisType = ...,
778779
limit: int | None = ...,
@@ -783,7 +784,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
783784
@overload
784785
def fillna(
785786
self,
786-
value: Scalar | dict | Series[S1] | DataFrame | None = ...,
787+
value: Scalar | NAType | dict | Series[S1] | DataFrame | None = ...,
787788
method: FillnaOptions | None = ...,
788789
axis: SeriesAxisType = ...,
789790
*,
@@ -793,7 +794,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
793794
@overload
794795
def fillna(
795796
self,
796-
value: Scalar | dict | Series[S1] | DataFrame | None = ...,
797+
value: Scalar | NAType | dict | Series[S1] | DataFrame | None = ...,
797798
method: FillnaOptions | None = ...,
798799
axis: SeriesAxisType = ...,
799800
inplace: _bool = ...,
@@ -804,7 +805,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
804805
def replace(
805806
self,
806807
to_replace: _str | list | dict | Series[S1] | float | None = ...,
807-
value: Scalar | dict | list | _str | None = ...,
808+
value: Scalar | NAType | dict | list | _str | None = ...,
808809
inplace: Literal[False] = ...,
809810
limit: int | None = ...,
810811
regex=...,
@@ -814,7 +815,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
814815
def replace(
815816
self,
816817
to_replace: _str | list | dict | Series[S1] | float | None = ...,
817-
value: Scalar | dict | list | _str | None = ...,
818+
value: Scalar | NAType | dict | list | _str | None = ...,
818819
limit: int | None = ...,
819820
regex=...,
820821
method: ReplaceMethod = ...,
@@ -825,7 +826,7 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
825826
def replace(
826827
self,
827828
to_replace: _str | list | dict | Series[S1] | float | None = ...,
828-
value: Scalar | dict | list | _str | None = ...,
829+
value: Scalar | NAType | dict | list | _str | None = ...,
829830
inplace: _bool = ...,
830831
limit: int | None = ...,
831832
regex=...,

tests/test_frame.py

+16
Original file line numberDiff line numberDiff line change
@@ -1781,3 +1781,19 @@ def cond2(x: pd.DataFrame) -> pd.DataFrame:
17811781

17821782
cond3 = pd.DataFrame({"a": [True, True, False], "b": [False, False, False]})
17831783
check(assert_type(df.where(cond3), pd.DataFrame), pd.DataFrame)
1784+
1785+
1786+
def test_setitem_loc() -> None:
1787+
# GH 254
1788+
df = pd.DataFrame.from_dict(
1789+
{view: (True, True, True) for view in ["A", "B", "C"]}, orient="index"
1790+
)
1791+
df.loc[["A", "C"]] = False
1792+
my_arr = ["A", "C"]
1793+
df.loc[my_arr] = False
1794+
1795+
1796+
def test_replace_na() -> None:
1797+
# GH 262
1798+
frame = pd.DataFrame(["N/A", "foo", "bar"])
1799+
check(assert_type(frame.replace("N/A", pd.NA), pd.DataFrame), pd.DataFrame)

tests/test_series.py

+2
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def test_types_fillna() -> None:
209209
assert assert_type(s.fillna(method="bfill", inplace=True), None) is None
210210
check(assert_type(s.fillna(method="pad"), pd.Series), pd.Series)
211211
check(assert_type(s.fillna(method="ffill", limit=1), pd.Series), pd.Series)
212+
# GH 263
213+
check(assert_type(s.fillna(pd.NA), pd.Series), pd.Series)
212214

213215

214216
def test_types_sort_index() -> None:

0 commit comments

Comments
 (0)