Skip to content

Commit 95c2a29

Browse files
committed
Refactor geom.draw_group to take a dataframe
Finally got rid of ``geom._make_pinfos`. Had to create a wrapper `groupby_with_null` around `DataFrame.groupby` to allow grouping on columns with Null values. Almost at the same time a PR [1] popped up to probably solve this issue. --- [1] pandas-dev/pandas#12607
1 parent 499e61b commit 95c2a29

24 files changed

+648
-592
lines changed

ggplot/geoms/geom.py

+27-69
Original file line numberDiff line numberDiff line change
@@ -27,19 +27,6 @@ class geom(object):
2727
# not implemented
2828
legend_geom = 'point'
2929

30-
# A matplotlib plot function may require that an aethestic have a
31-
# single unique value. e.g. linestyle='dashed' and not
32-
# linestyle=['dashed', 'dotted', ...].
33-
# A single call to such a function can only plot lines with the
34-
# same linestyle. However, if the plot we want has more than one
35-
# line with different linestyles, we need to group the lines with
36-
# the same linestyle and plot them as one unit.
37-
#
38-
# geoms should fill out this set with such aesthetics so that the
39-
# plot information they receive can be plotted in a single call.
40-
# See: geom_point
41-
_units = set()
42-
4330
# Whether to divide the distance between any two points into
4431
# multiple segments. This is done during coord.transform time
4532
_munch = False
@@ -167,12 +154,35 @@ def draw_panel(self, data, panel_scales, coord, ax, **params):
167154
"""
168155
data = coord.transform(data, panel_scales, self._munch)
169156
for _, gdata in data.groupby('group'):
170-
pinfos = self._make_pinfos(gdata, params)
171-
for pinfo in pinfos:
172-
self.draw_group(pinfo, panel_scales, coord, ax, **params)
157+
gdata.reset_index(inplace=True, drop=True)
158+
gdata.is_copy = None
159+
self.draw_group(gdata, panel_scales, coord, ax, **params)
160+
161+
@staticmethod
162+
def draw_group(data, panel_scales, coord, ax, **params):
163+
"""
164+
Plot data
165+
"""
166+
msg = "The geom should implement this method."
167+
raise NotImplementedError(msg)
173168

174169
@staticmethod
175-
def draw_group(pinfo, panel_scales, coord, ax, **params):
170+
def draw_unit(data, panel_scales, coord, ax, **params):
171+
"""
172+
Plot data
173+
174+
A matplotlib plot function may require that an aethestic
175+
have a single unique value. e.g. linestyle='dashed' and
176+
not linestyle=['dashed', 'dotted', ...].
177+
A single call to such a function can only plot lines with
178+
the same linestyle. However, if the plot we want has more
179+
than one line with different linestyles, we need to group
180+
the lines with the same linestyle and plot them as one
181+
unit. In this case, draw_group calls this function to do
182+
the plotting.
183+
184+
See: geom_point
185+
"""
176186
msg = "The geom should implement this method."
177187
raise NotImplementedError(msg)
178188

@@ -284,55 +294,3 @@ def verify_arguments(self, kwargs):
284294
if unknown:
285295
msg = 'Unknown parameters {}'
286296
raise GgplotError(msg.format(unknown))
287-
288-
def _make_pinfos(self, data, params):
289-
units = []
290-
for col in data.columns:
291-
if col in self._units:
292-
units.append(col)
293-
294-
shrinkable = {'alpha', 'fill', 'color', 'size', 'linetype',
295-
'shape'}
296-
297-
def prep(pinfo):
298-
"""
299-
Reduce shrinkable parameters & append zorder
300-
"""
301-
# If it is the same value in the list make it a scalar
302-
# This can help the matplotlib functions draw faster
303-
for ae in set(pinfo) & shrinkable:
304-
with suppress(TypeError, IndexError):
305-
if all(pinfo[ae][0] == v for v in pinfo[ae]):
306-
pinfo[ae] = pinfo[ae][0]
307-
pinfo['zorder'] = params['zorder']
308-
return pinfo
309-
310-
out = []
311-
if units:
312-
# Currently groupby does not like None values in any of
313-
# the columns that participate in the grouping. These
314-
# Nones come in when the default aesthetics are added to
315-
# the data. We drop these columns and after turning the
316-
# the dataframe into a dictionary insert a None for that
317-
# aesthetic
318-
_units = []
319-
_none_units = []
320-
for unit in units:
321-
if data[unit].iloc[0] is None:
322-
_none_units.append(unit)
323-
del data[unit]
324-
else:
325-
_units.append(unit)
326-
327-
for name, _data in data.groupby(_units):
328-
pinfo = _data.to_dict('list')
329-
for ae in _units:
330-
pinfo[ae] = pinfo[ae][0]
331-
for ae in _none_units:
332-
pinfo[ae] = None
333-
out.append(prep(pinfo))
334-
else:
335-
pinfo = data.to_dict('list')
336-
out.append(prep(pinfo))
337-
338-
return out

ggplot/geoms/geom_abline.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def draw_panel(self, data, panel_scales, coord, ax, **params):
5151
data = data.drop_duplicates()
5252

5353
for _, gdata in data.groupby('group'):
54-
pinfos = self._make_pinfos(gdata, params)
55-
for pinfo in pinfos:
56-
geom_segment.draw_group(pinfo, panel_scales,
57-
coord, ax, **params)
54+
gdata.reset_index(inplace=True)
55+
gdata.is_copy = None
56+
geom_segment.draw_group(gdata, panel_scales,
57+
coord, ax, **params)

ggplot/geoms/geom_boxplot.py

+42-36
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from __future__ import (absolute_import, division, print_function,
22
unicode_literals)
3-
from copy import deepcopy
43

4+
import numpy as np
5+
import pandas as pd
56
import matplotlib.lines as mlines
67
from matplotlib.patches import Rectangle
78

89
from ..scales.utils import resolution
9-
from ..utils import make_iterable_ntimes, to_rgba
10+
from ..utils import make_iterable_ntimes, to_rgba, copy_missing_columns
1011
from .geom_point import geom_point
1112
from .geom_segment import geom_segment
1213
from .geom_crossbar import geom_crossbar
@@ -22,7 +23,8 @@ class geom_boxplot(geom):
2223
'outlier_alpha': 1, 'outlier_color': None,
2324
'outlier_shape': 'o', 'outlier_size': 5,
2425
'outlier_stroke': 0, 'notch': False,
25-
'varwidth': False, 'notchwidth': 0.5}
26+
'varwidth': False, 'notchwidth': 0.5,
27+
'fatten': 2}
2628

2729
def setup_data(self, data):
2830
if 'width' not in data:
@@ -50,59 +52,63 @@ def setup_data(self, data):
5052
return data
5153

5254
@staticmethod
53-
def draw_group(pinfo, panel_scales, coord, ax, **params):
55+
def draw_group(data, panel_scales, coord, ax, **params):
56+
def flat(*args):
57+
"""Flatten list-likes"""
58+
return np.hstack(args)
5459

55-
def subdict(keys):
56-
d = {}
57-
for key in keys:
58-
d[key] = deepcopy(pinfo[key])
59-
return d
60-
61-
common = subdict(('color', 'size', 'linetype',
62-
'fill', 'group', 'alpha',
63-
'zorder'))
64-
65-
whiskers = subdict(('x',))
66-
whiskers.update(deepcopy(common))
67-
whiskers['x'] = whiskers['x'] * 2
60+
common_columns = ['color', 'size', 'linetype',
61+
'fill', 'group', 'alpha', 'shape']
62+
# whiskers
63+
whiskers = pd.DataFrame({
64+
'x': flat(data['x'], data['x']),
65+
'y': flat(data['upper'], data['lower']),
66+
'yend': flat(data['ymax'], data['ymin'])})
6867
whiskers['xend'] = whiskers['x']
69-
whiskers['y'] = pinfo['upper'] + pinfo['lower']
70-
whiskers['yend'] = pinfo['ymax'] + pinfo['ymin']
71-
72-
box = subdict(('xmin', 'xmax', 'lower', 'middle', 'upper'))
73-
box.update(deepcopy(common))
74-
box['ymin'] = box.pop('lower')
75-
box['y'] = box.pop('middle')
76-
box['ymax'] = box.pop('upper')
77-
box['notchwidth'] = params['notchwidth']
68+
copy_missing_columns(whiskers, data[common_columns])
69+
70+
# box
71+
box_columns = ['xmin', 'xmax', 'lower', 'middle', 'upper']
72+
box = data[common_columns + box_columns].copy()
73+
box.rename(columns={'lower': 'ymin',
74+
'middle': 'y',
75+
'upper': 'ymax'},
76+
inplace=True)
77+
78+
# notch
7879
if params['notch']:
79-
box['ynotchlower'] = pinfo['notchlower']
80-
box['ynotchupper'] = pinfo['notchupper']
80+
box['ynotchlower'] = data['notchlower']
81+
box['ynotchupper'] = data['notchupper']
8182

82-
if 'outliers' in pinfo and len(pinfo['outliers'][0]):
83-
outliers = subdict(('alpha', 'zorder'))
83+
# outliers
84+
try:
85+
num_outliers = len(data['outliers'].iloc[0])
86+
except KeyError:
87+
num_outliers = 0
8488

89+
if num_outliers:
8590
def outlier_value(param):
8691
oparam = 'outlier_{}'.format(param)
8792
if params[oparam] is not None:
8893
return params[oparam]
89-
return pinfo[param]
94+
return data[param].iloc[0]
9095

91-
outliers['y'] = pinfo['outliers'][0]
92-
outliers['x'] = make_iterable_ntimes(pinfo['x'][0],
93-
len(outliers['y']))
96+
outliers = pd.DataFrame({
97+
'y': data['outliers'].iloc[0],
98+
'x': make_iterable_ntimes(data['x'][0],
99+
num_outliers),
100+
'fill': None})
94101
outliers['alpha'] = outlier_value('alpha')
95102
outliers['color'] = outlier_value('color')
96-
outliers['fill'] = None
97103
outliers['shape'] = outlier_value('shape')
98104
outliers['size'] = outlier_value('size')
99105
outliers['stroke'] = outlier_value('stroke')
100106
geom_point.draw_group(outliers, panel_scales,
101107
coord, ax, **params)
102108

109+
# plot
103110
geom_segment.draw_group(whiskers, panel_scales,
104111
coord, ax, **params)
105-
params['fatten'] = geom_crossbar.DEFAULT_PARAMS['fatten']
106112
geom_crossbar.draw_group(box, panel_scales,
107113
coord, ax, **params)
108114

ggplot/geoms/geom_crossbar.py

+29-34
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
unicode_literals)
33

44
import numpy as np
5+
import pandas as pd
56
import matplotlib.lines as mlines
67
from matplotlib.patches import Rectangle
78

89
from ..scales.utils import resolution
910
from ..utils.exceptions import gg_warn
10-
from ..utils import to_rgba
11+
from ..utils import copy_missing_columns
1112
from .geom import geom
1213
from .geom_polygon import geom_polygon
1314
from .geom_segment import geom_segment
@@ -33,64 +34,58 @@ def setup_data(self, data):
3334
return data
3435

3536
@staticmethod
36-
def draw_group(pinfo, panel_scales, coord, ax, **params):
37-
y = pinfo['y']
38-
xmin = np.array(pinfo['xmin'])
39-
xmax = np.array(pinfo['xmax'])
40-
ymin = np.array(pinfo['ymin'])
41-
ymax = np.array(pinfo['ymax'])
42-
notchwidth = pinfo.get('notchwidth')
43-
ynotchupper = pinfo.get('ynotchupper')
44-
ynotchlower = pinfo.get('ynotchlower')
45-
46-
keys = ['alpha', 'color', 'fill', 'size',
47-
'linetype', 'zorder']
48-
49-
def copy_keys(d):
50-
for k in keys:
51-
d[k] = pinfo[k]
37+
def draw_group(data, panel_scales, coord, ax, **params):
38+
y = data['y']
39+
xmin = data['xmin']
40+
xmax = data['xmax']
41+
ymin = data['ymin']
42+
ymax = data['ymax']
43+
group = data['group']
44+
45+
# From violin
46+
notchwidth = data.get('notchwidth')
47+
ynotchupper = data.get('ynotchupper')
48+
ynotchlower = data.get('ynotchlower')
5249

5350
def flat(*args):
5451
"""Flatten list-likes"""
55-
return [i for arg in args for i in arg]
52+
return np.hstack(args)
5653

57-
middle = {'x': xmin,
58-
'y': y,
59-
'xend': xmax,
60-
'yend': y,
61-
'group': pinfo['group']}
62-
copy_keys(middle)
63-
middle['size'] = np.asarray(middle['size'])*params['fatten'],
54+
middle = pd.DataFrame({'x': xmin,
55+
'y': y,
56+
'xend': xmax,
57+
'yend': y,
58+
'group': group})
59+
copy_missing_columns(middle, data)
60+
middle['size'] *= params['fatten']
6461

6562
has_notch = ynotchlower is not None and ynotchupper is not None
6663
if has_notch: # 10 points + 1 closing
67-
ynotchlower = np.array(ynotchlower)
68-
ynotchupper = np.array(ynotchupper)
6964
if (any(ynotchlower < ymin) or any(ynotchupper > ymax)):
7065
msg = ("Notch went outside hinges."
7166
" Try setting notch=False.")
7267
gg_warn(msg)
7368

7469
notchindent = (1 - notchwidth) * (xmax-xmin)/2
7570

76-
middle['x'] = np.array(middle['x']) + notchindent
77-
middle['xend'] = np.array(middle['xend']) - notchindent
78-
box = {
71+
middle['x'] += notchindent
72+
middle['xend'] -= notchindent
73+
box = pd.DataFrame({
7974
'x': flat(xmin, xmin, xmin+notchindent, xmin, xmin,
8075
xmax, xmax, xmax-notchindent, xmax, xmax,
8176
xmin),
8277
'y': flat(ymax, ynotchupper, y, ynotchlower, ymin,
8378
ymin, ynotchlower, y, ynotchupper, ymax,
8479
ymax),
85-
'group': np.tile(np.arange(1, len(pinfo['group'])+1), 11)}
80+
'group': np.tile(np.arange(1, len(group)+1), 11)})
8681
else:
8782
# No notch, 4 points + 1 closing
88-
box = {
83+
box = pd.DataFrame({
8984
'x': flat(xmin, xmin, xmax, xmax, xmin),
9085
'y': flat(ymax, ymax, ymax, ymin, ymin),
91-
'group': np.tile(np.arange(1, len(pinfo['group'])+1), 5)}
92-
copy_keys(box)
86+
'group': np.tile(np.arange(1, len(group)+1), 5)})
9387

88+
copy_missing_columns(box, data)
9489
geom_polygon.draw_group(box, panel_scales, coord, ax, **params)
9590
geom_segment.draw_group(middle, panel_scales, coord, ax, **params)
9691

0 commit comments

Comments
 (0)