Skip to content

Commit 2a2cfb8

Browse files
committed
ENH: Add filter method to SeriesGroupBy, DataFrameGroupBy
1 parent cf47a42 commit 2a2cfb8

File tree

5 files changed

+344
-28
lines changed

5 files changed

+344
-28
lines changed

RELEASE.rst

+2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ pandas 0.11.1
4848
- Add iterator to ``Series.str`` (GH3638_)
4949
- ``pd.set_option()`` now allows N option, value pairs (GH3667_).
5050
- Added keyword parameters for different types of scatter_matrix subplots
51+
- A ``filter`` method on grouped Series or DataFrames returns a subset of
52+
the original (GH3680_, GH919_)
5153

5254
**Improvements to existing features**
5355

doc/source/groupby.rst

+39
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ following:
4141
- Standardizing data (zscore) within group
4242
- Filling NAs within groups with a value derived from each group
4343

44+
- **Filtration**: discard some groups, according to a group-wise computation
45+
that evaluates True or False. Some examples:
46+
47+
- Discarding data that belongs to groups with only a few members
48+
- Filtering out data based on the group sum or mean
49+
4450
- Some combination of the above: GroupBy will examine the results of the apply
4551
step and try to return a sensibly combined result if it doesn't fit into
4652
either of the above two categories
@@ -489,6 +495,39 @@ and that the transformed data contains no NAs.
489495
grouped_trans.count() # counts after transformation
490496
grouped_trans.size() # Verify non-NA count equals group size
491497
498+
.. _groupby.filter:
499+
500+
Filtration
501+
----------
502+
503+
The ``filter`` method returns a subset of the original object. Suppose we
504+
want to take only elements that belong to groups with a group sum greater
505+
than 2.
506+
507+
.. ipython:: python
508+
509+
s = Series([1, 1, 2, 3, 3, 3])
510+
s.groupby(s).filter(lambda x: x.sum() > 2)
511+
512+
The argument of ``filter`` must a function that, applied to the group as a
513+
whole, returns ``True`` or ``False``.
514+
515+
Another useful operation is filtering out elements that belong to groups
516+
with only a couple members.
517+
518+
.. ipython:: python
519+
520+
df = DataFrame({'A': arange(8), 'B': list('aabbbbcc')})
521+
df.groupby('B').filter(lambda x: len(x) > 2)
522+
523+
Alternatively, instead of dropping the offending groups, we can return a
524+
like-indexed objects where the groups that do not pass the filter are filled
525+
with NaNs.
526+
527+
.. ipython:: python
528+
529+
df.groupby('B').filter(lambda x: len(x) > 2, dropna=False)
530+
492531
.. _groupby.dispatch:
493532

494533
Dispatching to instance methods

doc/source/v0.11.1.txt

+29
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,35 @@ Enhancements
237237
pd.get_option('a.b')
238238
pd.get_option('b.c')
239239

240+
- The ``filter`` method for group objects returns a subset of the original
241+
object. Suppose we want to take only elements that belong to groups with a
242+
group sum greater than 2.
243+
244+
.. ipython:: python
245+
246+
s = Series([1, 1, 2, 3, 3, 3])
247+
s.groupby(s).filter(lambda x: x.sum() > 2)
248+
249+
The argument of ``filter`` must a function that, applied to the group as a
250+
whole, returns ``True`` or ``False``.
251+
252+
Another useful operation is filtering out elements that belong to groups
253+
with only a couple members.
254+
255+
.. ipython:: python
256+
257+
df = DataFrame({'A': arange(8), 'B': list('aabbbbcc')})
258+
df.groupby('B').filter(lambda x: len(x) > 2)
259+
260+
Alternatively, instead of dropping the offending groups, we can return a
261+
like-indexed objects where the groups that do not pass the filter are
262+
filled with NaNs.
263+
264+
.. ipython:: python
265+
266+
df.groupby('B').filter(lambda x: len(x) > 2, dropna=False)
267+
268+
240269
Bug Fixes
241270
~~~~~~~~~
242271

pandas/core/groupby.py

+125-28
Original file line numberDiff line numberDiff line change
@@ -1558,6 +1558,42 @@ def transform(self, func, *args, **kwargs):
15581558
result = _possibly_downcast_to_dtype(result, dtype)
15591559
return self.obj.__class__(result,index=self.obj.index,name=self.obj.name)
15601560

1561+
def filter(self, func, dropna=True, *args, **kwargs):
1562+
"""
1563+
Return a copy of a Series excluding elements from groups that
1564+
do not satisfy the boolean criterion specified by func.
1565+
1566+
Parameters
1567+
----------
1568+
func : function
1569+
To apply to each group. Should return True or False.
1570+
dropna : Drop groups that do not pass the filter. True by default;
1571+
if False, groups that evaluate False are filled with NaNs.
1572+
1573+
Example
1574+
-------
1575+
>>> grouped.filter(lambda x: x.mean() > 0)
1576+
1577+
Returns
1578+
-------
1579+
filtered : Series
1580+
"""
1581+
if isinstance(func, basestring):
1582+
wrapper = lambda x: getattr(x, func)(*args, **kwargs)
1583+
else:
1584+
wrapper = lambda x: func(x, *args, **kwargs)
1585+
1586+
indexers = [self.obj.index.get_indexer(group.index) \
1587+
if wrapper(group) else [] for _ , group in self]
1588+
1589+
if len(indexers) == 0:
1590+
filtered = self.obj.take([]) # because np.concatenate would fail
1591+
else:
1592+
filtered = self.obj.take(np.concatenate(indexers))
1593+
if dropna:
1594+
return filtered
1595+
else:
1596+
return filtered.reindex(self.obj.index) # Fill with NaNs.
15611597

