Skip to content

Commit a58e322

Browse files
committed
BUG: Accept generic array-like in .where
[ci skip]
1 parent d6f8b46 commit a58e322

File tree

9 files changed

+81
-16
lines changed

9 files changed

+81
-16
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,7 @@ Bug Fixes
509509

510510

511511

512+
- Bug in ``Series.where()`` and ``DataFrame.where()`` where array-like conditionals were being rejected (:issue:`15414`)
512513
- Bug in ``Series`` construction with a datetimetz (:issue:`14928`)
513514

514515
- Bug in compat for passing long integers to ``Timestamp.replace`` (:issue:`15030`)

pandas/core/generic.py

+6-13
Original file line numberDiff line numberDiff line change
@@ -4719,8 +4719,7 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
47194719
cond, _ = cond.align(self, join='right', broadcast_axis=1)
47204720
else:
47214721
if not hasattr(cond, 'shape'):
4722-
raise ValueError('where requires an ndarray like object for '
4723-
'its condition')
4722+
cond = np.asanyarray(cond)
47244723
if cond.shape != self.shape:
47254724
raise ValueError('Array conditional must be same shape as '
47264725
'self')
@@ -4878,26 +4877,20 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
48784877
48794878
Parameters
48804879
----------
4881-
cond : boolean %(klass)s, array or callable
4880+
cond : boolean %(klass)s, array-like, or callable
48824881
If cond is callable, it is computed on the %(klass)s and
4883-
should return boolean %(klass)s or array.
4884-
The callable must not change input %(klass)s
4885-
(though pandas doesn't check it).
4882+
should return boolean %(klass)s or array. The callable must
4883+
not change input %(klass)s (though pandas doesn't check it).
48864884
48874885
.. versionadded:: 0.18.1
48884886
4889-
A callable can be used as cond.
4890-
48914887
other : scalar, %(klass)s, or callable
48924888
If other is callable, it is computed on the %(klass)s and
4893-
should return scalar or %(klass)s.
4894-
The callable must not change input %(klass)s
4895-
(though pandas doesn't check it).
4889+
should return scalar or %(klass)s. The callable must not
4890+
change input %(klass)s (though pandas doesn't check it).
48964891
48974892
.. versionadded:: 0.18.1
48984893
4899-
A callable can be used as other.
4900-
49014894
inplace : boolean, default False
49024895
Whether to perform the operation in place on the data
49034896
axis : alignment axis if needed, default None

pandas/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def repeat(self, repeats, *args, **kwargs):
574574
575575
Parameters
576576
----------
577-
cond : boolean same length as self
577+
cond : boolean array-like with the same length as self
578578
other : scalar, or array-like
579579
"""
580580

pandas/tests/frame/test_indexing.py

+20
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,26 @@ def _check_set(df, cond, check_dtypes=True):
24792479
expected = df[df['a'] == 1].reindex(df.index)
24802480
assert_frame_equal(result, expected)
24812481

2482+
def test_where_array_like(self):
2483+
# see gh-15414
2484+
klasses = [list, tuple, np.array]
2485+
2486+
df = DataFrame({'a': [1, 2, 3]})
2487+
cond = [[False], [True], [True]]
2488+
expected = DataFrame({'a': [np.nan, 2, 3]})
2489+
2490+
for klass in klasses:
2491+
result = df.where(klass(cond))
2492+
assert_frame_equal(result, expected)
2493+
2494+
df['b'] = 2
2495+
expected['b'] = [2, np.nan, 2]
2496+
cond = [[False, True], [True, False], [True, True]]
2497+
2498+
for klass in klasses:
2499+
result = df.where(klass(cond))
2500+
assert_frame_equal(result, expected)
2501+
24822502
def test_where_bug(self):
24832503

24842504
# GH 2793

pandas/tests/indexes/common.py

+12
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,18 @@ def test_where(self):
497497
result = i.where(cond)
498498
tm.assert_index_equal(result, expected)
499499

500+
def test_where_array_like(self):
501+
i = self.create_index()
502+
503+
_nan = i._na_value
504+
cond = [False] + [True] * (len(i) - 1)
505+
klasses = [list, tuple, np.array, pd.Series]
506+
expected = pd.Index([_nan] + i[1:].tolist(), dtype=i.dtype)
507+
508+
for klass in klasses:
509+
result = i.where(klass(cond))
510+
tm.assert_index_equal(result, expected)
511+
500512
def test_setops_errorcases(self):
501513
for name, idx in compat.iteritems(self.indices):
502514
# # non-iterable input

pandas/tests/indexes/period/test_period.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,22 @@ def test_where(self):
8989
expected = i
9090
tm.assert_index_equal(result, expected)
9191

92-
i2 = i.copy()
9392
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
9493
freq='D')
9594
result = i.where(notnull(i2))
9695
expected = i2
9796
tm.assert_index_equal(result, expected)
9897

98+
def test_where_array_like(self):
99+
i = self.create_index()
100+
cond = [False] + [True] * (len(i) - 1)
101+
klasses = [list, tuple, np.array, Series]
102+
expected = pd.PeriodIndex([pd.NaT] + i[1:].tolist(), freq='D')
103+
104+
for klass in klasses:
105+
result = i.where(klass(cond))
106+
tm.assert_index_equal(result, expected)
107+
99108
def test_where_other(self):
100109

101110
i = self.create_index()

pandas/tests/indexes/test_category.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,23 @@ def test_where(self):
240240
expected = i
241241
tm.assert_index_equal(result, expected)
242242

243-
i2 = i.copy()
244243
i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(),
245244
categories=i.categories)
246245
result = i.where(notnull(i2))
247246
expected = i2
248247
tm.assert_index_equal(result, expected)
249248

249+
def test_where_array_like(self):
250+
i = self.create_index()
251+
cond = [False] + [True] * (len(i) - 1)
252+
klasses = [list, tuple, np.array, pd.Series]
253+
expected = pd.CategoricalIndex([np.nan] + i[1:].tolist(),
254+
categories=i.categories)
255+
256+
for klass in klasses:
257+
result = i.where(klass(cond))
258+
tm.assert_index_equal(result, expected)
259+
250260
def test_append(self):
251261

252262
ci = self.create_index()

pandas/tests/indexes/test_multi.py

+9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ def f():
8888

8989
self.assertRaises(NotImplementedError, f)
9090

91+
def test_where_array_like(self):
92+
i = MultiIndex.from_tuples([('A', 1), ('A', 2)])
93+
klasses = [list, tuple, np.array, pd.Series]
94+
cond = [False, True]
95+
96+
for klass in klasses:
97+
f = lambda: i.where(klass(cond))
98+
self.assertRaises(NotImplementedError, f)
99+
91100
def test_repeat(self):
92101
reps = 2
93102
numbers = [1, 2, 3]

pandas/tests/series/test_indexing.py

+11
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,17 @@ def f():
11931193
expected = Series(np.nan, index=[9])
11941194
assert_series_equal(result, expected)
11951195

1196+
def test_where_array_like(self):
1197+
# see gh-15414
1198+
s = Series([1, 2, 3])
1199+
cond = [False, True, True]
1200+
expected = Series([np.nan, 2, 3])
1201+
klasses = [list, tuple, np.array, Series]
1202+
1203+
for klass in klasses:
1204+
result = s.where(klass(cond))
1205+
assert_series_equal(result, expected)
1206+
11961207
def test_where_setitem_invalid(self):
11971208

11981209
# GH 2702

0 commit comments

Comments
 (0)