Skip to content

Commit d53765e

Browse files
committed
Handle DataFrames with multiple blocks correctly
1 parent 1d70ea3 commit d53765e

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

pandas/core/generic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3401,7 +3401,8 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
34013401
axis = 0
34023402
align = True
34033403
else:
3404-
align = False
3404+
axis = self._get_axis_number(axis)
3405+
align = (axis == 1)
34053406

34063407
block_axis = self._get_block_manager_axis(axis)
34073408

pandas/core/internals.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ def putmask(self, mask, new, align=True, inplace=False,
671671
# direction, then explictly repeat and reshape new instead
672672
if getattr(new, 'ndim', 0) >= 1:
673673
if self.ndim - 1 == new.ndim and axis == 1:
674-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
674+
new = np.repeat(new, new_values.shape[-1]).reshape(self.shape)
675675

676676
np.putmask(new_values, mask, new)
677677

@@ -2427,12 +2427,18 @@ def apply(self, f, axes=None, filter=None, do_integrity_check=False, **kwargs):
24272427
else:
24282428
kwargs['filter'] = filter_locs
24292429

2430-
if f == 'where' and kwargs.get('align', True):
2430+
if f == 'where':
24312431
align_copy = True
2432-
align_keys = ['other', 'cond']
2433-
elif f == 'putmask' and kwargs.get('align', True):
2432+
if kwargs.get('align', True):
2433+
align_keys = ['other', 'cond']
2434+
else:
2435+
align_keys = ['cond']
2436+
elif f == 'putmask':
24342437
align_copy = False
2435-
align_keys = ['new', 'mask']
2438+
if kwargs.get('align', True):
2439+
align_keys = ['new', 'mask']
2440+
else:
2441+
align_keys = ['mask']
24362442
elif f == 'eval':
24372443
align_copy = False
24382444
align_keys = ['other']

pandas/tests/test_frame.py

+30
Original file line numberDiff line numberDiff line change
@@ -9883,6 +9883,36 @@ def test_where_axis(self):
98839883
result.where(mask, s, axis='columns', inplace=True)
98849884
assert_frame_equal(result, expected)
98859885

9886+
# Multiple dtypes (=> multiple Blocks)
9887+
df = pd.concat([DataFrame(np.random.randn(10, 2)),
9888+
DataFrame(np.random.randint(0, 10, size=(10, 2)))],
9889+
ignore_index=True, axis=1)
9890+
mask = DataFrame(False, columns=df.columns, index=df.index)
9891+
s1 = Series(1, index=df.columns)
9892+
s2 = Series(2, index=df.index)
9893+
9894+
result = df.where(mask, s1, axis='columns')
9895+
expected = DataFrame(1.0, columns=df.columns, index=df.index)
9896+
expected[2] = expected[2].astype(int)
9897+
expected[3] = expected[3].astype(int)
9898+
assert_frame_equal(result, expected)
9899+
9900+
result = df.copy()
9901+
result.where(mask, s1, axis='columns', inplace=True)
9902+
assert_frame_equal(result, expected)
9903+
9904+
result = df.where(mask, s2, axis='index')
9905+
expected = DataFrame(2.0, columns=df.columns, index=df.index)
9906+
expected[2] = expected[2].astype(int)
9907+
expected[3] = expected[3].astype(int)
9908+
assert_frame_equal(result, expected)
9909+
9910+
result = df.copy()
9911+
result.where(mask, s2, axis='index', inplace=True)
9912+
assert_frame_equal(result, expected)
9913+
9914+
9915+
98869916

98879917
def test_mask(self):
98889918
df = DataFrame(np.random.randn(5, 3))

0 commit comments

Comments
 (0)