15621598
class NDFrameGroupBy(GroupBy):
15631599

@@ -1928,47 +1964,22 @@ def transform(self, func, *args, **kwargs):
19281964

19291965
obj = self._obj_with_exclusions
19301966
gen = self.grouper.get_iterator(obj, axis=self.axis)
1931-
1932-
if isinstance(func, basestring):
1933-
fast_path = lambda group: getattr(group, func)(*args, **kwargs)
1934-
slow_path = lambda group: group.apply(lambda x: getattr(x, func)(*args, **kwargs), axis=self.axis)
1935-
else:
1936-
fast_path = lambda group: func(group, *args, **kwargs)
1937-
slow_path = lambda group: group.apply(lambda x: func(x, *args, **kwargs), axis=self.axis)
1967+
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
19381968

19391969
path = None
19401970
for name, group in gen:
19411971
object.__setattr__(group, 'name', name)
19421972

1943-
# decide on a fast path
19441973
if path is None:
1945-
1946-
path = slow_path
1974+
# Try slow path and fast path.
19471975
try:
1948-
res = slow_path(group)
1949-
1950-
# if we make it here, test if we can use the fast path
1951-
try:
1952-
res_fast = fast_path(group)
1953-
1954-
# compare that we get the same results
1955-
if res.shape == res_fast.shape:
1956-
res_r = res.values.ravel()
1957-
res_fast_r = res_fast.values.ravel()
1958-
mask = notnull(res_r)
1959-
if (res_r[mask] == res_fast_r[mask]).all():
1960-
path = fast_path
1961-
1962-
except:
1963-
pass
1976+
path, res = self._choose_path(fast_path, slow_path, group)
19641977
except TypeError:
19651978
return self._transform_item_by_item(obj, fast_path)
19661979
except Exception: # pragma: no cover
19671980
res = fast_path(group)
19681981
path = fast_path
1969-
19701982
else:
1971-
19721983
res = path(group)
19731984

19741985
# broadcasting
@@ -1988,6 +1999,35 @@ def transform(self, func, *args, **kwargs):
19881999
concatenated.sort_index(inplace=True)
19892000
return concatenated
19902001

2002+
def _define_paths(self, func, *args, **kwargs):
2003+
if isinstance(func, basestring):
2004+
fast_path = lambda group: getattr(group, func)(*args, **kwargs)
2005+
slow_path = lambda group: group.apply(lambda x: getattr(x, func)(*args, **kwargs), axis=self.axis)
2006+
else:
2007+
fast_path = lambda group: func(group, *args, **kwargs)
2008+
slow_path = lambda group: group.apply(lambda x: func(x, *args, **kwargs), axis=self.axis)
2009+
return fast_path, slow_path
2010+
2011+
def _choose_path(self, fast_path, slow_path, group):
2012+
path = slow_path
2013+
res = slow_path(group)
2014+
2015+
# if we make it here, test if we can use the fast path
2016+
try:
2017+
res_fast = fast_path(group)
2018+
2019+
# compare that we get the same results
2020+
if res.shape == res_fast.shape:
2021+
res_r = res.values.ravel()
2022+
res_fast_r = res_fast.values.ravel()
2023+
mask = notnull(res_r)
2024+
if (res_r[mask] == res_fast_r[mask]).all():
2025+
path = fast_path
2026+
2027+
except:
2028+
pass
2029+
return path, res
2030+
19912031
def _transform_item_by_item(self, obj, wrapper):
19922032
# iterate through columns
19932033
output = {}
@@ -2008,6 +2048,63 @@ def _transform_item_by_item(self, obj, wrapper):
20082048

20092049
return DataFrame(output, index=obj.index, columns=columns)
20102050

2051+
def filter(self, func, dropna=True, *args, **kwargs):
2052+
"""
2053+
Return a copy of a DataFrame excluding elements from groups that
2054+
do not satisfy the boolean criterion specified by func.
2055+
2056+
Parameters
2057+
----------
2058+
f : function
2059+
Function to apply to each subframe. Should return True or False.
2060+
dropna : Drop groups that do not pass the filter. True by default;
2061+
if False, groups that evaluate False are filled with NaNs.
2062+
2063+
Note
2064+
----
2065+
Each subframe is endowed the attribute 'name' in case you need to know
2066+
which group you are working on.
2067+
2068+
Example
2069+
--------
2070+
>>> grouped = df.groupby(lambda x: mapping[x])
2071+
>>> grouped.filter(lambda x: x['A'].sum() + x['B'].sum() > 0)
2072+
"""
2073+
from pandas.tools.merge import concat
2074+
2075+
indexers = []
2076+
2077+
obj = self._obj_with_exclusions
2078+
gen = self.grouper.get_iterator(obj, axis=self.axis)
2079+
2080+
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
2081+
2082+
path = None
2083+
for name, group in gen:
2084+
object.__setattr__(group, 'name', name)
2085+
2086+
if path is None:
2087+
# Try slow path and fast path.
2088+
try:
2089+
path, res = self._choose_path(fast_path, slow_path, group)
2090+
except Exception: # pragma: no cover
2091+
res = fast_path(group)
2092+
path = fast_path
2093+
else:
2094+
res = path(group)
2095+
2096+
if res:
2097+
indexers.append(self.obj.index.get_indexer(group.index))
2098+
2099+
if len(indexers) == 0:
2100+
filtered = self.obj.take([]) # because np.concatenate would fail
2101+
else:
2102+
filtered = self.obj.take(np.concatenate(indexers))
2103+
if dropna:
2104+
return filtered
2105+
else:
2106+
return filtered.reindex(self.obj.index) # Fill with NaNs.
2107+
20112108

20122109
class DataFrameGroupBy(NDFrameGroupBy):
20132110

0 commit comments

Comments
 (0)