Skip to content

Commit 36fd857

Browse files
jrebackchanghiskhan
authored andcommitted
add where and mask methods to Series. where returns a series evaluated for the cond with a shape like the original
1 parent f29106d commit 36fd857

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

pandas/core/series.py

+38
Original file line numberDiff line numberDiff line change
@@ -562,6 +562,44 @@ def _get_values(self, indexer):
562562
except Exception:
563563
return self.values[indexer]
564564

565+
def where(self, cond, other=nan, inplace=False):
566+
"""
567+
Return a Series where cond is True; otherwise values are from other
568+
569+
Parameters
570+
----------
571+
cond: boolean Series or array
572+
other: scalar or Series
573+
574+
Returns
575+
-------
576+
wh: Series
577+
"""
578+
if not hasattr(cond, 'shape'):
579+
raise ValueError('where requires an ndarray like object for its '
580+
'condition')
581+
582+
if inplace:
583+
self._set_with(~cond, other)
584+
return self
585+
586+
return self._get_values(cond).reindex_like(self).fillna(other)
587+
588+
def mask(self, cond):
589+
"""
590+
Returns copy of self whose values are replaced with nan if the
591+
inverted condition is True
592+
593+
Parameters
594+
----------
595+
cond: boolean Series or array
596+
597+
Returns
598+
-------
599+
wh: Series
600+
"""
601+
return self.where(~cond, nan)
602+
565603
def __setitem__(self, key, value):
566604
try:
567605
try:

pandas/tests/test_series.py

+31
Original file line numberDiff line numberDiff line change
@@ -939,6 +939,37 @@ def test_ix_getitem_iterator(self):
939939
result = self.series.ix[idx]
940940
assert_series_equal(result, self.series[:10])
941941

942+
def test_where(self):
943+
s = Series(np.random.randn(5))
944+
cond = s > 0
945+
946+
rs = s.where(cond).dropna()
947+
rs2 = s[cond]
948+
assert_series_equal(rs, rs2)
949+
950+
rs = s.where(cond,-s)
951+
assert_series_equal(rs, s.abs())
952+
953+
rs = s.where(cond)
954+
assert(s.shape == rs.shape)
955+
956+
self.assertRaises(ValueError, s.where, 1)
957+
958+
def test_where_inplace(self):
959+
s = Series(np.random.randn(5))
960+
cond = s > 0
961+
962+
rs = s.copy()
963+
rs.where(cond,inplace=True)
964+
assert_series_equal(rs.dropna(), s[cond])
965+
966+
def test_mask(self):
967+
s = Series(np.random.randn(5))
968+
cond = s > 0
969+
970+
rs = s.where(cond, np.nan)
971+
assert_series_equal(rs, s.mask(~cond))
972+
942973
def test_ix_setitem(self):
943974
inds = self.series.index[[3,4,7]]
944975

0 commit comments

Comments
 (0)