Skip to content

Commit 52c60d6

Browse files
committed
BUG: Accept generic array-like in .where
1 parent 9b5d848 commit 52c60d6

File tree

9 files changed

+102
-19
lines changed

9 files changed

+102
-19
lines changed

doc/source/whatsnew/v0.20.0.txt

+1
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,7 @@ Bug Fixes
510510

511511

512512

513+
- Bug in ``Series.where()`` and ``DataFrame.where()`` where array-like conditionals were being rejected (:issue:`15414`)
513514
- Bug in ``Series`` construction with a datetimetz (:issue:`14928`)
514515

515516
- Bug in compat for passing long integers to ``Timestamp.replace`` (:issue:`15030`)

pandas/core/generic.py

+27-16
Original file line numberDiff line numberDiff line change
@@ -4717,24 +4717,41 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
47174717
applied as a function even if callable. Used in __setitem__.
47184718
"""
47194719
inplace = validate_bool_kwarg(inplace, 'inplace')
4720+
inplace = True if inplace else False
47204721

47214722
cond = com._apply_if_callable(cond, self)
47224723

47234724
if isinstance(cond, NDFrame):
47244725
cond, _ = cond.align(self, join='right', broadcast_axis=1)
47254726
else:
47264727
if not hasattr(cond, 'shape'):
4727-
raise ValueError('where requires an ndarray like object for '
4728-
'its condition')
4728+
cond = np.asanyarray(cond)
47294729
if cond.shape != self.shape:
47304730
raise ValueError('Array conditional must be same shape as '
47314731
'self')
47324732
cond = self._constructor(cond, **self._construct_axes_dict())
47334733

4734-
if inplace:
4735-
cond = -(cond.fillna(True).astype(bool))
4734+
# If 'inplace' is True, we want to fill with True
4735+
# before inverting. If 'inplace' is False, we will
4736+
# fill with False and do nothing else.
4737+
#
4738+
# Conveniently, 'inplace' matches the boolean with
4739+
# which we want to fill.
4740+
cond = cond.fillna(inplace)
4741+
4742+
msg = "Boolean array expected for the condition, not {dtype}"
4743+
4744+
if not isinstance(cond, pd.DataFrame):
4745+
# This is a single-dimensional object.
4746+
if not is_bool_dtype(cond):
4747+
raise ValueError(msg.format(dtype=cond.dtype))
47364748
else:
4737-
cond = cond.fillna(False).astype(bool)
4749+
for dt in cond.dtypes:
4750+
if not is_bool_dtype(dt):
4751+
raise ValueError(msg.format(dtype=dt))
4752+
4753+
cond = cond.astype(bool)
4754+
cond = -cond if inplace else cond
47384755

47394756
# try to align
47404757
try_quick = True
@@ -4883,26 +4900,20 @@ def _where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
48834900
48844901
Parameters
48854902
----------
4886-
cond : boolean %(klass)s, array or callable
4903+
cond : boolean %(klass)s, array-like, or callable
48874904
If cond is callable, it is computed on the %(klass)s and
4888-
should return boolean %(klass)s or array.
4889-
The callable must not change input %(klass)s
4890-
(though pandas doesn't check it).
4905+
should return boolean %(klass)s or array. The callable must
4906+
not change input %(klass)s (though pandas doesn't check it).
48914907
48924908
.. versionadded:: 0.18.1
48934909
4894-
A callable can be used as cond.
4895-
48964910
other : scalar, %(klass)s, or callable
48974911
If other is callable, it is computed on the %(klass)s and
4898-
should return scalar or %(klass)s.
4899-
The callable must not change input %(klass)s
4900-
(though pandas doesn't check it).
4912+
should return scalar or %(klass)s. The callable must not
4913+
change input %(klass)s (though pandas doesn't check it).
49014914
49024915
.. versionadded:: 0.18.1
49034916
4904-
A callable can be used as other.
4905-
49064917
inplace : boolean, default False
49074918
Whether to perform the operation in place on the data
49084919
axis : alignment axis if needed, default None

pandas/indexes/base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def repeat(self, repeats, *args, **kwargs):
574574
575575
Parameters
576576
----------
577-
cond : boolean same length as self
577+
cond : boolean array-like with the same length as self
578578
other : scalar, or array-like
579579
"""
580580

pandas/tests/frame/test_indexing.py

+20
Original file line numberDiff line numberDiff line change
@@ -2479,6 +2479,26 @@ def _check_set(df, cond, check_dtypes=True):
24792479
expected = df[df['a'] == 1].reindex(df.index)
24802480
assert_frame_equal(result, expected)
24812481

