@@ -9919,6 +9919,36 @@ def test_where_axis(self):
9919
9919
result.where(mask, s, axis='columns', inplace=True)
9920
9920
assert_frame_equal(result, expected)
9921
9921
9922
+ # Multiple dtypes (=> multiple Blocks)
9923
+ df = pd.concat([DataFrame(np.random.randn(10, 2)),
9924
+ DataFrame(np.random.randint(0, 10, size=(10, 2)))],
9925
+ ignore_index=True, axis=1)
9926
+ mask = DataFrame(False, columns=df.columns, index=df.index)
9927
+ s1 = Series(1, index=df.columns)
9928
+ s2 = Series(2, index=df.index)
9929
+
9930
+ result = df.where(mask, s1, axis='columns')
9931
+ expected = DataFrame(1.0, columns=df.columns, index=df.index)
9932
+ expected[2] = expected[2].astype(int)
9933
+ expected[3] = expected[3].astype(int)
9934
+ assert_frame_equal(result, expected)
9935
+
9936
+ result = df.copy()
9937
+ result.where(mask, s1, axis='columns', inplace=True)
9938
+ assert_frame_equal(result, expected)
9939
+
9940
+ result = df.where(mask, s2, axis='index')
9941
+ expected = DataFrame(2.0, columns=df.columns, index=df.index)
9942
+ expected[2] = expected[2].astype(int)
9943
+ expected[3] = expected[3].astype(int)
9944
+ assert_frame_equal(result, expected)
9945
+
9946
+ result = df.copy()
9947
+ result.where(mask, s2, axis='index', inplace=True)
9948
+ assert_frame_equal(result, expected)
9949
+
9950
+
9951
+
9922
9952
9923
9953
def test_mask(self):
9924
9954
df = DataFrame(np.random.randn(5, 3))
0 commit comments