Skip to content

Commit 6876c17

Browse files
committed
ENH: tweaks to Series.where #2337
1 parent 36fd857 commit 6876c17

File tree

3 files changed

+27
-3
lines changed

3 files changed

+27
-3
lines changed

RELEASE.rst

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pandas 0.10.0
3030
**New features**
3131

3232
- Add error handling to Series.str.encode/decode (#2276)
33+
- Add ``where`` and ``mask`` to Series (#2337)
3334

3435
**API Changes**
3536

pandas/core/series.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -575,15 +575,27 @@ def where(self, cond, other=nan, inplace=False):
575575
-------
576576
wh: Series
577577
"""
578+
if isinstance(cond, Series):
579+
cond = cond.reindex(self.index, fill_value=True)
578580
if not hasattr(cond, 'shape'):
579581
raise ValueError('where requires an ndarray like object for its '
580582
'condition')
583+
if len(cond) != len(self):
584+
raise ValueError('condition must have same length as series')
581585

582-
if inplace:
586+
ser = self if inplace else self.copy()
587+
if not isinstance(other, (list, tuple, np.ndarray)):
583588
self._set_with(~cond, other)
584589
return self
585590

586-
return self._get_values(cond).reindex_like(self).fillna(other)
591+
if isinstance(other, Series):
592+
other = other.reindex(ser.index)
593+
if len(other) != len(ser):
594+
raise ValueError('Length of replacements must equal series length')
595+
596+
np.putmask(ser, ~cond, other)
597+
598+
return ser
587599

588600
def mask(self, cond):
589601
"""

pandas/tests/test_series.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -953,15 +953,26 @@ def test_where(self):
953953
rs = s.where(cond)
954954
assert(s.shape == rs.shape)
955955

956+
rs = s.where(cond[:3], -s)
957+
assert_series_equal(rs, s.abs()[:3].append(s[3:]))
958+
956959
self.assertRaises(ValueError, s.where, 1)
960+
self.assertRaises(ValueError, s.where, cond[:3].values, -s)
961+
self.assertRaises(ValueError, s.where, cond, s[:3].values)
957962

958963
def test_where_inplace(self):
959964
s = Series(np.random.randn(5))
960965
cond = s > 0
961966

962967
rs = s.copy()
963-
rs.where(cond,inplace=True)
968+
rs.where(cond, inplace=True)
964969
assert_series_equal(rs.dropna(), s[cond])
970+
assert_series_equal(rs, s.where(cond))
971+
972+
rs = s.copy()
973+
rs.where(cond, -s, inplace=True)
974+
assert_series_equal(rs, s.where(cond, -s))
975+
965976

966977
def test_mask(self):
967978
s = Series(np.random.randn(5))

0 commit comments

Comments
 (0)