Skip to content

Commit a265c5c

Browse files
committed
Merge pull request #6593 from hayd/groupby_filter_filter
FIX filter selects selected columns
2 parents 26e8fa8 + 022ab40 commit a265c5c

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

pandas/core/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2529,7 +2529,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
25292529

25302530
indices = []
25312531

2532-
obj = self._obj_with_exclusions
2532+
obj = self._selected_obj
25332533
gen = self.grouper.get_iterator(obj, axis=self.axis)
25342534

25352535
fast_path, slow_path = self._define_paths(func, *args, **kwargs)

pandas/tests/test_groupby.py

+13-2
Original file line numberDiff line numberDiff line change
@@ -3438,6 +3438,13 @@ def test_filter_and_transform_with_non_unique_string_index(self):
34383438
actual = grouped_df.pid.transform(len)
34393439
assert_series_equal(actual, expected)
34403440

3441+
def test_filter_has_access_to_grouped_cols(self):
3442+
df = DataFrame([[1, 2], [1, 3], [5, 6]], columns=['A', 'B'])
3443+
g = df.groupby('A')
3444+
# previously didn't have access to col A #????
3445+
filt = g.filter(lambda x: x['A'].sum() == 2)
3446+
assert_frame_equal(filt, df.iloc[[0, 1]])
3447+
34413448
def test_index_label_overlaps_location(self):
34423449
# checking we don't have any label/location confusion in the
34433450
# the wake of GH5375
@@ -3486,7 +3493,8 @@ def test_groupby_selection_with_methods(self):
34863493
'idxmin', 'idxmax',
34873494
'ffill', 'bfill',
34883495
'pct_change',
3489-
'tshift'
3496+
'tshift',
3497+
#'ohlc'
34903498
]
34913499

34923500
for m in methods:
@@ -3501,8 +3509,11 @@ def test_groupby_selection_with_methods(self):
35013509
g_exp.apply(lambda x: x.sum()))
35023510

35033511
assert_frame_equal(g.resample('D'), g_exp.resample('D'))
3512+
assert_frame_equal(g.resample('D', how='ohlc'),
3513+
g_exp.resample('D', how='ohlc'))
35043514

3505-
3515+
assert_frame_equal(g.filter(lambda x: len(x) == 3),
3516+
g_exp.filter(lambda x: len(x) == 3))
35063517

35073518
def test_groupby_whitelist(self):
35083519
from string import ascii_lowercase

0 commit comments

Comments
 (0)