@@ -555,6 +555,81 @@ def test_idxmin_idxmax_axis1():
555
555
gb2 .idxmax (axis = 1 )
556
556
557
557
558
+ @pytest .mark .parametrize ("numeric_only" , [True , False , None ])
559
+ def test_axis1_numeric_only (request , groupby_func , numeric_only ):
560
+ if groupby_func in ("idxmax" , "idxmin" ):
561
+ pytest .skip ("idxmax and idx_min tested in test_idxmin_idxmax_axis1" )
562
+ if groupby_func in ("mad" , "tshift" ):
563
+ pytest .skip ("mad and tshift are deprecated" )
564
+ if groupby_func in ("corrwith" , "skew" ):
565
+ msg = "GH#47723 groupby.corrwith and skew do not correctly implement axis=1"
566
+ request .node .add_marker (pytest .mark .xfail (reason = msg ))
567
+
568
+ df = DataFrame (np .random .randn (10 , 4 ), columns = ["A" , "B" , "C" , "D" ])
569
+ df ["E" ] = "x"
570
+ groups = [1 , 2 , 3 , 1 , 2 , 3 , 1 , 2 , 3 , 4 ]
571
+ gb = df .groupby (groups )
572
+ method = getattr (gb , groupby_func )
573
+ args = (0 ,) if groupby_func == "fillna" else ()
574
+ kwargs = {"axis" : 1 }
575
+ if numeric_only is not None :
576
+ # when numeric_only is None we don't pass any argument
577
+ kwargs ["numeric_only" ] = numeric_only
578
+
579
+ # Functions without numeric_only and axis args
580
+ no_args = ("cumprod" , "cumsum" , "diff" , "fillna" , "pct_change" , "rank" , "shift" )
581
+ # Functions with axis args
582
+ has_axis = (
583
+ "cumprod" ,
584
+ "cumsum" ,
585
+ "diff" ,
586
+ "pct_change" ,
587
+ "rank" ,
588
+ "shift" ,
589
+ "cummax" ,
590
+ "cummin" ,
591
+ "idxmin" ,
592
+ "idxmax" ,
593
+ "fillna" ,
594
+ )
595
+ if numeric_only is not None and groupby_func in no_args :
596
+ msg = "got an unexpected keyword argument 'numeric_only'"
597
+ with pytest .raises (TypeError , match = msg ):
598
+ method (* args , ** kwargs )
599
+ elif groupby_func not in has_axis :
600
+ msg = "got an unexpected keyword argument 'axis'"
601
+ warn = FutureWarning if groupby_func == "skew" and not numeric_only else None
602
+ with tm .assert_produces_warning (warn , match = "Dropping of nuisance columns" ):
603
+ with pytest .raises (TypeError , match = msg ):
604
+ method (* args , ** kwargs )
605
+ # fillna and shift are successful even on object dtypes
606
+ elif (numeric_only is None or not numeric_only ) and groupby_func not in (
607
+ "fillna" ,
608
+ "shift" ,
609
+ ):
610
+ msgs = (
611
+ # cummax, cummin, rank
612
+ "not supported between instances of" ,
613
+ # cumprod
614
+ "can't multiply sequence by non-int of type 'float'" ,
615
+ # cumsum, diff, pct_change
616
+ "unsupported operand type" ,
617
+ )
618
+ with pytest .raises (TypeError , match = f"({ '|' .join (msgs )} )" ):
619
+ method (* args , ** kwargs )
620
+ else :
621
+ result = method (* args , ** kwargs )
622
+
623
+ df_expected = df .drop (columns = "E" ).T if numeric_only else df .T
624
+ expected = getattr (df_expected , groupby_func )(* args ).T
625
+ if groupby_func == "shift" and not numeric_only :
626
+ # shift with axis=1 leaves the leftmost column as numeric
627
+ # but transposing for expected gives us object dtype
628
+ expected = expected .astype (float )
629
+
630
+ tm .assert_equal (result , expected )
631
+
632
+
558
633
def test_groupby_cumprod ():
559
634
# GH 4095
560
635
df = DataFrame ({"key" : ["b" ] * 10 , "value" : 2 })
@@ -1321,7 +1396,7 @@ def test_deprecate_numeric_only(
1321
1396
assert "b" not in result .columns
1322
1397
elif (
1323
1398
# kernels that work on any dtype and have numeric_only arg
1324
- kernel in ("first" , "last" , "corrwith" )
1399
+ kernel in ("first" , "last" )
1325
1400
or (
1326
1401
# kernels that work on any dtype and don't have numeric_only arg
1327
1402
kernel in ("any" , "all" , "bfill" , "ffill" , "fillna" , "nth" , "nunique" )
@@ -1339,7 +1414,8 @@ def test_deprecate_numeric_only(
1339
1414
"(not allowed for this dtype"
1340
1415
"|must be a string or a number"
1341
1416
"|cannot be performed against 'object' dtypes"
1342
- "|must be a string or a real number)"
1417
+ "|must be a string or a real number"
1418
+ "|unsupported operand type)"
1343
1419
)
1344
1420
with pytest .raises (TypeError , match = msg ):
1345
1421
method (* args , ** kwargs )
0 commit comments