Skip to content

Commit ad97507

Browse files
committed
ENH:column-wise DataFrame.fillna with Series and Dict (pandas-dev#4514)
1 parent 9a68635 commit ad97507

File tree

3 files changed

+161
-12
lines changed

3 files changed

+161
-12
lines changed

doc/source/whatsnew/v1.1.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Enhancements
1818
Other enhancements
1919
^^^^^^^^^^^^^^^^^^
2020

21-
-
21+
- :meth:`DataFrame.fillna` can fill NA values column-wise with a dictionary or :class:`Series` (:issue:`4514`)
2222
-
2323

2424

pandas/core/generic.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -5960,19 +5960,39 @@ def fillna(
59605960
)
59615961

59625962
elif isinstance(value, (dict, ABCSeries)):
5963+
result = self if inplace else self.copy()
5964+
59635965
if axis == 1:
5964-
raise NotImplementedError(
5965-
"Currently only can fill "
5966-
"with dict/Series column "
5967-
"by column"
5968-
)
5966+
# To access column base
5967+
result = result.T
59695968

5970-
result = self if inplace else self.copy()
5969+
columns = result.columns
59715970
for k, v in value.items():
5972-
if k not in result:
5971+
if k not in columns.array:
59735972
continue
5974-
obj = result[k]
5975-
obj.fillna(v, limit=limit, inplace=True, downcast=downcast)
5973+
5974+
column_locs = columns.get_loc(k)
5975+
5976+
if is_scalar(column_locs):
5977+
# if label is not duplicated
5978+
result.loc[:, k] = result.loc[:, k].fillna(
5979+
v, limit=limit, inplace=False, downcast=downcast
5980+
)
5981+
5982+
else:
5983+
if isinstance(column_locs, slice):
5984+
locs = range(column_locs.start, column_locs.stop)
5985+
else:
5986+
locs = np.where(column_locs)[0]
5987+
5988+
for i in locs:
5989+
result.iloc[:, i] = result.iloc[:, i].fillna(
5990+
v, limit=limit, inplace=False, downcast=downcast
5991+
)
5992+
5993+
if axis == 1:
5994+
result = result.T
5995+
59765996
return result if not inplace else None
59775997

59785998
elif not is_list_like(value):

pandas/tests/frame/test_missing.py

+131-2
Original file line numberDiff line numberDiff line change
@@ -605,8 +605,15 @@ def test_fillna_dict_series(self):
605605
tm.assert_frame_equal(result, expected)
606606

607607
# disable this for now
608-
with pytest.raises(NotImplementedError, match="column by column"):
609-
df.fillna(df.max(1), axis=1)
608+
expected = DataFrame(
609+
{
610+
"a": [1.0, 1.0, 2.0, 3.0, 4.0],
611+
"b": [1.0, 2.0, 3.0, 3.0, 4.0],
612+
"c": [1.0, 1.0, 2.0, 3.0, 4.0],
613+
}
614+
)
615+
result = df.fillna(df.max(1), axis=1)
616+
tm.assert_frame_equal(expected, result)
610617

611618
def test_fillna_dataframe(self):
612619
# GH 8377
@@ -983,3 +990,125 @@ def test_interp_time_inplace_axis(self, axis):
983990
result = expected.interpolate(axis=0, method="time")
984991
expected.interpolate(axis=0, method="time", inplace=True)
985992
tm.assert_frame_equal(result, expected)
993+
994+
@pytest.mark.parametrize(
995+
"expected,fill_value",
996+
[
997+
(
998+
DataFrame(
999+
[[100, 100], [200, 4], [5, 6]], columns=list("AB"), dtype="float64"
1000+
),
1001+
Series([100, 200, 300]),
1002+
),
1003+
(
1004+
DataFrame(
1005+
[[100, 100], [np.nan, 4], [5, 6]],
1006+
columns=list("AB"),
1007+
dtype="float64",
1008+
),
1009+
{0: 100, 2: 300, 3: 400},
1010+
),
1011+
],
1012+
)
1013+
def test_fillna_column_wise(self, expected, fill_value):
1014+
# GH 4514
1015+
df = DataFrame([[np.nan, np.nan], [np.nan, 4], [5, 6]], columns=list("AB"))
1016+
result = df.fillna(fill_value, axis=1)
1017+
tm.assert_frame_equal(expected, result)
1018+
1019+
df.fillna(fill_value, axis=1, inplace=True)
1020+
tm.assert_frame_equal(expected, df)
1021+
1022+
def test_fillna_column_wise_downcast(self):
1023+
df = DataFrame([[np.nan, 2], [3, np.nan], [np.nan, np.nan]], columns=list("AB"))
1024+
s = Series([100, 200, 300])
1025+
1026+
expected = DataFrame(
1027+
[[100, 2], [3, 200], [300, 300]], columns=list("AB"), dtype="int64"
1028+
)
1029+
result = df.fillna(s, axis=1, downcast="infer")
1030+
tm.assert_frame_equal(expected, result)
1031+
1032+
@pytest.mark.parametrize(
1033+
"target,expected",
1034+
[
1035+
(
1036+
DataFrame(
1037+
[[np.nan, np.nan, 3], [np.nan, 5, np.nan], [7, np.nan, np.nan]],
1038+
columns=list("ABB"),
1039+
index=[0, 0, 1],
1040+
),
1041+
DataFrame(
1042+
[[100, 100, 3], [100, 5, 100], [7, 200, 200]],
1043+
columns=list("ABB"),
1044+
index=[0, 0, 1],
1045+
dtype="float64",
1046+
),
1047+
),
1048+
(
1049+
DataFrame(
1050+
[[np.nan, np.nan, 3], [np.nan, 5, np.nan], [7, np.nan, np.nan]],
1051+
columns=list("ABA"),
1052+
index=[0, 1, 0],
1053+
),
1054+
DataFrame(
1055+
[[100, 100, 3], [200, 5, 200], [7, 100, 100]],
1056+
columns=list("ABA"),
1057+
index=[0, 1, 0],
1058+
dtype="float64",
1059+
),
1060+
),
1061+
],
1062+
)
1063+
def test_fillna_column_wise_duplicated_with_series_dict(self, target, expected):
1064+
# GH 4514
1065+
s = pd.Series([100, 200, 300], index=[0, 1, 2])
1066+
d = {0: 100, 1: 200, 2: 300}
1067+
1068+
result = target.fillna(s, axis=1)
1069+
tm.assert_frame_equal(result, expected)
1070+
1071+
result = target.fillna(d, axis=1)
1072+
tm.assert_frame_equal(result, expected)
1073+
1074+
@pytest.mark.parametrize(
1075+
"target,expected",
1076+
[
1077+
(
1078+
DataFrame(
1079+
[[np.nan, np.nan, 3], [np.nan, 5, np.nan], [7, np.nan, np.nan]],
1080+
columns=list("ABB"),
1081+
index=[0, 0, 1],
1082+
),
1083+
DataFrame(
1084+
[[100, 200, 3], [100, 5, 200], [7, 200, 200]],
1085+
columns=list("ABB"),
1086+
index=[0, 0, 1],
1087+
dtype="float64",
1088+
),
1089+
),
1090+
(
1091+
DataFrame(
1092+
[[np.nan, np.nan, 3], [np.nan, 5, np.nan], [7, np.nan, np.nan]],
1093+
columns=list("ABA"),
1094+
index=[0, 1, 0],
1095+
),
1096+
DataFrame(
1097+
[[100, 200, 3], [100, 5, 100], [7, 200, 100]],
1098+
columns=list("ABA"),
1099+
index=[0, 1, 0],
1100+
dtype="float64",
1101+
),
1102+
),
1103+
],
1104+
)
1105+
def test_fillna_duplicated_with_series_dict(self, target, expected):
1106+
# GH 4514
1107+
s = pd.Series([100, 200, 300], index=["A", "B", "C"])
1108+
d = {"A": 100, "B": 200, "C": 300}
1109+
1110+
result = target.fillna(s)
1111+
tm.assert_frame_equal(result, expected)
1112+
1113+
result = target.fillna(d)
1114+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)