|
| 1 | +from collections import OrderedDict |
1 | 2 | from itertools import product
|
2 | 3 | import pytest
|
3 | 4 | import warnings
|
@@ -314,6 +315,53 @@ def test_preserve_metadata(self):
|
314 | 315 | assert s2.name == 'foo'
|
315 | 316 | assert s3.name == 'foo'
|
316 | 317 |
|
| 318 | + @pytest.mark.parametrize("func,window_size,expected_vals", [ |
| 319 | + ('rolling', 2, [[np.nan, np.nan, np.nan, np.nan], |
| 320 | + [15., 20., 25., 20.], |
| 321 | + [25., 30., 35., 30.], |
| 322 | + [np.nan, np.nan, np.nan, np.nan], |
| 323 | + [20., 30., 35., 30.], |
| 324 | + [35., 40., 60., 40.], |
| 325 | + [60., 80., 85., 80]]), |
| 326 | + ('expanding', None, [[10., 10., 20., 20.], |
| 327 | + [15., 20., 25., 20.], |
| 328 | + [20., 30., 30., 20.], |
| 329 | + [10., 10., 30., 30.], |
| 330 | + [20., 30., 35., 30.], |
| 331 | + [26.666667, 40., 50., 30.], |
| 332 | + [40., 80., 60., 30.]])]) |
| 333 | + def test_multiple_agg_funcs(self, func, window_size, expected_vals): |
| 334 | + # GH 15072 |
| 335 | + df = pd.DataFrame([ |
| 336 | + ['A', 10, 20], |
| 337 | + ['A', 20, 30], |
| 338 | + ['A', 30, 40], |
| 339 | + ['B', 10, 30], |
| 340 | + ['B', 30, 40], |
| 341 | + ['B', 40, 80], |
| 342 | + ['B', 80, 90]], columns=['stock', 'low', 'high']) |
| 343 | + |
| 344 | + f = getattr(df.groupby('stock'), func) |
| 345 | + if window_size: |
| 346 | + window = f(window_size) |
| 347 | + else: |
| 348 | + window = f() |
| 349 | + |
| 350 | + index = pd.MultiIndex.from_tuples([ |
| 351 | + ('A', 0), ('A', 1), ('A', 2), |
| 352 | + ('B', 3), ('B', 4), ('B', 5), ('B', 6)], names=['stock', None]) |
| 353 | + columns = pd.MultiIndex.from_tuples([ |
| 354 | + ('low', 'mean'), ('low', 'max'), ('high', 'mean'), |
| 355 | + ('high', 'min')]) |
| 356 | + expected = pd.DataFrame(expected_vals, index=index, columns=columns) |
| 357 | + |
| 358 | + result = window.agg(OrderedDict(( |
| 359 | + ('low', ['mean', 'max']), |
| 360 | + ('high', ['mean', 'min']), |
| 361 | + ))) |
| 362 | + |
| 363 | + tm.assert_frame_equal(result, expected) |
| 364 | + |
317 | 365 |
|
318 | 366 | @pytest.mark.filterwarnings("ignore:can't resolve package:ImportWarning")
|
319 | 367 | class TestWindow(Base):
|
|
0 commit comments