|
9 | 9 | from pandas.core.index import Index, MultiIndex
|
10 | 10 | from pandas.core.common import rands
|
11 | 11 | from pandas.core.api import Categorical, DataFrame
|
12 |
| -from pandas.core.groupby import GroupByError, SpecificationError, DataError |
| 12 | +from pandas.core.groupby import (GroupByError, SpecificationError, DataError, |
| 13 | + _apply_whitelist) |
13 | 14 | from pandas.core.series import Series
|
14 | 15 | from pandas.util.testing import (assert_panel_equal, assert_frame_equal,
|
15 | 16 | assert_series_equal, assert_almost_equal,
|
@@ -2696,8 +2697,40 @@ def test_filter_against_workaround(self):
|
2696 | 2697 | new_way = grouped.filter(lambda x: x['ints'].mean() > N/20)
|
2697 | 2698 | assert_frame_equal(new_way.sort_index(), old_way.sort_index())
|
2698 | 2699 |
|
| 2700 | + def test_groupby_whitelist(self): |
| 2701 | + from string import ascii_lowercase |
| 2702 | + letters = np.array(list(ascii_lowercase)) |
| 2703 | + N = 10 |
| 2704 | + random_letters = letters.take(np.random.randint(0, 26, N)) |
| 2705 | + df = DataFrame({'floats': N / 10 * Series(np.random.random(N)), |
| 2706 | + 'letters': Series(random_letters)}) |
| 2707 | + s = df.floats |
| 2708 | + |
| 2709 | + blacklist = ['eval', 'query', 'abs', 'shift', 'tshift', 'where', |
| 2710 | + 'mask', 'align', 'groupby', 'clip', 'astype', |
| 2711 | + 'at', 'combine', 'consolidate', 'convert_objects', |
| 2712 | + 'corr', 'corr_with', 'cov'] |
| 2713 | + to_methods = [method for method in dir(df) if method.startswith('to_')] |
| 2714 | + |
| 2715 | + blacklist.extend(to_methods) |
| 2716 | + |
| 2717 | + # e.g., to_csv |
| 2718 | + defined_but_not_allowed = ("(?:^Cannot.+{0!r}.+{1!r}.+try using the " |
| 2719 | + "'apply' method$)") |
| 2720 | + |
| 2721 | + # e.g., query, eval |
| 2722 | + not_defined = "(?:^{1!r} object has no attribute {0!r}$)" |
| 2723 | + fmt = defined_but_not_allowed + '|' + not_defined |
| 2724 | + for bl in blacklist: |
| 2725 | + for obj in (df, s): |
| 2726 | + gb = obj.groupby(df.letters) |
| 2727 | + msg = fmt.format(bl, type(gb).__name__) |
| 2728 | + with tm.assertRaisesRegexp(AttributeError, msg): |
| 2729 | + getattr(gb, bl) |
| 2730 | + |
| 2731 | + |
2699 | 2732 | def assert_fp_equal(a, b):
|
2700 |
| - assert((np.abs(a - b) < 1e-12).all()) |
| 2733 | + assert (np.abs(a - b) < 1e-12).all() |
2701 | 2734 |
|
2702 | 2735 |
|
2703 | 2736 | def _check_groupby(df, result, keys, field, f=lambda x: x.sum()):
|
|
0 commit comments