Skip to content

Commit 94d3cee

Browse files
committed
BUG: Accept generic array-like in .where
1 parent f65a641 commit 94d3cee

File tree

9 files changed

+150
-19
lines changed

9 files changed

+150
-19
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ Bug Fixes
511511

512512

513513

514+
- Bug in ``Series.where()`` and ``DataFrame.where()`` where array-like conditionals were being rejected (:issue:`15414`)
514515
- Bug in ``Series`` construction with a datetimetz (:issue:`14928`)
515516

516517
- Bug in compat for passing long integers to ``Timestamp.replace`` (:issue:`15030`)

pandas/core/generic.py

+26-16
Original file line numberDiff line numberDiff line change
@@ -4724,17 +4724,33 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
47244724
cond, _ = cond.align(self, join='right', broadcast_axis=1)
47254725
else:
47264726
if not hasattr(cond, 'shape'):
4727-
raise ValueError('where requires an ndarray like object for '
4728-
'its condition')
4727+
cond = np.asanyarray(cond)
47294728
if cond.shape != self.shape:
47304729
raise ValueError('Array conditional must be same shape as '
47314730
'self')
47324731
cond = self._constructor(cond, **self._construct_axes_dict())
47334732

4734-
if inplace:
4735-
cond = -(cond.fillna(True).astype(bool))
4733+
# If 'inplace' is True, we want to fill with True
4734+
# before inverting. If 'inplace' is False, we will
4735+
# fill with False and do nothing else.
4736+
#
4737+
# Conveniently, 'inplace' matches the boolean with
4738+
# which we want to fill.
4739+
cond = cond.fillna(bool(inplace))
4740+
4741+
msg = "Boolean array expected for the condition, not {dtype}"
4742+
4743+
if not isinstance(cond, pd.DataFrame):
4744+
# This is a single-dimensional object.
4745+
if not is_bool_dtype(cond):
4746+
raise ValueError(msg.format(dtype=cond.dtype))
47364747
else:
4737-
cond = cond.fillna(False).astype(bool)
4748+
for dt in cond.dtypes:
4749+
if not is_bool_dtype(dt):
4750+
raise ValueError(msg.format(dtype=dt))
4751+
4752+
cond = cond.astype(bool)
4753+
cond = -cond if inplace else cond
47384754

