Skip to content

Commit a26065c

Browse files
committed
Handle DataFrames with multiple blocks correctly
1 parent 7e36845 commit a26065c

File tree

3 files changed

+43
-6
lines changed

3 files changed

+43
-6
lines changed

pandas/core/generic.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3413,7 +3413,8 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
34133413
axis = 0
34143414
align = True
34153415
else:
3416-
align = False
3416+
axis = self._get_axis_number(axis)
3417+
align = (axis == 1)
34173418

34183419
block_axis = self._get_block_manager_axis(axis)
34193420

pandas/core/internals.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ def putmask(self, mask, new, align=True, inplace=False,
676676
# direction, then explictly repeat and reshape new instead
677677
if getattr(new, 'ndim', 0) >= 1:
678678
if self.ndim - 1 == new.ndim and axis == 1:
679-
new = np.repeat(new, self.shape[-1]).reshape(self.shape)
679+
new = np.repeat(new, new_values.shape[-1]).reshape(self.shape)
680680

681681
np.putmask(new_values, mask, new)
682682

@@ -2427,12 +2427,18 @@ def apply(self, f, axes=None, filter=None, do_integrity_check=False, **kwargs):
24272427
else:
24282428
kwargs['filter'] = filter_locs
24292429

2430-
if f == 'where' and kwargs.get('align', True):
2430+
if f == 'where':
24312431
align_copy = True
2432-
align_keys = ['other', 'cond']
2433-
elif f == 'putmask' and kwargs.get('align', True):
2432+
if kwargs.get('align', True):
2433+
align_keys = ['other', 'cond']
2434+
else:
2435+
align_keys = ['cond']
2436+
elif f == 'putmask':
24342437
align_copy = False
2435-
align_keys = ['new', 'mask']
2438+
if kwargs.get('align', True):
2439+
align_keys = ['new', 'mask']
2440+
else:
2441+
align_keys = ['mask']
24362442
elif f == 'eval':
24372443
align_copy = False
24382444
align_keys = ['other']

pandas/tests/test_frame.py

+30
Original file line numberDiff line numberDiff line change
@@ -9919,6 +9919,36 @@ def test_where_axis(self):
99199919
result.where(mask, s, axis='columns', inplace=True)
99209920
assert_frame_equal(result, expected)
99219921

9922+
# Multiple dtypes (=> multiple Blocks)
9923+
df = pd.concat([DataFrame(np.random.randn(10, 2)),
9924+
DataFrame(np.random.randint(0, 10, size=(10, 2)))],
9925+
ignore_index=True, axis=1)
9926+
mask = DataFrame(False, columns=df.columns, index=df.index)
9927+
s1 = Series(1, index=df.columns)
9928+
s2 = Series(2, index=df.index)
9929+
9930+
result = df.where(mask, s1, axis='columns')
9931+
expected = DataFrame(1.0, columns=df.columns, index=df.index)
9932+
expected[2] = expected[2].astype(int)
9933+
expected[3] = expected[3].astype(int)
9934+
assert_frame_equal(result, expected)
9935+
9936+
result = df.copy()
9937+
result.where(mask, s1, axis='columns', inplace=True)
9938+
assert_frame_equal(result, expected)
9939+
9940+
result = df.where(mask, s2, axis='index')
9941+
expected = DataFrame(2.0, columns=df.columns, index=df.index)
9942+
expected[2] = expected[2].astype(int)
9943+
expected[3] = expected[3].astype(int)
9944+
assert_frame_equal(result, expected)
9945+
9946+
result = df.copy()
9947+
result.where(mask, s2, axis='index', inplace=True)
9948+
assert_frame_equal(result, expected)
9949+
9950+
9951+
99229952

99239953
def test_mask(self):
99249954
df = DataFrame(np.random.randn(5, 3))

0 commit comments

Comments
 (0)