Skip to content

Commit fdb5566

Browse files
committed
Splt into separate tests and add use fixture for input df
1 parent dfcaeda commit fdb5566

File tree

1 file changed

+51
-21
lines changed

1 file changed

+51
-21
lines changed

pandas/tests/groupby/test_transform.py

+51-21
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323
from pandas.core.groupby.groupby import DataError
2424

2525

26+
@pytest.fixture
27+
def df_for_transformation_func():
28+
return DataFrame(
29+
{
30+
"A": [121, 121, 121, 121, 231, 231, 676],
31+
"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0],
32+
}
33+
)
34+
35+
2636
def assert_fp_equal(a, b):
2737
assert (np.abs(a - b) < 1e-12).all()
2838

@@ -1198,46 +1208,66 @@ def test_transform_lambda_indexing():
11981208
tm.assert_frame_equal(result, expected)
11991209

12001210

1201-
def test_transform_nan_tshift_corrwith(transformation_func):
1211+
def test_groupby_corrwith(transformation_func, df_for_transformation_func):
12021212

1203-
df1 = DataFrame(
1204-
{
1205-
"A": [121, 121, 121, 121, 231, 231, 676],
1206-
"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0],
1207-
}
1208-
)
1209-
g1 = df1.groupby("A")
1213+
# GH 27905
1214+
df = df_for_transformation_func.copy()
1215+
g = df.groupby("A")
12101216

12111217
if transformation_func == "corrwith":
1212-
result = g1.corrwith(df1)
1218+
op = lambda x: getattr(x, transformation_func)(df)
1219+
result = op(g)
12131220
expected = pd.DataFrame(dict(B=[1, np.nan, np.nan], A=[np.nan] * 3))
12141221
expected.index = pd.Index([121, 231, 676], name="A")
12151222
tm.assert_frame_equal(result, expected)
12161223

1224+
1225+
def test_groupby_transform_nan(transformation_func, df_for_transformation_func):
1226+
1227+
# GH 27905
1228+
df = df_for_transformation_func.copy()
1229+
g = df.groupby("A")
1230+
12171231
if transformation_func == "fillna":
1218-
df3 = df1.copy()
1219-
df3["B"] = [1, np.nan, np.nan, 3, np.nan, 3, 4]
1220-
result = df3.groupby("A").transform(lambda x: x.fillna(x.mean()))
1221-
expected = pd.DataFrame({"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0]})
1222-
tm.assert_frame_equal(result, expected)
12231232

1224-
result = df3.groupby("A").transform(transformation_func, value=1)
1233+
df["B"] = [1, np.nan, np.nan, 3, np.nan, 3, 4]
1234+
result = g.transform(transformation_func, value=1)
12251235
expected = pd.DataFrame({"B": [1.0, 1.0, 1.0, 3.0, 1.0, 3.0, 4.0]})
12261236
tm.assert_frame_equal(result, expected)
1237+
op = lambda x: getattr(x, transformation_func)(1)
1238+
result = op(g)
1239+
tm.assert_frame_equal(result, expected)
1240+
1241+
1242+
def test_groupby_transform_tshift(transformation_func, df_for_transformation_func):
1243+
1244+
# GH 27905
1245+
df = df_for_transformation_func.copy()
1246+
dt_periods = pd.date_range("2013-11-03", periods=7, freq="D")
1247+
df["C"] = dt_periods
1248+
g = df.set_index("C").groupby("A")
12271249

12281250
if transformation_func == "tshift":
1229-
df2 = df1.copy()
1230-
dt_periods = pd.date_range("2013-11-03", periods=7, freq="D")
1231-
df2["C"] = dt_periods
1232-
result = df2.set_index("C").groupby("A").tshift(2, "D")
1233-
df2["C"] = dt_periods + dt_periods.freq * 2
1234-
expected = df2
1251+
1252+
op = lambda x: getattr(x, transformation_func)(2, "D")
1253+
result = op(g)
1254+
df["C"] = dt_periods + dt_periods.freq * 2
1255+
expected = df
12351256
tm.assert_frame_equal(
12361257
result.reset_index().reindex(columns=["A", "B", "C"]), expected
12371258
)
12381259

12391260

12401261
def test_check_original_and_transformed_index(transformation_func):
1262+
1263+
# GH 27905
1264+
df = DataFrame(
1265+
{
1266+
"A": [121, 121, 121, 121, 231, 231, 676],
1267+
"B": [1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0],
1268+
}
1269+
)
1270+
12411271
df = DataFrame({"A": [0, 0, 0, 1, 1, 1], "B": [0, 1, 2, 3, 4, 5]})
12421272
g = df.groupby("A")
12431273

0 commit comments

Comments
 (0)