Skip to content

Support for customizing parallel_plot() x axis tickmarks #2287

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 11, 2012
9 changes: 9 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,15 @@ def test_parallel_coordinates(self):
path = os.path.join(curpath(), 'data/iris.csv')
df = read_csv(path)
_check_plot_works(parallel_coordinates, df, 'Name')
_check_plot_works(parallel_coordinates, df, 'Name',
colors=('#556270', '#4ECDC4', '#C7F464'))
_check_plot_works(parallel_coordinates, df, 'Name',
colors=['dodgerblue', 'aquamarine', 'seagreen'])

df = read_csv(path, header=None, skiprows=1, names=[1,2,4,8, 'Name'])
_check_plot_works(parallel_coordinates, df, 'Name', use_columns=True)
_check_plot_works(parallel_coordinates, df, 'Name',
xticks=[1, 5, 25, 125])

@slow
def test_radviz(self):
Expand Down
71 changes: 57 additions & 14 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,20 +411,41 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
return fig


def parallel_coordinates(data, class_column, cols=None, ax=None, **kwds):
def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
use_columns=False, xticks=None, **kwds):
"""Parallel coordinates plotting.

Parameters:
-----------
data: A DataFrame containing data to be plotted
class_column: Column name containing class names
cols: A list of column names to use, optional
ax: matplotlib axis object, optional
kwds: A list of keywords for matplotlib plot method
Parameters
----------
data: DataFrame
A DataFrame containing data to be plotted
class_column: str
Column name containing class names
cols: list, optional
A list of column names to use
ax: matplotlib.axis, optional
matplotlib axis object
colors: list or tuple, optional
Colors to use for the different classes
use_columns: bool, optional
If true, columns will be used as xticks
xticks: list or tuple, optional
A list of values to use for xticks
kwds: list, optional
A list of keywords for matplotlib plot method

Returns:
--------
Returns
-------
ax: matplotlib axis object

Examples
--------
>>> from pandas import read_csv
>>> from pandas.tools.plotting import parallel_coordinates
>>> from matplotlib import pyplot as plt
>>> df = read_csv('https://raw.github.com/pydata/pandas/master/pandas/tests/data/iris.csv')
>>> parallel_coordinates(df, 'Name', colors=('#556270', '#4ECDC4', '#C7F464'))
>>> plt.show()
"""
import matplotlib.pyplot as plt
import random
Expand All @@ -444,28 +465,50 @@ def random_color(column):
used_legends = set([])

ncols = len(df.columns)
x = range(ncols)

# determine values to use for xticks
if use_columns is True:
if not np.all(np.isreal(list(df.columns))):
raise ValueError('Columns must be numeric to be used as xticks')
x = df.columns
elif xticks is not None:
if not np.all(np.isreal(xticks)):
raise ValueError('xticks specified must be numeric')
elif len(xticks) != ncols:
raise ValueError('Length of xticks must match number of columns')
x = xticks
else:
x = range(ncols)

if ax == None:
ax = plt.gca()

# if user has not specified colors to use, choose at random
if colors is None:
colors = dict((kls, random_color(kls)) for kls in classes)
else:
if len(colors) != len(classes):
raise ValueError('Number of colors must match number of classes')
colors = dict((kls, colors[i]) for i, kls in enumerate(classes))

for i in range(n):
row = df.irow(i).values
y = row
kls = class_col.iget_value(i)
if com.pprint_thing(kls) not in used_legends:
label = com.pprint_thing(kls)
used_legends.add(label)
ax.plot(x, y, color=random_color(kls),
ax.plot(x, y, color=colors[kls],
label=label, **kwds)
else:
ax.plot(x, y, color=random_color(kls), **kwds)
ax.plot(x, y, color=colors[kls], **kwds)

for i in range(ncols):
for i in x:
ax.axvline(i, linewidth=1, color='black')

ax.set_xticks(x)
ax.set_xticklabels(df.columns)
ax.set_xlim(x[0], x[-1])
ax.legend(loc='upper right')
ax.grid()
return ax
Expand Down