2482+
def test_where_array_like(self):
2483+
# see gh-15414
2484+
klasses = [list, tuple, np.array]
2485+
2486+
df = DataFrame({'a': [1, 2, 3]})
2487+
cond = [[False], [True], [True]]
2488+
expected = DataFrame({'a': [np.nan, 2, 3]})
2489+
2490+
for klass in klasses:
2491+
result = df.where(klass(cond))
2492+
assert_frame_equal(result, expected)
2493+
2494+
df['b'] = 2
2495+
expected['b'] = [2, np.nan, 2]
2496+
cond = [[False, True], [True, False], [True, True]]
2497+
2498+
for klass in klasses:
2499+
result = df.where(klass(cond))
2500+
assert_frame_equal(result, expected)
2501+
24822502
def test_where_bug(self):
24832503

24842504
# GH 2793

pandas/tests/indexes/common.py

+12
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,18 @@ def test_where(self):
497497
result = i.where(cond)
498498
tm.assert_index_equal(result, expected)
499499

500+
def test_where_array_like(self):
501+
i = self.create_index()
502+
503+
_nan = i._na_value
504+
cond = [False] + [True] * (len(i) - 1)
505+
klasses = [list, tuple, np.array, pd.Series]
506+
expected = pd.Index([_nan] + i[1:].tolist(), dtype=i.dtype)
507+
508+
for klass in klasses:
509+
result = i.where(klass(cond))
510+
tm.assert_index_equal(result, expected)
511+
500512
def test_setops_errorcases(self):
501513
for name, idx in compat.iteritems(self.indices):
502514
# # non-iterable input

pandas/tests/indexes/period/test_period.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,22 @@ def test_where(self):
8989
expected = i
9090
tm.assert_index_equal(result, expected)
9191

92-
i2 = i.copy()
9392
i2 = pd.PeriodIndex([pd.NaT, pd.NaT] + i[2:].tolist(),
9493
freq='D')
9594
result = i.where(notnull(i2))
9695
expected = i2
9796
tm.assert_index_equal(result, expected)
9897

98+
def test_where_array_like(self):
99+
i = self.create_index()
100+
cond = [False] + [True] * (len(i) - 1)
101+
klasses = [list, tuple, np.array, Series]
102+
expected = pd.PeriodIndex([pd.NaT] + i[1:].tolist(), freq='D')
103+
104+
for klass in klasses:
105+
result = i.where(klass(cond))
106+
tm.assert_index_equal(result, expected)
107+
99108
def test_where_other(self):
100109

101110
i = self.create_index()

pandas/tests/indexes/test_category.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -240,13 +240,23 @@ def test_where(self):
240240
expected = i
241241
tm.assert_index_equal(result, expected)
242242

243-
i2 = i.copy()
244243
i2 = pd.CategoricalIndex([np.nan, np.nan] + i[2:].tolist(),
245244
categories=i.categories)
246245
result = i.where(notnull(i2))
247246
expected = i2
248247
tm.assert_index_equal(result, expected)
249248

249+
def test_where_array_like(self):
250+
i = self.create_index()
251+
cond = [False] + [True] * (len(i) - 1)
252+
klasses = [list, tuple, np.array, pd.Series]
253+
expected = pd.CategoricalIndex([np.nan] + i[1:].tolist(),
254+
categories=i.categories)
255+
256+
for klass in klasses:
257+
result = i.where(klass(cond))
258+
tm.assert_index_equal(result, expected)
259+
250260
def test_append(self):
251261

252262
ci = self.create_index()

pandas/tests/indexes/test_multi.py

+9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ def f():
8888

8989
self.assertRaises(NotImplementedError, f)
9090

91+
def test_where_array_like(self):
92+
i = MultiIndex.from_tuples([('A', 1), ('A', 2)])
93+
klasses = [list, tuple, np.array, pd.Series]
94+
cond = [False, True]
95+
96+
for klass in klasses:
97+
f = lambda: i.where(klass(cond))
98+
self.assertRaises(NotImplementedError, f)
99+
91100
def test_repeat(self):
92101
reps = 2
93102
numbers = [1, 2, 3]

pandas/tests/series/test_indexing.py

+11
Original file line numberDiff line numberDiff line change
@@ -1193,6 +1193,17 @@ def f():
11931193
expected = Series(np.nan, index=[9])
11941194
assert_series_equal(result, expected)
11951195

1196+
def test_where_array_like(self):
1197+
# see gh-15414
1198+
s = Series([1, 2, 3])
1199+
cond = [False, True, True]
1200+
expected = Series([np.nan, 2, 3])
1201+
klasses = [list, tuple, np.array, Series]
1202+
1203+
for klass in klasses:
1204+
result = s.where(klass(cond))
1205+
assert_series_equal(result, expected)
1206+
11961207
def test_where_setitem_invalid(self):
11971208

11981209
# GH 2702

0 commit comments

Comments
 (0)