|
8 | 8 | import numpy as np
|
9 | 9 | import pytest
|
10 | 10 |
|
11 |
| -from pandas import DataFrame, Index, MultiIndex, Series, date_range |
| 11 | +from pandas import DataFrame, Index, MultiIndex, RangeIndex, Series, date_range |
12 | 12 | from pandas.util import testing as tm
|
13 | 13 |
|
14 | 14 | AGG_FUNCTIONS = ['sum', 'prod', 'min', 'max', 'median', 'mean', 'skew',
|
@@ -164,33 +164,39 @@ def raw_frame():
|
164 | 164 | @pytest.mark.parametrize('axis', [0, 1])
|
165 | 165 | @pytest.mark.parametrize('skipna', [True, False])
|
166 | 166 | @pytest.mark.parametrize('sort', [True, False])
|
| 167 | +@pytest.mark.parametrize('as_index', [True, False]) |
167 | 168 | def test_regression_whitelist_methods(
|
168 | 169 | raw_frame, op, level,
|
169 |
| - axis, skipna, sort): |
| 170 | + axis, skipna, sort, as_index): |
170 | 171 | # GH6944
|
171 | 172 | # GH 17537
|
172 | 173 | # explicitly test the whitelist methods
|
173 | 174 |
|
| 175 | + if not as_index and axis == 1: |
| 176 | + pytest.skip('as_index=False only valid for axis=0') |
| 177 | + |
174 | 178 | if axis == 0:
|
175 | 179 | frame = raw_frame
|
176 | 180 | else:
|
177 | 181 | frame = raw_frame.T
|
178 | 182 |
|
| 183 | + groupby_kwargs = {'level': level, 'axis': axis, 'sort': sort} #, 'as_index': as_index} |
| 184 | + group_op_kwargs = {} |
| 185 | + frame_op_kwargs = {'level': level, 'axis': axis} |
179 | 186 | if op in AGG_FUNCTIONS_WITH_SKIPNA:
|
180 |
| - grouped = frame.groupby(level=level, axis=axis, sort=sort) |
181 |
| - result = getattr(grouped, op)(skipna=skipna) |
182 |
| - expected = getattr(frame, op)(level=level, axis=axis, |
183 |
| - skipna=skipna) |
184 |
| - if sort: |
185 |
| - expected = expected.sort_index(axis=axis, level=level) |
186 |
| - tm.assert_frame_equal(result, expected) |
187 |
| - else: |
188 |
| - grouped = frame.groupby(level=level, axis=axis, sort=sort) |
189 |
| - result = getattr(grouped, op)() |
190 |
| - expected = getattr(frame, op)(level=level, axis=axis) |
191 |
| - if sort: |
192 |
| - expected = expected.sort_index(axis=axis, level=level) |
193 |
| - tm.assert_frame_equal(result, expected) |
| 187 | + group_op_kwargs['skipna'] = skipna |
| 188 | + frame_op_kwargs['skipna'] = skipna |
| 189 | + |
| 190 | + grouped = frame.groupby(**groupby_kwargs) |
| 191 | + result = getattr(grouped, op)(**group_op_kwargs) |
| 192 | + expected = getattr(frame, op)(**frame_op_kwargs) |
| 193 | + if as_index: |
| 194 | + pass |
| 195 | + |
| 196 | + if sort: |
| 197 | + expected = expected.sort_index(axis=axis, level=level) |
| 198 | + |
| 199 | + tm.assert_frame_equal(result, expected) |
194 | 200 |
|
195 | 201 |
|
196 | 202 | def test_groupby_blacklist(df_letters):
|
|
0 commit comments