Skip to content

Commit d2b951d

Browse files
committed
BUG: Accept generic array-like in .where
1 parent b94186d commit d2b951d

File tree

9 files changed

+208
-19
lines changed

9 files changed

+208
-19
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ Bug Fixes
548548

549549

550550

551+
- Bug in ``Series.where()`` and ``DataFrame.where()`` where array-like conditionals were being rejected (:issue:`15414`)
551552
- Bug in ``Series`` construction with a datetimetz (:issue:`14928`)
552553

553554
- Bug in compat for passing long integers to ``Timestamp.replace`` (:issue:`15030`)

pandas/core/generic.py

+21-16
Original file line numberDiff line numberDiff line change
@@ -4732,17 +4732,28 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
47324732
cond, _ = cond.align(self, join='right', broadcast_axis=1)
47334733
else:
47344734
if not hasattr(cond, 'shape'):
4735-
raise ValueError('where requires an ndarray like object for '
4736-
'its condition')
4735+
cond = np.asanyarray(cond)
47374736
if cond.shape != self.shape:
47384737
raise ValueError('Array conditional must be same shape as '
47394738
'self')
47404739
cond = self._constructor(cond, **self._construct_axes_dict())
47414740

4742-
if inplace:
4743-
cond = -(cond.fillna(True).astype(bool))
4741+
fill_value = True if inplace else False
4742+
cond = cond.fillna(fill_value)
4743+
4744+
msg = "Boolean array expected for the condition, not {dtype}"
4745+
4746+
if not isinstance(cond, pd.DataFrame):
4747+
# This is a single-dimensional object.
4748+
if not is_bool_dtype(cond):
4749+
raise ValueError(msg.format(dtype=cond.dtype))
47444750
else:
4745-
cond = cond.fillna(False).astype(bool)
4751+
for dt in cond.dtypes:
4752+
if not is_bool_dtype(dt):
4753+
raise ValueError(msg.format(dtype=dt))
4754+
4755+
cond = cond.astype(bool, copy=False)
4756+
cond = -cond if inplace else cond
47464757

