Skip to content

BUG: Fixes color selection in andrews_curve #5378

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 31, 2013
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ Improvements to existing features
wrapper is updated inplace, a copy is still made internally.
(:issue:`1960`, :issue:`5247`, and related :issue:`2325` [still not
closed])
- Fixed bug in `tools.plotting.andrews_curvres` so that lines are drawn grouped
by color as expected.

API Changes
~~~~~~~~~~~
Expand Down
45 changes: 23 additions & 22 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,13 +454,14 @@ def f(x):

n = len(data)
class_col = data[class_column]
uniq_class = class_col.drop_duplicates()
columns = [data[col] for col in data.columns if (col != class_column)]
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
used_legends = set([])

colors = _get_standard_colors(num_colors=n, colormap=colormap,
colors = _get_standard_colors(num_colors=len(uniq_class), colormap=colormap,
color_type='random', color=kwds.get('color'))

col_dict = dict([(klass, col) for klass, col in zip(uniq_class, colors)])
if ax is None:
ax = plt.gca(xlim=(-pi, pi))
for i in range(n):
Expand All @@ -471,9 +472,9 @@ def f(x):
if com.pprint_thing(class_col[i]) not in used_legends:
label = com.pprint_thing(class_col[i])
used_legends.add(label)
ax.plot(x, y, color=colors[i], label=label, **kwds)
ax.plot(x, y, color=col_dict[class_col[i]], label=label, **kwds)
else:
ax.plot(x, y, color=colors[i], **kwds)
ax.plot(x, y, color=col_dict[class_col[i]], **kwds)

ax.legend(loc='upper right')
ax.grid()
Expand Down Expand Up @@ -656,10 +657,10 @@ def lag_plot(series, lag=1, ax=None, **kwds):
ax: Matplotlib axis object
"""
import matplotlib.pyplot as plt

# workaround because `c='b'` is hardcoded in matplotlibs scatter method
kwds.setdefault('c', plt.rcParams['patch.facecolor'])

data = series.values
y1 = data[:-lag]
y2 = data[lag:]
Expand Down Expand Up @@ -1212,20 +1213,20 @@ def __init__(self, data, x, y, **kwargs):
y = self.data.columns[y]
self.x = x
self.y = y


def _make_plot(self):
x, y, data = self.x, self.y, self.data
ax = self.axes[0]
ax.scatter(data[x].values, data[y].values, **self.kwds)

def _post_plot_logic(self):
ax = self.axes[0]
x, y = self.x, self.y
x, y = self.x, self.y
ax.set_ylabel(com.pprint_thing(y))
ax.set_xlabel(com.pprint_thing(x))


class LinePlot(MPLPlot):

def __init__(self, data, **kwargs):
Expand Down Expand Up @@ -1658,25 +1659,25 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
elif kind == 'kde':
klass = KdePlot
elif kind == 'scatter':
klass = ScatterPlot
klass = ScatterPlot
else:
raise ValueError('Invalid chart type given %s' % kind)

if kind == 'scatter':
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
rot=rot,legend=legend, ax=ax, style=style,
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
rot=rot,legend=legend, ax=ax, style=style,
fontsize=fontsize, use_index=use_index, sharex=sharex,
sharey=sharey, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, title=title, grid=grid,
figsize=figsize, logx=logx, logy=logy,
sort_columns=sort_columns, secondary_y=secondary_y,
sharey=sharey, xticks=xticks, yticks=yticks,
xlim=xlim, ylim=ylim, title=title, grid=grid,
figsize=figsize, logx=logx, logy=logy,
sort_columns=sort_columns, secondary_y=secondary_y,
**kwds)
else:
if x is not None:
if com.is_integer(x) and not frame.columns.holds_integer():
x = frame.columns[x]
frame = frame.set_index(x)

if y is not None:
if com.is_integer(y) and not frame.columns.holds_integer():
y = frame.columns[y]
Expand All @@ -1691,7 +1692,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
grid=grid, logx=logx, logy=logy,
secondary_y=secondary_y, title=title,
figsize=figsize, fontsize=fontsize, **kwds)

else:
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
legend=legend, ax=ax, style=style, fontsize=fontsize,
Expand All @@ -1700,7 +1701,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
title=title, grid=grid, figsize=figsize, logx=logx,
logy=logy, sort_columns=sort_columns,
secondary_y=secondary_y, **kwds)

plot_obj.generate()
plot_obj.draw()
if subplots:
Expand Down