Skip to content

Commit 20a8441

Browse files
committed
BUG: allow plot, boxplot, hist and completion on GroupBy objects
1 parent bb61f23 commit 20a8441

File tree

2 files changed

+112
-25
lines changed

2 files changed

+112
-25
lines changed

pandas/core/groupby.py

+40-25
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import types
2+
from functools import wraps
23
import numpy as np
34

45
from pandas.compat import(
@@ -45,14 +46,19 @@
4546
"""
4647

4748

49+
# special case to prevent duplicate plots when catching exceptions when
50+
# forwarding methods from NDFrames
51+
_plotting_methods = frozenset(['plot', 'boxplot', 'hist'])
52+
4853
_apply_whitelist = frozenset(['last', 'first',
4954
'mean', 'sum', 'min', 'max',
5055
'head', 'tail',
5156
'cumsum', 'cumprod', 'cummin', 'cummax',
5257
'resample',
5358
'describe',
5459
'rank', 'quantile', 'count',
55-
'fillna', 'dtype'])
60+
'fillna', 'dtype']) | _plotting_methods
61+
5662

5763

5864
class GroupByError(Exception):
@@ -180,7 +186,6 @@ class GroupBy(PandasObject):
180186
len(grouped) : int
181187
Number of groups
182188
"""
183-
184189
def __init__(self, obj, keys=None, axis=0, level=None,
185190
grouper=None, exclusions=None, selection=None, as_index=True,
186191
sort=True, group_keys=True, squeeze=False):
@@ -244,6 +249,9 @@ def _selection_list(self):
244249
return [self._selection]
245250
return self._selection
246251

252+
def _local_dir(self):
253+
return sorted(set(self.obj._local_dir() + list(_apply_whitelist)))
254+
247255
def __getattr__(self, attr):
248256
if attr in self.obj:
249257
return self[attr]
@@ -281,9 +289,16 @@ def wrapper(*args, **kwargs):
281289

282290
def curried_with_axis(x):
283291
return f(x, *args, **kwargs_with_axis)
292+
curried_with_axis.__name__ = name
284293

285294
def curried(x):
286295
return f(x, *args, **kwargs)
296+
curried.__name__ = name
297+
298+
# special case otherwise extra plots are created when catching the
299+
# exception below
300+
if name in _plotting_methods:
301+
return self.apply(curried)
287302

288303
try:
289304
return self.apply(curried_with_axis)
@@ -348,7 +363,11 @@ def apply(self, func, *args, **kwargs):
348363
applied : type depending on grouped object and function
349364
"""
350365
func = _intercept_function(func)
351-
f = lambda g: func(g, *args, **kwargs)
366+
367+
@wraps(func)
368+
def f(g):
369+
return func(g, *args, **kwargs)
370+
352371
return self._python_apply_general(f)
353372

354373
def _python_apply_general(self, f):
@@ -598,7 +617,7 @@ def __iter__(self):
598617
def nkeys(self):
599618
return len(self.groupings)
600619

601-
def get_iterator(self, data, axis=0, keep_internal=True):
620+
def get_iterator(self, data, axis=0):
602621
"""
603622
Groupby iterator
604623
@@ -607,16 +626,14 @@ def get_iterator(self, data, axis=0, keep_internal=True):
607626
Generator yielding sequence of (name, subsetted object)
608627
for each group
609628
"""
610-
splitter = self._get_splitter(data, axis=axis,
611-
keep_internal=keep_internal)
629+
splitter = self._get_splitter(data, axis=axis)
612630
keys = self._get_group_keys()
613631
for key, (i, group) in zip(keys, splitter):
614632
yield key, group
615633

616-
def _get_splitter(self, data, axis=0, keep_internal=True):
634+
def _get_splitter(self, data, axis=0):
617635
comp_ids, _, ngroups = self.group_info
618-
return get_splitter(data, comp_ids, ngroups, axis=axis,
619-
keep_internal=keep_internal)
636+
return get_splitter(data, comp_ids, ngroups, axis=axis)
620637

621638
def _get_group_keys(self):
622639
if len(self.groupings) == 1:
@@ -627,19 +644,19 @@ def _get_group_keys(self):
627644
mapper = _KeyMapper(comp_ids, ngroups, self.labels, self.levels)
628645
return [mapper.get_key(i) for i in range(ngroups)]
629646

630-
def apply(self, f, data, axis=0, keep_internal=False):
647+
def apply(self, f, data, axis=0):
631648
mutated = False
632-
splitter = self._get_splitter(data, axis=axis,
633-
keep_internal=keep_internal)
649+
splitter = self._get_splitter(data, axis=axis)
634650
group_keys = self._get_group_keys()
635651

636652
# oh boy
637-
if hasattr(splitter, 'fast_apply') and axis == 0:
653+
if (f.__name__ not in _plotting_methods and
654+
hasattr(splitter, 'fast_apply') and axis == 0):
638655
try:
639656
values, mutated = splitter.fast_apply(f, group_keys)
640657
return group_keys, values, mutated
641-
except (Exception) as detail:
642-
# we detect a mutatation of some kind
658+
except Exception:
659+
# we detect a mutation of some kind
643660
# so take slow path
644661
pass
645662

@@ -1043,7 +1060,7 @@ def get_iterator(self, data, axis=0):
10431060
inds = lrange(start, n)
10441061
yield self.binlabels[-1], data.take(inds, axis=axis)
10451062

1046-
def apply(self, f, data, axis=0, keep_internal=False):
1063+
def apply(self, f, data, axis=0):
10471064
result_keys = []
10481065
result_values = []
10491066
mutated = False
@@ -1617,6 +1634,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
16171634
else:
16181635
return filtered.reindex(self.obj.index) # Fill with NaNs.
16191636

1637+
16201638
class NDFrameGroupBy(GroupBy):
16211639

16221640
def _iterate_slices(self):
@@ -2268,6 +2286,7 @@ def ohlc(self):
22682286
"""
22692287
return self._apply_to_column_groupbys(lambda x: x._cython_agg_general('ohlc'))
22702288

2289+
22712290
from pandas.tools.plotting import boxplot_frame_groupby
22722291
DataFrameGroupBy.boxplot = boxplot_frame_groupby
22732292

@@ -2364,7 +2383,7 @@ class NDArrayGroupBy(GroupBy):
23642383

23652384
class DataSplitter(object):
23662385

2367-
def __init__(self, data, labels, ngroups, axis=0, keep_internal=False):
2386+
def __init__(self, data, labels, ngroups, axis=0):
23682387
self.data = data
23692388
self.labels = com._ensure_int64(labels)
23702389
self.ngroups = ngroups
@@ -2419,10 +2438,8 @@ def _chop(self, sdata, slice_obj):
24192438

24202439

24212440
class FrameSplitter(DataSplitter):
2422-
2423-
def __init__(self, data, labels, ngroups, axis=0, keep_internal=False):
2424-
DataSplitter.__init__(self, data, labels, ngroups, axis=axis,
2425-
keep_internal=keep_internal)
2441+
def __init__(self, data, labels, ngroups, axis=0):
2442+
super(FrameSplitter, self).__init__(data, labels, ngroups, axis=axis)
24262443

24272444
def fast_apply(self, f, names):
24282445
# must return keys::list, values::list, mutated::bool
@@ -2445,10 +2462,8 @@ def _chop(self, sdata, slice_obj):
24452462

24462463

24472464
class NDFrameSplitter(DataSplitter):
2448-
2449-
def __init__(self, data, labels, ngroups, axis=0, keep_internal=False):
2450-
DataSplitter.__init__(self, data, labels, ngroups, axis=axis,
2451-
keep_internal=keep_internal)
2465+
def __init__(self, data, labels, ngroups, axis=0):
2466+
super(NDFrameSplitter, self).__init__(data, labels, ngroups, axis=axis)
24522467

24532468
self.factory = data._constructor
24542469

pandas/tests/test_groupby.py

+72
Original file line numberDiff line numberDiff line change
@@ -2728,6 +2728,78 @@ def test_groupby_whitelist(self):
27282728
with tm.assertRaisesRegexp(AttributeError, msg):
27292729
getattr(gb, bl)
27302730

2731+
def test_series_groupby_plotting_nominally_works(self):
2732+
try:
2733+
import matplotlib.pyplot as plt
2734+
except ImportError:
2735+
raise nose.SkipTest("matplotlib not installed")
2736+
n = 10
2737+
weight = Series(np.random.normal(166, 20, size=n))
2738+
height = Series(np.random.normal(60, 10, size=n))
2739+
gender = tm.choice(['male', 'female'], size=n)
2740+
2741+
weight.groupby(gender).plot()
2742+
tm.close()
2743+
height.groupby(gender).hist()
2744+
tm.close()
2745+
2746+
def test_frame_groupby_plot_boxplot(self):
2747+
try:
2748+
import matplotlib.pyplot as plt
2749+
except ImportError:
2750+
raise nose.SkipTest("matplotlib not installed")
2751+
tm.close()
2752+
2753+
n = 10
2754+
weight = Series(np.random.normal(166, 20, size=n))
2755+
height = Series(np.random.normal(60, 10, size=n))
2756+
gender = tm.choice(['male', 'female'], size=n)
2757+
df = DataFrame({'height': height, 'weight': weight, 'gender': gender})
2758+
gb = df.groupby('gender')
2759+
2760+
res = gb.plot()
2761+
self.assertEqual(len(plt.get_fignums()), 2)
2762+
self.assertEqual(len(res), 2)
2763+
tm.close()
2764+
2765+
res = gb.boxplot()
2766+
self.assertEqual(len(plt.get_fignums()), 1)
2767+
self.assertEqual(len(res), 2)
2768+
tm.close()
2769+
2770+
with tm.assertRaises(TypeError, '.*str.+float'):
2771+
gb.hist()
2772+
2773+
def test_frame_groupby_hist(self):
2774+
try:
2775+
import matplotlib.pyplot as plt
2776+
except ImportError:
2777+
raise nose.SkipTest("matplotlib not installed")
2778+
tm.close()
2779+
2780+
n = 10
2781+
weight = Series(np.random.normal(166, 20, size=n))
2782+
height = Series(np.random.normal(60, 10, size=n))
2783+
gender_int = tm.choice([0, 1], size=n)
2784+
df_int = DataFrame({'height': height, 'weight': weight,
2785+
'gender': gender_int})
2786+
gb = df_int.groupby('gender')
2787+
axes = gb.hist()
2788+
self.assertEqual(len(axes), 2)
2789+
self.assertEqual(len(plt.get_fignums()), 2)
2790+
tm.close()
2791+
2792+
def test_tab_completion(self):
2793+
grp = self.mframe.groupby(level='second')
2794+
results = set([v for v in grp.__dir__() if not v.startswith('_')])
2795+
expected = set(['A','B','C',
2796+
'agg','aggregate','apply','boxplot','filter','first','get_group',
2797+
'groups','hist','indices','last','max','mean','median',
2798+
'min','name','ngroups','nth','ohlc','plot', 'prod',
2799+
'size','std','sum','transform','var', 'count', 'head', 'describe',
2800+
'cummax', 'dtype', 'quantile', 'rank',
2801+
'cumprod', 'tail', 'resample', 'cummin', 'fillna', 'cumsum'])
2802+
self.assertEqual(results, expected)
27312803

27322804
def assert_fp_equal(a, b):
27332805
assert (np.abs(a - b) < 1e-12).all()

0 commit comments

Comments
 (0)