Skip to content

Commit 43f42e8

Browse files
committed
Fix alignment in the case DataFrame vs DataFrame case. Add tests
1 parent a26065c commit 43f42e8

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

pandas/core/generic.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3411,10 +3411,11 @@ def where(self, cond, other=np.nan, inplace=False, axis=None, level=None,
34113411

34123412
if axis is None:
34133413
axis = 0
3414+
3415+
if self.ndim == getattr(other, 'ndim', 0):
34143416
align = True
34153417
else:
3416-
axis = self._get_axis_number(axis)
3417-
align = (axis == 1)
3418+
align = (self._get_axis_number(axis) == 1)
34183419

34193420
block_axis = self._get_block_manager_axis(axis)
34203421

pandas/tests/test_frame.py

+28
Original file line numberDiff line numberDiff line change
@@ -9947,8 +9947,36 @@ def test_where_axis(self):
99479947
result.where(mask, s2, axis='index', inplace=True)
99489948
assert_frame_equal(result, expected)
99499949

9950+
# DataFrame vs DataFrame
9951+
d1 = df.copy().drop(1, axis=0)
9952+
expected = df.copy()
9953+
expected.loc[1, :] = np.nan
99509954

9955+
result = df.where(mask, d1)
9956+
assert_frame_equal(result, expected)
9957+
result = df.where(mask, d1, axis='index')
9958+
assert_frame_equal(result, expected)
9959+
result = df.copy()
9960+
result.where(mask, d1, inplace=True)
9961+
assert_frame_equal(result, expected)
9962+
result = df.copy()
9963+
result.where(mask, d1, inplace=True, axis='index')
9964+
assert_frame_equal(result, expected)
99519965

9966+
d2 = df.copy().drop(1, axis=1)
9967+
expected = df.copy()
9968+
expected.loc[:, 1] = np.nan
9969+
9970+
result = df.where(mask, d2)
9971+
assert_frame_equal(result, expected)
9972+
result = df.where(mask, d2, axis='columns')
9973+
assert_frame_equal(result, expected)
9974+
result = df.copy()
9975+
result.where(mask, d2, inplace=True)
9976+
assert_frame_equal(result, expected)
9977+
result = df.copy()
9978+
result.where(mask, d2, inplace=True, axis='columns')
9979+
assert_frame_equal(result, expected)
99529980

99539981
def test_mask(self):
99549982
df = DataFrame(np.random.randn(5, 3))

0 commit comments

Comments
 (0)