@@ -426,3 +426,55 @@ def test_array_ufunc_series_defer():
426
426
427
427
tm .assert_series_equal (r1 , expected )
428
428
tm .assert_series_equal (r2 , expected )
429
+
430
+
431
+ def test_groupby_agg ():
432
+ # Ensure that the result of agg is inferred to be decimal dtype
433
+ # https://github.com/pandas-dev/pandas/issues/29141
434
+
435
+ data = make_data ()[:5 ]
436
+ df = pd .DataFrame (
437
+ {"id1" : [0 , 0 , 0 , 1 , 1 ], "id2" : [0 , 1 , 0 , 1 , 1 ], "decimals" : DecimalArray (data )}
438
+ )
439
+
440
+ # single key, selected column
441
+ expected = pd .Series (to_decimal ([data [0 ], data [3 ]]))
442
+ result = df .groupby ("id1" )["decimals" ].agg (lambda x : x .iloc [0 ])
443
+ tm .assert_series_equal (result , expected , check_names = False )
444
+ result = df ["decimals" ].groupby (df ["id1" ]).agg (lambda x : x .iloc [0 ])
445
+ tm .assert_series_equal (result , expected , check_names = False )
446
+
447
+ # multiple keys, selected column
448
+ expected = pd .Series (
449
+ to_decimal ([data [0 ], data [1 ], data [3 ]]),
450
+ index = pd .MultiIndex .from_tuples ([(0 , 0 ), (0 , 1 ), (1 , 1 )]),
451
+ )
452
+ result = df .groupby (["id1" , "id2" ])["decimals" ].agg (lambda x : x .iloc [0 ])
453
+ tm .assert_series_equal (result , expected , check_names = False )
454
+ result = df ["decimals" ].groupby ([df ["id1" ], df ["id2" ]]).agg (lambda x : x .iloc [0 ])
455
+ tm .assert_series_equal (result , expected , check_names = False )
456
+
457
+ # multiple columns
458
+ expected = pd .DataFrame ({"id2" : [0 , 1 ], "decimals" : to_decimal ([data [0 ], data [3 ]])})
459
+ result = df .groupby ("id1" ).agg (lambda x : x .iloc [0 ])
460
+ tm .assert_frame_equal (result , expected , check_names = False )
461
+
462
+
463
+ def test_groupby_agg_ea_method (monkeypatch ):
464
+ # Ensure that the result of agg is inferred to be decimal dtype
465
+ # https://github.com/pandas-dev/pandas/issues/29141
466
+
467
+ def DecimalArray__my_sum (self ):
468
+ return np .sum (np .array (self ))
469
+
470
+ monkeypatch .setattr (DecimalArray , "my_sum" , DecimalArray__my_sum , raising = False )
471
+
472
+ data = make_data ()[:5 ]
473
+ df = pd .DataFrame ({"id" : [0 , 0 , 0 , 1 , 1 ], "decimals" : DecimalArray (data )})
474
+ expected = pd .Series (to_decimal ([data [0 ] + data [1 ] + data [2 ], data [3 ] + data [4 ]]))
475
+
476
+ result = df .groupby ("id" )["decimals" ].agg (lambda x : x .values .my_sum ())
477
+ tm .assert_series_equal (result , expected , check_names = False )
478
+ s = pd .Series (DecimalArray (data ))
479
+ result = s .groupby (np .array ([0 , 0 , 0 , 1 , 1 ])).agg (lambda x : x .values .my_sum ())
480
+ tm .assert_series_equal (result , expected , check_names = False )
0 commit comments