47474758
# try to align
47484759
try_quick = True
@@ -4891,26 +4902,20 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
48914902
48924903
Parameters
48934904
----------
4894-
cond : boolean %(klass)s, array or callable
4905+
cond : boolean %(klass)s, array-like, or callable
48954906
If cond is callable, it is computed on the %(klass)s and
4896-
should return boolean %(klass)s or array.
4897-
The callable must not change input %(klass)s
4898-
(though pandas doesn't check it).
4907+
should return boolean %(klass)s or array. The callable must
4908+
not change input %(klass)s (though pandas doesn't check it).
48994909
49004910
.. versionadded:: 0.18.1
49014911
4902-
A callable can be used as cond.
4903-
49044912
other : scalar, %(klass)s, or callable
49054913
If other is callable, it is computed on the %(klass)s and
4906-
should return scalar or %(klass)s.
4907-
The callable must not change input %(klass)s
4908-
(though pandas doesn't check it).
4914+
should return scalar or %(klass)s. The callable must not
4915+
change input %(klass)s (though pandas doesn't check it).
49094916
49104917
.. versionadded:: 0.18.1
49114918
4912-
A callable can be used as other.
4913-
49144919
inplace : boolean, default False
49154920
Whether to perform the operation in place on the data
49164921
axis : alignment axis if needed, default None

pandas/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def repeat(self, repeats, *args, **kwargs):
574574
575575
Parameters
576576
----------
577-
cond : boolean same length as self
577+
cond : boolean array-like with the same length as self
578578
other : scalar, or array-like
579579
"""
580580

pandas/tests/frame/test_indexing.py

+89
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,95 @@ def _check_set(df, cond, check_dtypes=True):
24792479
expected = df[df['a'] == 1].reindex(df.index)
24802480
assert_frame_equal(result, expected)
24812481

2482+
def test_where_array_like(self):
2483+
# see gh-15414
2484+
klasses = [list, tuple, np.array]
2485+
2486+
df = DataFrame({'a': [1, 2, 3]})
2487+
cond = [[False], [True], [True]]
2488+
expected = DataFrame({'a': [np.nan, 2, 3]})
2489+
2490+
for klass in klasses:
2491+
result = df.where(klass(cond))
2492+
assert_frame_equal(result, expected)
2493+
2494+
df['b'] = 2
2495+
expected['b'] = [2, np.nan, 2]
2496+
cond = [[False, True], [True, False], [True, True]]
2497+
2498+
for klass in klasses:
2499+
result = df.where(klass(cond))
2500+
assert_frame_equal(result, expected)
2501+
2502+
def test_where_invalid_input(self):
2503+
# see gh-15414: only boolean arrays accepted
2504+
df = DataFrame({'a': [1, 2, 3]})
2505+
msg = "Boolean array expected for the condition"
2506+
2507+
conds = [
2508+
[[1], [0], [1]],
2509+
Series([[2], [5], [7]]),
2510+
DataFrame({'a': [2, 5, 7]}),
2511+
[["True"], ["False"], ["True"]],
2512+
[[Timestamp("2017-01-01")],
2513+
[pd.NaT], [Timestamp("2017-01-02")]]
2514+
]
2515+
2516+
for cond in conds:
2517+
with tm.assertRaisesRegexp(ValueError, msg):
2518+
df.where(cond)
2519+
2520+
df['b'] = 2
2521+
conds = [
2522+
[[0, 1], [1, 0], [1, 1]],
2523+
Series([[0, 2], [5, 0], [4, 7]]),
2524+
[["False", "True"], ["True", "False"],
2525+
["True", "True"]],
2526+
DataFrame({'a': [2, 5, 7], 'b': [4, 8, 9]}),
2527+
[[pd.NaT, Timestamp("2017-01-01")],
2528+
[Timestamp("2017-01-02"), pd.NaT],
2529+
[Timestamp("2017-01-03"), Timestamp("2017-01-03")]]
2530+
]
2531+
2532+
for cond in conds:
2533+
with tm.assertRaisesRegexp(ValueError, msg):
2534+
df.where(cond)
2535+
2536+
def test_where_dataframe_col_match(self):
2537+
df = DataFrame([[1, 2, 3], [4, 5, 6]])
2538+
cond = DataFrame([[True, False, True], [False, False, True]])
2539+
2540+
out = df.where(cond)
2541+
expected = DataFrame([[1.0, np.nan, 3], [np.nan, np.nan, 6]])
2542+
tm.assert_frame_equal(out, expected)
2543+
2544+
cond.columns = ["a", "b", "c"] # Columns no longer match.
2545+
msg = "Boolean array expected for the condition"
2546+
with tm.assertRaisesRegexp(ValueError, msg):
2547+
df.where(cond)
2548+
2549+
def test_where_ndframe_align(self):
2550+
msg = "Array conditional must be same shape as self"
2551+
df = DataFrame([[1, 2, 3], [4, 5, 6]])
2552+
2553+
cond = [True]
2554+
with tm.assertRaisesRegexp(ValueError, msg):
2555+
df.where(cond)
2556+
2557+
expected = DataFrame([[1, 2, 3], [np.nan, np.nan, np.nan]])
2558+
2559+
out = df.where(Series(cond))
2560+
tm.assert_frame_equal(out, expected)
2561+
2562+
cond = np.array([False, True, False, True])
2563+
with tm.assertRaisesRegexp(ValueError, msg):
2564+
df.where(cond)
2565+
2566+
expected = DataFrame([[np.nan, np.nan, np.nan], [4, 5, 6]])
2567+
2568+
out = df.where(Series(cond))
2569+
tm.assert_frame_equal(out, expected)
2570+
24822571
def test_where_bug(self):
24832572

24842573
# GH 2793

pandas/tests/indexes/common.py

+12
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,18 @@ def test_where(self):
497497
result = i.where(cond)
498498
tm.assert_index_equal(result, expected)
499499

500+
def test_where_array_like(self):
501+
i = self.create_index()
502+
503+
_nan = i._na_value
504+
cond = [False] + [True] * (len(i) - 1)
505+
klasses = [list, tuple, np.array, pd.Series]
506+
expected = pd.Index([_nan] + i[1:].tolist(), dtype=i.dtype)
507+
508+
for klass in klasses:
509+
result = i.where(klass(cond))
510+
tm.assert_index_equal(result, expected)
511+
500512
def test_setops_errorcases(self):
501513
for name, idx in compat.iteritems(self.indices):
502514
# # non-iterable input

pandas/tests/indexes/period/test_period.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,22 @@ def test_where(self):
8989
expected = i
9090
tm.assert_index_equal(result, expected)
9191

92-
i2 = i.copy()
9392
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
9493
freq='D')
9594
result = i.where(notnull(i2))
9695
expected = i2
9796
tm.assert_index_equal(result, expected)
9897

98+
def test_where_array_like(self):
99+
i = self.create_index()
100+
cond = [False] + [True] * (len(i) - 1)
101+
klasses = [list, tuple, np.array, Series]
102+
expected = pd.PeriodIndex([pd.NaT] + i[1:].tolist(), freq='D')
103+
104+
for klass in klasses:
105+
result = i.where(klass(cond))
106+
tm.assert_index_equal(result, expected)
107+
99108
def test_where_other(self):
100109

101110
i = self.create_index()

pandas/tests/indexes/test_category.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,23 @@ def test_where(self):
240240
expected = i
241241
tm.assert_index_equal(result, expected)
242242

243-
i2 = i.copy()
244243
i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(),
245244
categories=i.categories)
246245
result = i.where(notnull(i2))
247246
expected = i2
248247
tm.assert_index_equal(result, expected)
249248

249+
def test_where_array_like(self):
250+
i = self.create_index()
251+
cond = [False] + [True] * (len(i) - 1)
252+
klasses = [list, tuple, np.array, pd.Series]
253+
expected = pd.CategoricalIndex([np.nan] + i[1:].tolist(),
254+
categories=i.categories)
255+
256+
for klass in klasses:
257+
result = i.where(klass(cond))
258+
tm.assert_index_equal(result, expected)
259+
250260
def test_append(self):
251261

252262
ci = self.create_index()

pandas/tests/indexes/test_multi.py

+9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ def f():
8888

8989
self.assertRaises(NotImplementedError, f)
9090

91+
def test_where_array_like(self):
92+
i = MultiIndex.from_tuples([('A', 1), ('A', 2)])
93+
klasses = [list, tuple, np.array, pd.Series]
94+
cond = [False, True]
95+
96+
for klass in klasses:
97+
f = lambda: i.where(klass(cond))
98+
self.assertRaises(NotImplementedError, f)
99+
91100
def test_repeat(self):
92101
reps = 2
93102
numbers = [1, 2, 3]

pandas/tests/series/test_indexing.py

+54
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,60 @@ def f():
11931193
expected = Series(np.nan, index=[9])
11941194
assert_series_equal(result, expected)
11951195

1196+
def test_where_array_like(self):
1197+
# see gh-15414
1198+
s = Series([1, 2, 3])
1199+
cond = [False, True, True]
1200+
expected = Series([np.nan, 2, 3])
1201+
klasses = [list, tuple, np.array, Series]
1202+
1203+
for klass in klasses:
1204+
result = s.where(klass(cond))
1205+
assert_series_equal(result, expected)
1206+
1207+
def test_where_invalid_input(self):
1208+
# see gh-15414: only boolean arrays accepted
1209+
s = Series([1, 2, 3])
1210+
msg = "Boolean array expected for the condition"
1211+
1212+
conds = [
1213+
[1, 0, 1],
1214+
Series([2, 5, 7]),
1215+
["True", "False", "True"],
1216+
[Timestamp("2017-01-01"),
1217+
pd.NaT, Timestamp("2017-01-02")]
1218+
]
1219+
1220+
for cond in conds:
1221+
with tm.assertRaisesRegexp(ValueError, msg):
1222+
s.where(cond)
1223+
1224+
msg = "Array conditional must be same shape as self"
1225+
with tm.assertRaisesRegexp(ValueError, msg):
1226+
s.where([True])
1227+
1228+
def test_where_ndframe_align(self):
1229+
msg = "Array conditional must be same shape as self"
1230+
s = Series([1, 2, 3])
1231+
1232+
cond = [True]
1233+
with tm.assertRaisesRegexp(ValueError, msg):
1234+
s.where(cond)
1235+
1236+
expected = Series([1, np.nan, np.nan])
1237+
1238+
out = s.where(Series(cond))
1239+
tm.assert_series_equal(out, expected)
1240+
1241+
cond = np.array([False, True, False, True])
1242+
with tm.assertRaisesRegexp(ValueError, msg):
1243+
s.where(cond)
1244+
1245+
expected = Series([np.nan, 2, np.nan])
1246+
1247+
out = s.where(Series(cond))
1248+
tm.assert_series_equal(out, expected)
1249+
11961250
def test_where_setitem_invalid(self):
11971251

11981252
# GH 2702

0 commit comments

Comments
 (0)