Skip to content

Commit 8093941

Browse files
Chang Shewesm
Chang She
authored andcommitted
ENH: filter_func keyword to DataFrame.update, related to #1477
1 parent 6521f04 commit 8093941

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

pandas/core/frame.py

+11-6
Original file line numberDiff line numberDiff line change
@@ -3143,7 +3143,7 @@ def combine_first(self, other):
31433143
combiner = lambda x, y: np.where(isnull(x), y, x)
31443144
return self.combine(other, combiner)
31453145

3146-
def update(self, other, join='left', overwrite=True):
3146+
def update(self, other, join='left', overwrite=True, filter_func=None):
31473147
"""
31483148
Modify DataFrame in place using non-NA values from passed
31493149
DataFrame. Aligns on indices
@@ -3153,8 +3153,10 @@ def update(self, other, join='left', overwrite=True):
31533153
other : DataFrame
31543154
join : {'left', 'right', 'outer', 'inner'}, default 'left'
31553155
overwrite : boolean, default True
3156-
If True then overwrite values for common keys in the calling
3157-
frame
3156+
If True then overwrite values for common keys in the calling frame
3157+
filter_func : callable(1d-array) -> 1d-array<boolean>, default None
3158+
Can choose to replace values other than NA. Return True for values
3159+
that should be updated
31583160
"""
31593161
if join != 'left':
31603162
raise NotImplementedError
@@ -3163,10 +3165,13 @@ def update(self, other, join='left', overwrite=True):
31633165
for col in self.columns:
31643166
this = self[col].values
31653167
that = other[col].values
3166-
if overwrite:
3167-
mask = isnull(that)
3168+
if filter_func is not None:
3169+
mask = -filter_func(this) | isnull(that)
31683170
else:
3169-
mask = notnull(this)
3171+
if overwrite:
3172+
mask = isnull(that)
3173+
else:
3174+
mask = notnull(this)
31703175
self[col] = np.where(mask, this, that)
31713176

31723177
#----------------------------------------------------------------------

pandas/tests/test_frame.py

+18
Original file line numberDiff line numberDiff line change
@@ -5252,6 +5252,24 @@ def test_update_nooverwrite(self):
52525252
[1.5, nan, 3.]])
52535253
assert_frame_equal(df, expected)
52545254

5255+
def test_update_filtered(self):
5256+
df = DataFrame([[1.5, nan, 3.],
5257+
[1.5, nan, 3.],
5258+
[1.5, nan, 3],
5259+
[1.5, nan, 3]])
5260+
5261+
other = DataFrame([[3.6, 2., np.nan],
5262+
[np.nan, np.nan, 7]], index=[1, 3])
5263+
5264+
df.update(other, filter_func=lambda x: x > 2)
5265+
5266+
expected = DataFrame([[1.5, nan, 3],
5267+
[1.5, nan, 3],
5268+
[1.5, nan, 3],
5269+
[1.5, nan, 7.]])
5270+
assert_frame_equal(df, expected)
5271+
5272+
52555273
def test_combineAdd(self):
52565274
# trivial
52575275
comb = self.frame.combineAdd(self.frame)

0 commit comments

Comments
 (0)