Skip to content

Commit 2086ae2

Browse files
changhiskhanwesm
authored andcommitted
ENH: DataFrame where and mask #2109
1 parent 5b5653d commit 2086ae2

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

pandas/core/frame.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4835,6 +4835,49 @@ def combineMult(self, other):
48354835
"""
48364836
return self.mul(other, fill_value=1.)
48374837

4838+
def where(self, cond, other):
4839+
"""
4840+
Return a DataFrame with the same shape as self and whose corresponding
4841+
entries are from self where cond is True and otherwise are from other.
4842+
4843+
4844+
Parameters
4845+
----------
4846+
cond: boolean DataFrame or array
4847+
other: scalar or DataFrame
4848+
4849+
Returns
4850+
-------
4851+
wh: DataFrame
4852+
"""
4853+
if isinstance(cond, np.ndarray):
4854+
if cond.shape != self.shape:
4855+
raise ValueError('Array onditional must be same shape as self')
4856+
cond = self._constructor(cond, index=self.index, columns=self.columns)
4857+
if cond.shape != self.shape:
4858+
cond = cond.reindex(self.index, columns=self.columns)
4859+
cond = cond.fillna(False)
4860+
4861+
if isinstance(other, DataFrame):
4862+
_, other = self.align(other, join='left', fill_value=np.nan)
4863+
4864+
rs = np.where(cond, self, other)
4865+
return self._constructor(rs, self.index, self.columns)
4866+
4867+
def mask(self, cond):
4868+
"""
4869+
Returns copy of self whose values are replaced with nan if the
4870+
corresponding entry in cond is False
4871+
4872+
Parameters
4873+
----------
4874+
cond: boolean DataFrame or array
4875+
4876+
Returns
4877+
-------
4878+
wh: DataFrame
4879+
"""
4880+
return self.where(cond, np.nan)
48384881

48394882
_EMPTY_SERIES = Series([])
48404883

pandas/tests/test_frame.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5078,6 +5078,34 @@ def test_align_int_fill_bug(self):
50785078
expected = df2 - df2.mean()
50795079
assert_frame_equal(result, expected)
50805080

5081+
def test_where(self):
5082+
df = DataFrame(np.random.randn(5, 3))
5083+
cond = df > 0
5084+
5085+
other1 = df + 1
5086+
rs = df.where(cond, other1)
5087+
for k, v in rs.iteritems():
5088+
assert_series_equal(v, np.where(cond[k], df[k], other1[k]))
5089+
5090+
other2 = (df + 1).values
5091+
rs = df.where(cond, other2)
5092+
for k, v in rs.iteritems():
5093+
assert_series_equal(v, np.where(cond[k], df[k], other2[:, k]))
5094+
5095+
other5 = np.nan
5096+
rs = df.where(cond, other5)
5097+
for k, v in rs.iteritems():
5098+
assert_series_equal(v, np.where(cond[k], df[k], other5))
5099+
5100+
assert_frame_equal(rs, df.mask(cond))
5101+
5102+
err1 = (df + 1).values[0:2, :]
5103+
self.assertRaises(ValueError, df.where, cond, err1)
5104+
5105+
err2 = cond.ix[:2, :].values
5106+
self.assertRaises(ValueError, df.where, err2, other1)
5107+
5108+
50815109
#----------------------------------------------------------------------
50825110
# Transposing
50835111

0 commit comments

Comments
 (0)