|
14 | 14 | qcut,
|
15 | 15 | )
|
16 | 16 | import pandas._testing as tm
|
| 17 | +from pandas.api.typing import SeriesGroupBy |
17 | 18 | from pandas.tests.groupby import get_groupby_method_args
|
18 | 19 |
|
19 | 20 |
|
@@ -2036,3 +2037,50 @@ def test_groupby_default_depr(cat_columns, keys):
|
2036 | 2037 | klass = FutureWarning if set(cat_columns) & set(keys) else None
|
2037 | 2038 | with tm.assert_produces_warning(klass, match=msg):
|
2038 | 2039 | df.groupby(keys)
|
| 2040 | + |
| 2041 | + |
| 2042 | +@pytest.mark.parametrize("test_series", [True, False]) |
| 2043 | +@pytest.mark.parametrize("keys", [["a1"], ["a1", "a2"]]) |
| 2044 | +def test_agg_list(request, as_index, observed, reduction_func, test_series, keys): |
| 2045 | + # GH#52760 |
| 2046 | + if test_series and reduction_func == "corrwith": |
| 2047 | + assert not hasattr(SeriesGroupBy, "corrwith") |
| 2048 | + pytest.skip("corrwith not implemented for SeriesGroupBy") |
| 2049 | + elif reduction_func == "corrwith": |
| 2050 | + msg = "GH#32293: attempts to call SeriesGroupBy.corrwith" |
| 2051 | + request.node.add_marker(pytest.mark.xfail(reason=msg)) |
| 2052 | + elif ( |
| 2053 | + reduction_func == "nunique" |
| 2054 | + and not test_series |
| 2055 | + and len(keys) != 1 |
| 2056 | + and not observed |
| 2057 | + and not as_index |
| 2058 | + ): |
| 2059 | + msg = "GH#52848 - raises a ValueError" |
| 2060 | + request.node.add_marker(pytest.mark.xfail(reason=msg)) |
| 2061 | + |
| 2062 | + df = DataFrame({"a1": [0, 0, 1], "a2": [2, 3, 3], "b": [4, 5, 6]}) |
| 2063 | + df = df.astype({"a1": "category", "a2": "category"}) |
| 2064 | + if "a2" not in keys: |
| 2065 | + df = df.drop(columns="a2") |
| 2066 | + gb = df.groupby(by=keys, as_index=as_index, observed=observed) |
| 2067 | + if test_series: |
| 2068 | + gb = gb["b"] |
| 2069 | + args = get_groupby_method_args(reduction_func, df) |
| 2070 | + |
| 2071 | + result = gb.agg([reduction_func], *args) |
| 2072 | + expected = getattr(gb, reduction_func)(*args) |
| 2073 | + |
| 2074 | + if as_index and (test_series or reduction_func == "size"): |
| 2075 | + expected = expected.to_frame(reduction_func) |
| 2076 | + if not test_series: |
| 2077 | + if not as_index: |
| 2078 | + # TODO: GH#52849 - as_index=False is not respected |
| 2079 | + expected = expected.set_index(keys) |
| 2080 | + expected.columns = MultiIndex( |
| 2081 | + levels=[["b"], [reduction_func]], codes=[[0], [0]] |
| 2082 | + ) |
| 2083 | + elif not as_index: |
| 2084 | + expected.columns = keys + [reduction_func] |
| 2085 | + |
| 2086 | + tm.assert_equal(result, expected) |
0 commit comments