Skip to content

Commit dfcaeda

Browse files
committed
Add tests with transformation_func fixture
1 parent b8c622d commit dfcaeda

File tree

1 file changed

+56
-22
lines changed

1 file changed

+56
-22
lines changed

pandas/tests/groupby/test_transform.py

+56-22
Original file line numberDiff line numberDiff line change
@@ -1198,25 +1198,59 @@ def test_transform_lambda_indexing():
11981198
tm.assert_frame_equal(result, expected)
11991199

12001200

1201-
@pytest.mark.parametrize(
1202-
"input_df",
1203-
[
1204-
DataFrame(
1205-
{
1206-
"A": [121, 121, 121, 121, 231, 231, 676],
1207-
"B": [1, 2, np.nan, 3, 3, np.nan, 4],
1208-
}
1209-
),
1210-
DataFrame(
1211-
{
1212-
"A": [121, 121, 121, 121, 231, 231, 676],
1213-
"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0],
1214-
}
1215-
),
1216-
],
1217-
)
1218-
def test_groupby_transform_fillna(input_df):
1219-
# GH 27905
1220-
result = input_df.groupby("A").transform(lambda x: x.fillna(x.mean()))
1221-
expected = pd.DataFrame({"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0]})
1222-
tm.assert_frame_equal(result, expected)
1201+
def test_transform_nan_tshift_corrwith(transformation_func):
1202+
1203+
df1 = DataFrame(
1204+
{
1205+
"A": [121, 121, 121, 121, 231, 231, 676],
1206+
"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0],
1207+
}
1208+
)
1209+
g1 = df1.groupby("A")
1210+
1211+
if transformation_func == "corrwith":
1212+
result = g1.corrwith(df1)
1213+
expected = pd.DataFrame(dict(B=[1, np.nan, np.nan], A=[np.nan] * 3))
1214+
expected.index = pd.Index([121, 231, 676], name="A")
1215+
tm.assert_frame_equal(result, expected)
1216+
1217+
if transformation_func == "fillna":
1218+
df3 = df1.copy()
1219+
df3["B"] = [1, np.nan, np.nan, 3, np.nan, 3, 4]
1220+
result = df3.groupby("A").transform(lambda x: x.fillna(x.mean()))
1221+
expected = pd.DataFrame({"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0]})
1222+
tm.assert_frame_equal(result, expected)
1223+
1224+
result = df3.groupby("A").transform(transformation_func, value=1)
1225+
expected = pd.DataFrame({"B": [1.0, 1.0, 1.0, 3.0, 1.0, 3.0, 4.0]})
1226+
tm.assert_frame_equal(result, expected)
1227+
1228+
if transformation_func == "tshift":
1229+
df2 = df1.copy()
1230+
dt_periods = pd.date_range("2013-11-03", periods=7, freq="D")
1231+
df2["C"] = dt_periods
1232+
result = df2.set_index("C").groupby("A").tshift(2, "D")
1233+
df2["C"] = dt_periods + dt_periods.freq * 2
1234+
expected = df2
1235+
tm.assert_frame_equal(
1236+
result.reset_index().reindex(columns=["A", "B", "C"]), expected
1237+
)
1238+
1239+
1240+
def test_check_original_and_transformed_index(transformation_func):
1241+
df = DataFrame({"A": [0, 0, 0, 1, 1, 1], "B": [0, 1, 2, 3, 4, 5]})
1242+
g = df.groupby("A")
1243+
1244+
if transformation_func in [
1245+
"cummax",
1246+
"cummin",
1247+
"cumprod",
1248+
"cumsum",
1249+
"diff",
1250+
"ffill",
1251+
"pct_change",
1252+
"rank",
1253+
"shift",
1254+
]:
1255+
result = g.transform(transformation_func)
1256+
tm.assert_index_equal(result.index, df.index)

0 commit comments

Comments
 (0)