diff --git a/doc/source/release.rst b/doc/source/release.rst index 4d628fac78cf0..f628917238e47 100644 --- a/doc/source/release.rst +++ b/doc/source/release.rst @@ -205,6 +205,8 @@ Improvements to existing features wrapper is updated inplace, a copy is still made internally. (:issue:`1960`, :issue:`5247`, and related :issue:`2325` [still not closed]) + - Fixed bug in `tools.plotting.andrews_curvres` so that lines are drawn grouped + by color as expected. API Changes ~~~~~~~~~~~ diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 7de5840384974..c4255e706b19f 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -454,13 +454,14 @@ def f(x): n = len(data) class_col = data[class_column] + uniq_class = class_col.drop_duplicates() columns = [data[col] for col in data.columns if (col != class_column)] x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)] used_legends = set([]) - colors = _get_standard_colors(num_colors=n, colormap=colormap, + colors = _get_standard_colors(num_colors=len(uniq_class), colormap=colormap, color_type='random', color=kwds.get('color')) - + col_dict = dict([(klass, col) for klass, col in zip(uniq_class, colors)]) if ax is None: ax = plt.gca(xlim=(-pi, pi)) for i in range(n): @@ -471,9 +472,9 @@ def f(x): if com.pprint_thing(class_col[i]) not in used_legends: label = com.pprint_thing(class_col[i]) used_legends.add(label) - ax.plot(x, y, color=colors[i], label=label, **kwds) + ax.plot(x, y, color=col_dict[class_col[i]], label=label, **kwds) else: - ax.plot(x, y, color=colors[i], **kwds) + ax.plot(x, y, color=col_dict[class_col[i]], **kwds) ax.legend(loc='upper right') ax.grid() @@ -656,10 +657,10 @@ def lag_plot(series, lag=1, ax=None, **kwds): ax: Matplotlib axis object """ import matplotlib.pyplot as plt - + # workaround because `c='b'` is hardcoded in matplotlibs scatter method kwds.setdefault('c', plt.rcParams['patch.facecolor']) - + data = series.values y1 = data[:-lag] y2 = data[lag:] @@ -1212,20 +1213,20 @@ def __init__(self, data, x, y, **kwargs): y = self.data.columns[y] self.x = x self.y = y - - + + def _make_plot(self): x, y, data = self.x, self.y, self.data ax = self.axes[0] ax.scatter(data[x].values, data[y].values, **self.kwds) - + def _post_plot_logic(self): ax = self.axes[0] - x, y = self.x, self.y + x, y = self.x, self.y ax.set_ylabel(com.pprint_thing(y)) ax.set_xlabel(com.pprint_thing(x)) - - + + class LinePlot(MPLPlot): def __init__(self, data, **kwargs): @@ -1658,25 +1659,25 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, elif kind == 'kde': klass = KdePlot elif kind == 'scatter': - klass = ScatterPlot + klass = ScatterPlot else: raise ValueError('Invalid chart type given %s' % kind) if kind == 'scatter': - plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, - rot=rot,legend=legend, ax=ax, style=style, + plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, + rot=rot,legend=legend, ax=ax, style=style, fontsize=fontsize, use_index=use_index, sharex=sharex, - sharey=sharey, xticks=xticks, yticks=yticks, - xlim=xlim, ylim=ylim, title=title, grid=grid, - figsize=figsize, logx=logx, logy=logy, - sort_columns=sort_columns, secondary_y=secondary_y, + sharey=sharey, xticks=xticks, yticks=yticks, + xlim=xlim, ylim=ylim, title=title, grid=grid, + figsize=figsize, logx=logx, logy=logy, + sort_columns=sort_columns, secondary_y=secondary_y, **kwds) else: if x is not None: if com.is_integer(x) and not frame.columns.holds_integer(): x = frame.columns[x] frame = frame.set_index(x) - + if y is not None: if com.is_integer(y) and not frame.columns.holds_integer(): y = frame.columns[y] @@ -1691,7 +1692,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, grid=grid, logx=logx, logy=logy, secondary_y=secondary_y, title=title, figsize=figsize, fontsize=fontsize, **kwds) - + else: plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, legend=legend, ax=ax, style=style, fontsize=fontsize, @@ -1700,7 +1701,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, title=title, grid=grid, figsize=figsize, logx=logx, logy=logy, sort_columns=sort_columns, secondary_y=secondary_y, **kwds) - + plot_obj.generate() plot_obj.draw() if subplots: