@@ -2479,6 +2479,95 @@ def _check_set(df, cond, check_dtypes=True):
2479
2479
expected = df [df ['a' ] == 1 ].reindex (df .index )
2480
2480
assert_frame_equal (result , expected )
2481
2481
2482
+ def test_where_array_like (self ):
2483
+ # see gh-15414
2484
+ klasses = [list , tuple , np .array ]
2485
+
2486
+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2487
+ cond = [[False ], [True ], [True ]]
2488
+ expected = DataFrame ({'a' : [np .nan , 2 , 3 ]})
2489
+
2490
+ for klass in klasses :
2491
+ result = df .where (klass (cond ))
2492
+ assert_frame_equal (result , expected )
2493
+
2494
+ df ['b' ] = 2
2495
+ expected ['b' ] = [2 , np .nan , 2 ]
2496
+ cond = [[False , True ], [True , False ], [True , True ]]
2497
+
2498
+ for klass in klasses :
2499
+ result = df .where (klass (cond ))
2500
+ assert_frame_equal (result , expected )
2501
+
2502
+ def test_where_invalid_input (self ):
2503
+ # see gh-15414: only boolean arrays accepted
2504
+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2505
+ msg = "Boolean array expected for the condition"
2506
+
2507
+ conds = [
2508
+ [[1 ], [0 ], [1 ]],
2509
+ Series ([[2 ], [5 ], [7 ]]),
2510
+ DataFrame ({'a' : [2 , 5 , 7 ]}),
2511
+ [["True" ], ["False" ], ["True" ]],
2512
+ [[Timestamp ("2017-01-01" )],
2513
+ [pd .NaT ], [Timestamp ("2017-01-02" )]]
2514
+ ]
2515
+
2516
+ for cond in conds :
2517
+ with tm .assertRaisesRegexp (ValueError , msg ):
2518
+ df .where (cond )
2519
+
2520
+ df ['b' ] = 2
2521
+ conds = [
2522
+ [[0 , 1 ], [1 , 0 ], [1 , 1 ]],
2523
+ Series ([[0 , 2 ], [5 , 0 ], [4 , 7 ]]),
2524
+ [["False" , "True" ], ["True" , "False" ],
2525
+ ["True" , "True" ]],
2526
+ DataFrame ({'a' : [2 , 5 , 7 ], 'b' : [4 , 8 , 9 ]}),
2527
+ [[pd .NaT , Timestamp ("2017-01-01" )],
2528
+ [Timestamp ("2017-01-02" ), pd .NaT ],
2529
+ [Timestamp ("2017-01-03" ), Timestamp ("2017-01-03" )]]
2530
+ ]
2531
+
2532
+ for cond in conds :
2533
+ with tm .assertRaisesRegexp (ValueError , msg ):
2534
+ df .where (cond )
2535
+
2536
+ def test_where_dataframe_col_match (self ):
2537
+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2538
+ cond = DataFrame ([[True , False , True ], [False , False , True ]])
2539
+
2540
+ out = df .where (cond )
2541
+ expected = DataFrame ([[1.0 , np .nan , 3 ], [np .nan , np .nan , 6 ]])
2542
+ tm .assert_frame_equal (out , expected )
2543
+
2544
+ cond .columns = ["a" , "b" , "c" ] # Columns no longer match.
2545
+ msg = "Boolean array expected for the condition"
2546
+ with tm .assertRaisesRegexp (ValueError , msg ):
2547
+ df .where (cond )
2548
+
2549
+ def test_where_ndframe_align (self ):
2550
+ msg = "Array conditional must be same shape as self"
2551
+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2552
+
2553
+ cond = [True ]
2554
+ with tm .assertRaisesRegexp (ValueError , msg ):
2555
+ df .where (cond )
2556
+
2557
+ expected = DataFrame ([[1 , 2 , 3 ], [np .nan , np .nan , np .nan ]])
2558
+
2559
+ out = df .where (Series (cond ))
2560
+ tm .assert_frame_equal (out , expected )
2561
+
2562
+ cond = np .array ([False , True , False , True ])
2563
+ with tm .assertRaisesRegexp (ValueError , msg ):
2564
+ df .where (cond )
2565
+
2566
+ expected = DataFrame ([[np .nan , np .nan , np .nan ], [4 , 5 , 6 ]])
2567
+
2568
+ out = df .where (Series (cond ))
2569
+ tm .assert_frame_equal (out , expected )
2570
+
2482
2571
def test_where_bug (self ):
2483
2572
2484
2573
# GH 2793
0 commit comments