47394755
# try to align
47404756
try_quick = True
@@ -4883,26 +4899,20 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
48834899
48844900
Parameters
48854901
----------
4886-
cond : boolean %(klass)s, array or callable
4902+
cond : boolean %(klass)s, array-like, or callable
48874903
If cond is callable, it is computed on the %(klass)s and
4888-
should return boolean %(klass)s or array.
4889-
The callable must not change input %(klass)s
4890-
(though pandas doesn't check it).
4904+
should return boolean %(klass)s or array. The callable must
4905+
not change input %(klass)s (though pandas doesn't check it).
48914906
48924907
.. versionadded:: 0.18.1
48934908
4894-
A callable can be used as cond.
4895-
48964909
other : scalar, %(klass)s, or callable
48974910
If other is callable, it is computed on the %(klass)s and
4898-
should return scalar or %(klass)s.
4899-
The callable must not change input %(klass)s
4900-
(though pandas doesn't check it).
4911+
should return scalar or %(klass)s. The callable must not
4912+
change input %(klass)s (though pandas doesn't check it).
49014913
49024914
.. versionadded:: 0.18.1
49034915
4904-
A callable can be used as other.
4905-
49064916
inplace : boolean, default False
49074917
Whether to perform the operation in place on the data
49084918
axis : alignment axis if needed, default None

pandas/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def repeat(self, repeats, *args, **kwargs):
574574
575575
Parameters
576576
----------
577-
cond : boolean same length as self
577+
cond : boolean array-like with the same length as self
578578
other : scalar, or array-like
579579
"""
580580

pandas/tests/frame/test_indexing.py

+52
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,58 @@ def _check_set(df, cond, check_dtypes=True):
24792479
expected = df[df['a'] == 1].reindex(df.index)
24802480
assert_frame_equal(result, expected)
24812481

2482+
def test_where_array_like(self):
2483+
# see gh-15414
2484+
klasses = [list, tuple, np.array]
2485+
2486+
df = DataFrame({'a': [1, 2, 3]})
2487+
cond = [[False], [True], [True]]
2488+
expected = DataFrame({'a': [np.nan, 2, 3]})
2489+
2490+
for klass in klasses:
2491+
result = df.where(klass(cond))
2492+
assert_frame_equal(result, expected)
2493+
2494+
df['b'] = 2
2495+
expected['b'] = [2, np.nan, 2]
2496+
cond = [[False, True], [True, False], [True, True]]
2497+
2498+
for klass in klasses:
2499+
result = df.where(klass(cond))
2500+
assert_frame_equal(result, expected)
2501+
2502+
def test_where_invalid_input(self):
2503+
# see gh-15414: only boolean arrays accepted
2504+
df = DataFrame({'a': [1, 2, 3]})
2505+
msg = "Boolean array expected for the condition"
2506+
2507+
conds = [
2508+
[[1], [0], [1]],
2509+
Series([[2], [5], [7]]),
2510+
[["True"], ["False"], ["True"]],
2511+
[[Timestamp("2017-01-01")],
2512+
[pd.NaT], [Timestamp("2017-01-02")]]
2513+
]
2514+
2515+
for cond in conds:
2516+
with tm.assertRaisesRegexp(ValueError, msg):
2517+
df.where(cond)
2518+
2519+
df['b'] = 2
2520+
conds = [
2521+
[[0, 1], [1, 0], [1, 1]],
2522+
Series([[0, 2], [5, 0], [4, 7]]),
2523+
[["False", "True"], ["True", "False"],
2524+
["True", "True"]],
2525+
[[pd.NaT, Timestamp("2017-01-01")],
2526+
[Timestamp("2017-01-02"), pd.NaT],
2527+
[Timestamp("2017-01-03"), Timestamp("2017-01-03")]]
2528+
]
2529+
2530+
for cond in conds:
2531+
with tm.assertRaisesRegexp(ValueError, msg):
2532+
df.where(cond)
2533+
24822534
def test_where_bug(self):
24832535

24842536
# GH 2793

pandas/tests/indexes/common.py

+12
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,18 @@ def test_where(self):
497497
result = i.where(cond)
498498
tm.assert_index_equal(result, expected)
499499

500+
def test_where_array_like(self):
501+
i = self.create_index()
502+
503+
_nan = i._na_value
504+
cond = [False] + [True] * (len(i) - 1)
505+
klasses = [list, tuple, np.array, pd.Series]
506+
expected = pd.Index([_nan] + i[1:].tolist(), dtype=i.dtype)
507+
508+
for klass in klasses:
509+
result = i.where(klass(cond))
510+
tm.assert_index_equal(result, expected)
511+
500512
def test_setops_errorcases(self):
501513
for name, idx in compat.iteritems(self.indices):
502514
# # non-iterable input

pandas/tests/indexes/period/test_period.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,22 @@ def test_where(self):
8989
expected = i
9090
tm.assert_index_equal(result, expected)
9191

92-
i2 = i.copy()
9392
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
9493
freq='D')
9594
result = i.where(notnull(i2))
9695
expected = i2
9796
tm.assert_index_equal(result, expected)
9897

98+
def test_where_array_like(self):
99+
i = self.create_index()
100+
cond = [False] + [True] * (len(i) - 1)
101+
klasses = [list, tuple, np.array, Series]
102+
expected = pd.PeriodIndex([pd.NaT] + i[1:].tolist(), freq='D')
103+
104+
for klass in klasses:
105+
result = i.where(klass(cond))
106+
tm.assert_index_equal(result, expected)
107+
99108
def test_where_other(self):
100109

101110
i = self.create_index()

pandas/tests/indexes/test_category.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,23 @@ def test_where(self):
240240
expected = i
241241
tm.assert_index_equal(result, expected)
242242

243-
i2 = i.copy()
244243
i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(),
245244
categories=i.categories)
246245
result = i.where(notnull(i2))
247246
expected = i2
248247
tm.assert_index_equal(result, expected)
249248

249+
def test_where_array_like(self):
250+
i = self.create_index()
251+
cond = [False] + [True] * (len(i) - 1)
252+
klasses = [list, tuple, np.array, pd.Series]
253+
expected = pd.CategoricalIndex([np.nan] + i[1:].tolist(),
254+
categories=i.categories)
255+
256+
for klass in klasses:
257+
result = i.where(klass(cond))
258+
tm.assert_index_equal(result, expected)
259+
250260
def test_append(self):
251261

252262
ci = self.create_index()

pandas/tests/indexes/test_multi.py

+9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ def f():
8888

8989
self.assertRaises(NotImplementedError, f)
9090

91+
def test_where_array_like(self):
92+
i = MultiIndex.from_tuples([('A', 1), ('A', 2)])
93+
klasses = [list, tuple, np.array, pd.Series]
94+
cond = [False, True]
95+
96+
for klass in klasses:
97+
f = lambda: i.where(klass(cond))
98+
self.assertRaises(NotImplementedError, f)
99+
91100
def test_repeat(self):
92101
reps = 2
93102
numbers = [1, 2, 3]

pandas/tests/series/test_indexing.py

+28
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,34 @@ def f():
11931193
expected = Series(np.nan, index=[9])
11941194
assert_series_equal(result, expected)
11951195

1196+
def test_where_array_like(self):
1197+
# see gh-15414
1198+
s = Series([1, 2, 3])
1199+
cond = [False, True, True]
1200+
expected = Series([np.nan, 2, 3])
1201+
klasses = [list, tuple, np.array, Series]
1202+
1203+
for klass in klasses:
1204+
result = s.where(klass(cond))
1205+
assert_series_equal(result, expected)
1206+
1207+
def test_where_invalid_input(self):
1208+
# see gh-15414: only boolean arrays accepted
1209+
s = Series([1, 2, 3])
1210+
msg = "Boolean array expected for the condition"
1211+
1212+
conds = [
1213+
[1, 0, 1],
1214+
Series([2, 5, 7]),
1215+
["True", "False", "True"],
1216+
[Timestamp("2017-01-01"),
1217+
pd.NaT, Timestamp("2017-01-02")]
1218+
]
1219+
1220+
for cond in conds:
1221+
with tm.assertRaisesRegexp(ValueError, msg):
1222+
s.where(cond)
1223+
11961224
def test_where_setitem_invalid(self):
11971225

11981226
# GH 2702

0 commit comments

Comments
 (0)