@@ -1112,10 +1112,15 @@ def test_agg_function_input(self, cython_table_items, inputs, axis):
1112
1112
with pytest .raises (expected ):
1113
1113
# e.g. DataFrame(['a b'.split()]).cumprod() will raise
1114
1114
df .agg (np_func , axis = axis )
1115
+ with pytest .raises (expected ):
1115
1116
df .agg (str_func , axis = axis )
1116
- elif str_func in ('cumprod' , 'cumsum' ):
1117
- tm .assert_frame_equal (df .agg (np_func , axis = axis ), expected )
1118
- tm .assert_frame_equal (df .agg (str_func , axis = axis ), expected )
1117
+ return
1118
+
1119
+ result = df .agg (np_func , axis = axis )
1120
+ result_str_func = df .agg (str_func , axis = axis )
1121
+ if str_func in ('cumprod' , 'cumsum' ):
1122
+ tm .assert_frame_equal (result , expected )
1123
+ tm .assert_frame_equal (result_str_func , expected )
1119
1124
else :
1120
- tm .assert_series_equal (df . agg ( np_func , axis = axis ) , expected )
1121
- tm .assert_series_equal (df . agg ( str_func , axis = axis ) , expected )
1125
+ tm .assert_series_equal (result , expected )
1126
+ tm .assert_series_equal (result_str_func , expected )
0 commit comments