Skip to content

Commit 10ea36b

Browse files
committed
fix where
1 parent dae0096 commit 10ea36b

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

pandas/core/internals.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,8 @@ def _maybe_downcast(self, blocks, downcast=None):
476476
elif downcast is None and (self.is_timedelta or self.is_datetime):
477477
return blocks
478478

479+
if not isinstance(blocks, list):
480+
blocks = [blocks]
479481
return _extend_blocks([b.downcast(downcast) for b in blocks])
480482

481483
def downcast(self, dtypes=None, mgr=None):
@@ -1437,13 +1439,15 @@ def func(cond, values, other):
14371439
try:
14381440
result = func(cond, values, other)
14391441
except TypeError:
1442+
14401443
# we cannot coerce, return a compat dtype
14411444
# we are explicity ignoring raise_on_error here
14421445
block = self.coerce_to_target_dtype(other)
1443-
return block.where(orig_other, cond, align=align,
1444-
raise_on_error=raise_on_error,
1445-
try_cast=try_cast, axis=axis,
1446-
transpose=transpose)
1446+
blocks = block.where(orig_other, cond, align=align,
1447+
raise_on_error=raise_on_error,
1448+
try_cast=try_cast, axis=axis,
1449+
transpose=transpose)
1450+
return self._maybe_downcast(blocks, 'infer')
14471451

14481452
if self._can_hold_na or self.ndim == 1:
14491453

pandas/tests/frame/test_indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2712,7 +2712,7 @@ def test_where_axis(self):
27122712
result.where(mask, s, axis='index', inplace=True)
27132713
assert_frame_equal(result, expected)
27142714

2715-
expected = DataFrame([[0, np.nan], [0, np.nan]], dtype='float64')
2715+
expected = DataFrame([[0, np.nan], [0, np.nan]])
27162716
result = df.where(mask, s, axis='columns')
27172717
assert_frame_equal(result, expected)
27182718

pandas/tests/series/test_indexing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,7 @@ def test_where_dups(self):
13771377

13781378
def test_where_datetime_conversion(self):
13791379
s = Series(date_range('20130102', periods=2))
1380-
expected = Series([10, 10], dtype='object')
1380+
expected = Series([10, 10])
13811381
mask = np.array([False, False])
13821382

13831383
rs = s.where(mask, [10, 10])
@@ -1406,7 +1406,7 @@ def test_where_datetime_conversion(self):
14061406

14071407
def test_where_timedelta_coerce(self):
14081408
s = Series([1, 2], dtype='timedelta64[ns]')
1409-
expected = Series([10, 10], dtype='object')
1409+
expected = Series([10, 10])
14101410
mask = np.array([False, False])
14111411

14121412
rs = s.where(mask, [10, 10])

0 commit comments

Comments
 (0)