Skip to content

Commit f272dc1

Browse files
charlesdong1991proost
authored andcommitted
TST: Add tests for MultiIndex columns cases in aggregate relabelling (pandas-dev#29504)
1 parent f3a26d8 commit f272dc1

File tree

1 file changed

+74
-0
lines changed

1 file changed

+74
-0
lines changed

pandas/tests/groupby/aggregate/test_aggregate.py

+74
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,80 @@ def test_mangled(self):
495495
tm.assert_frame_equal(result, expected)
496496

497497

498+
@pytest.mark.parametrize(
499+
"agg_col1, agg_col2, agg_col3, agg_result1, agg_result2, agg_result3",
500+
[
501+
(
502+
(("y", "A"), "max"),
503+
(("y", "A"), np.min),
504+
(("y", "B"), "mean"),
505+
[1, 3],
506+
[0, 2],
507+
[5.5, 7.5],
508+
),
509+
(
510+
(("y", "A"), lambda x: max(x)),
511+
(("y", "A"), lambda x: 1),
512+
(("y", "B"), "mean"),
513+
[1, 3],
514+
[1, 1],
515+
[5.5, 7.5],
516+
),
517+
(
518+
pd.NamedAgg(("y", "A"), "max"),
519+
pd.NamedAgg(("y", "B"), np.mean),
520+
pd.NamedAgg(("y", "A"), lambda x: 1),
521+
[1, 3],
522+
[5.5, 7.5],
523+
[1, 1],
524+
),
525+
],
526+
)
527+
def test_agg_relabel_multiindex_column(
528+
agg_col1, agg_col2, agg_col3, agg_result1, agg_result2, agg_result3
529+
):
530+
# GH 29422, add tests for multiindex column cases
531+
df = DataFrame(
532+
{"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]}
533+
)
534+
df.columns = pd.MultiIndex.from_tuples([("x", "group"), ("y", "A"), ("y", "B")])
535+
idx = pd.Index(["a", "b"], name=("x", "group"))
536+
537+
result = df.groupby(("x", "group")).agg(a_max=(("y", "A"), "max"))
538+
expected = DataFrame({"a_max": [1, 3]}, index=idx)
539+
tm.assert_frame_equal(result, expected)
540+
541+
result = df.groupby(("x", "group")).agg(
542+
col_1=agg_col1, col_2=agg_col2, col_3=agg_col3
543+
)
544+
expected = DataFrame(
545+
{"col_1": agg_result1, "col_2": agg_result2, "col_3": agg_result3}, index=idx
546+
)
547+
tm.assert_frame_equal(result, expected)
548+
549+
550+
def test_agg_relabel_multiindex_raises_not_exist():
551+
# GH 29422, add test for raises senario when aggregate column does not exist
552+
df = DataFrame(
553+
{"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]}
554+
)
555+
df.columns = pd.MultiIndex.from_tuples([("x", "group"), ("y", "A"), ("y", "B")])
556+
557+
with pytest.raises(KeyError, match="does not exist"):
558+
df.groupby(("x", "group")).agg(a=(("Y", "a"), "max"))
559+
560+
561+
def test_agg_relabel_multiindex_raises_duplicate():
562+
# GH29422, add test for raises senario when getting duplicates
563+
df = DataFrame(
564+
{"group": ["a", "a", "b", "b"], "A": [0, 1, 2, 3], "B": [5, 6, 7, 8]}
565+
)
566+
df.columns = pd.MultiIndex.from_tuples([("x", "group"), ("y", "A"), ("y", "B")])
567+
568+
with pytest.raises(SpecificationError, match="Function names"):
569+
df.groupby(("x", "group")).agg(a=(("y", "A"), "min"), b=(("y", "A"), "min"))
570+
571+
498572
def myfunc(s):
499573
return np.percentile(s, q=0.90)
500574

0 commit comments

Comments
 (0)