@@ -2479,6 +2479,93 @@ def _check_set(df, cond, check_dtypes=True):
2479
2479
expected = df [df ['a' ] == 1 ].reindex (df .index )
2480
2480
assert_frame_equal (result , expected )
2481
2481
2482
+ def test_where_array_like (self ):
2483
+ # see gh-15414
2484
+ klasses = [list , tuple , np .array ]
2485
+
2486
+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2487
+ cond = [[False ], [True ], [True ]]
2488
+ expected = DataFrame ({'a' : [np .nan , 2 , 3 ]})
2489
+
2490
+ for klass in klasses :
2491
+ result = df .where (klass (cond ))
2492
+ assert_frame_equal (result , expected )
2493
+
2494
+ df ['b' ] = 2
2495
+ expected ['b' ] = [2 , np .nan , 2 ]
2496
+ cond = [[False , True ], [True , False ], [True , True ]]
2497
+
2498
+ for klass in klasses :
2499
+ result = df .where (klass (cond ))
2500
+ assert_frame_equal (result , expected )
2501
+
2502
+ def test_where_invalid_input (self ):
2503
+ # see gh-15414: only boolean arrays accepted
2504
+ df = DataFrame ({'a' : [1 , 2 , 3 ]})
2505
+ msg = "Boolean array expected for the condition"
2506
+
2507
+ conds = [
2508
+ [[1 ], [0 ], [1 ]],
2509
+ Series ([[2 ], [5 ], [7 ]]),
2510
+ [["True" ], ["False" ], ["True" ]],
2511
+ [[Timestamp ("2017-01-01" )],
2512
+ [pd .NaT ], [Timestamp ("2017-01-02" )]]
2513
+ ]
2514
+
2515
+ for cond in conds :
2516
+ with tm .assertRaisesRegexp (ValueError , msg ):
2517
+ df .where (cond )
2518
+
2519
+ df ['b' ] = 2
2520
+ conds = [
2521
+ [[0 , 1 ], [1 , 0 ], [1 , 1 ]],
2522
+ Series ([[0 , 2 ], [5 , 0 ], [4 , 7 ]]),
2523
+ [["False" , "True" ], ["True" , "False" ],
2524
+ ["True" , "True" ]],
2525
+ [[pd .NaT , Timestamp ("2017-01-01" )],
2526
+ [Timestamp ("2017-01-02" ), pd .NaT ],
2527
+ [Timestamp ("2017-01-03" ), Timestamp ("2017-01-03" )]]
2528
+ ]
2529
+
2530
+ for cond in conds :
2531
+ with tm .assertRaisesRegexp (ValueError , msg ):
2532
+ df .where (cond )
2533
+
2534
+ def test_where_dataframe_col_match (self ):
2535
+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2536
+ cond = DataFrame ([[True , False , True ], [False , False , True ]])
2537
+
2538
+ out = df .where (cond )
2539
+ expected = DataFrame ([[1.0 , np .nan , 3 ], [np .nan , np .nan , 6 ]])
2540
+ tm .assert_frame_equal (out , expected )
2541
+
2542
+ cond .columns = ["a" , "b" , "c" ] # Columns no longer match.
2543
+ msg = "Boolean array expected for the condition"
2544
+ with tm .assertRaisesRegexp (ValueError , msg ):
2545
+ df .where (cond )
2546
+
2547
+ def test_where_ndframe_align (self ):
2548
+ msg = "Array conditional must be same shape as self"
2549
+ df = DataFrame ([[1 , 2 , 3 ], [4 , 5 , 6 ]])
2550
+
2551
+ cond = [True ]
2552
+ with tm .assertRaisesRegexp (ValueError , msg ):
2553
+ df .where (cond )
2554
+
2555
+ expected = DataFrame ([[1 , 2 , 3 ], [np .nan , np .nan , np .nan ]])
2556
+
2557
+ out = df .where (Series (cond ))
2558
+ tm .assert_frame_equal (out , expected )
2559
+
2560
+ cond = np .array ([False , True , False , True ])
2561
+ with tm .assertRaisesRegexp (ValueError , msg ):
2562
+ df .where (cond )
2563
+
2564
+ expected = DataFrame ([[np .nan , np .nan , np .nan ], [4 , 5 , 6 ]])
2565
+
2566
+ out = df .where (Series (cond ))
2567
+ tm .assert_frame_equal (out , expected )
2568
+
2482
2569
def test_where_bug (self ):
2483
2570
2484
2571
# GH 2793
0 commit comments