diff --git a/pandas/tests/groupby/test_transform.py b/pandas/tests/groupby/test_transform.py index d3972e6ba9008..f5eb58112fda3 100644 --- a/pandas/tests/groupby/test_transform.py +++ b/pandas/tests/groupby/test_transform.py @@ -1074,3 +1074,34 @@ def test_transform_lambda_with_datetimetz(): name="time", ) assert_series_equal(result, expected) + + +# @pytest.mark.parametrize( +# "input_df, expected_df", +# [ +# ( +# DataFrame( +# { +# "A": [121, 121, 121, 121, 231, 231, 676], +# "B": [1, 2, np.nan, 3, 3, np.nan, 4], +# } +# ), +# 1, +# ), +# (DataFrame({"B": [6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 4.0]}), 1), +# ], +# ) +# def test_groupby_transform_sum(input_df, expected_df): +# # GH 27905 - Test sum in groupby.transform +# df_transform = input_df.groupby("A")["B"].transform("sum") +# df_transform = df_transform.to_frame() +# assert all(df_transform == expected_df) + + +def test_groupby_transform_sum(): + input_df = DataFrame( + {"A": [121, 121, 121, 121, 231, 231, 676], "B": [1, 2, np.nan, 3, 3, np.nan, 4]} + ) + expected = Series([6.0, 6.0, 6.0, 6.0, 3.0, 3.0, 4.0]) + result = input_df.groupby("A")["B"].transform("sum") + assert_series_equal(result, expected)