@@ -9883,6 +9883,36 @@ def test_where_axis(self):
9883
9883
result.where(mask, s, axis='columns', inplace=True)
9884
9884
assert_frame_equal(result, expected)
9885
9885
9886
+ # Multiple dtypes (=> multiple Blocks)
9887
+ df = pd.concat([DataFrame(np.random.randn(10, 2)),
9888
+ DataFrame(np.random.randint(0, 10, size=(10, 2)))],
9889
+ ignore_index=True, axis=1)
9890
+ mask = DataFrame(False, columns=df.columns, index=df.index)
9891
+ s1 = Series(1, index=df.columns)
9892
+ s2 = Series(2, index=df.index)
9893
+
9894
+ result = df.where(mask, s1, axis='columns')
9895
+ expected = DataFrame(1.0, columns=df.columns, index=df.index)
9896
+ expected[2] = expected[2].astype(int)
9897
+ expected[3] = expected[3].astype(int)
9898
+ assert_frame_equal(result, expected)
9899
+
9900
+ result = df.copy()
9901
+ result.where(mask, s1, axis='columns', inplace=True)
9902
+ assert_frame_equal(result, expected)
9903
+
9904
+ result = df.where(mask, s2, axis='index')
9905
+ expected = DataFrame(2.0, columns=df.columns, index=df.index)
9906
+ expected[2] = expected[2].astype(int)
9907
+ expected[3] = expected[3].astype(int)
9908
+ assert_frame_equal(result, expected)
9909
+
9910
+ result = df.copy()
9911
+ result.where(mask, s2, axis='index', inplace=True)
9912
+ assert_frame_equal(result, expected)
9913
+
9914
+
9915
+
9886
9916
9887
9917
def test_mask(self):
9888
9918
df = DataFrame(np.random.randn(5, 3))
0 commit comments