@@ -2555,36 +2555,43 @@ def test_pivot_table_index_and_column_keys_with_nan(self, dropna):
2555
2555
tm .assert_frame_equal (left = result , right = expected )
2556
2556
2557
2557
@pytest .mark .parametrize (
2558
- "index, columns" ,
2559
- [("Category" , "Value" ), ("Value" , "Category" )],
2558
+ "index, columns, e_data, e_index, e_cols" ,
2559
+ [
2560
+ (
2561
+ "Category" ,
2562
+ "Value" ,
2563
+ [
2564
+ [1.0 , (nan := np .nan ), 1.0 , nan ],
2565
+ [nan , 1.0 , nan , 1.0 ],
2566
+ ],
2567
+ (cat_index := Index (data = ["A" , "B" ], name = "Category" )),
2568
+ (val_index := Index (data = [10 , 20 , 40 , 50 ], name = "Value" )),
2569
+ ),
2570
+ (
2571
+ "Value" ,
2572
+ "Category" ,
2573
+ [
2574
+ [1.0 , nan ],
2575
+ [nan , 1.0 ],
2576
+ [1.0 , nan ],
2577
+ [nan , 1.0 ],
2578
+ ],
2579
+ val_index ,
2580
+ cat_index ,
2581
+ ),
2582
+ ],
2560
2583
ids = ["values-and-columns" , "values-and-index" ],
2561
2584
)
2562
- def test_pivot_table_values_as_two_params (self , index , columns , request ):
2585
+ def test_pivot_table_values_as_two_params (
2586
+ self , index , columns , e_data , e_index , e_cols
2587
+ ):
2563
2588
# GH#57876
2564
2589
data = {"Category" : ["A" , "B" , "A" , "B" ], "Value" : [10 , 20 , 40 , 50 ]}
2565
2590
df = DataFrame (data )
2566
2591
result = df .pivot_table (
2567
2592
index = index , columns = columns , values = "Value" , aggfunc = "count"
2568
2593
)
2569
- nan = np .nan
2570
- cat_index = Index (data = ["A" , "B" ], name = "Category" )
2571
- val_index = Index (data = [10 , 20 , 40 , 50 ], name = "Value" )
2572
- if request .node .callspec .id == "values-and-columns" :
2573
- e_data = [
2574
- [1.0 , nan , 1.0 , nan ],
2575
- [nan , 1.0 , nan , 1.0 ],
2576
- ]
2577
- expected = DataFrame (data = e_data , index = cat_index , columns = val_index )
2578
-
2579
- else :
2580
- e_data = [
2581
- [1.0 , nan ],
2582
- [nan , 1.0 ],
2583
- [1.0 , nan ],
2584
- [nan , 1.0 ],
2585
- ]
2586
- expected = DataFrame (data = e_data , index = val_index , columns = cat_index )
2587
-
2594
+ expected = DataFrame (data = e_data , index = e_index , columns = e_cols )
2588
2595
tm .assert_frame_equal (result , expected )
2589
2596
2590
2597
0 commit comments