Skip to content

Commit de6ed81

Browse files
author
Tom Augspurger
committedMay 5, 2014
Merge pull request #6956 from anomrake/plotting
BUG: fix handling of color argument for variety of plotting functions
2 parents eb3b677 + 1980c7a commit de6ed81

File tree

4 files changed

+141
-73
lines changed

4 files changed

+141
-73
lines changed
 

‎doc/source/release.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,14 @@ Deprecations
229229
returned if possible, otherwise a copy will be made. Previously the user could think that ``copy=False`` would
230230
ALWAYS return a view. (:issue:`6894`)
231231

232+
- The :func:`parallel_coordinates` function now takes argument ``color``
233+
instead of ``colors``. A ``FutureWarning`` is raised to alert that
234+
the old ``colors`` argument will not be supported in a future release
235+
236+
- The :func:`parallel_coordinates` and :func:`andrews_curves` functions now take
237+
positional argument ``frame`` instead of ``data``. A ``FutureWarning`` is
238+
raised if the old ``data`` argument is used by name.
239+
232240
Prior Version Deprecations/Changes
233241
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
234242

@@ -462,6 +470,10 @@ Bug Fixes
462470
- Bug in timeseries-with-frequency plot cursor display (:issue:`5453`)
463471
- Bug surfaced in groupby.plot when using a ``Float64Index`` (:issue:`7025`)
464472
- Stopped tests from failing if options data isn't able to be downloaded from Yahoo (:issue:`7034`)
473+
- Bug in ``parallel_coordinates`` and ``radviz`` where reordering of class column
474+
caused possible color/class mismatch
475+
- Bug in ``radviz`` and ``andrews_curves`` where multiple values of 'color'
476+
were being passed to plotting method
465477

466478
pandas 0.13.1
467479
-------------

‎doc/source/v0.14.0.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,14 @@ Plotting
382382

383383
Because of the default `align` value changes, coordinates of bar plots are now located on integer values (0.0, 1.0, 2.0 ...). This is intended to make bar plot be located on the same coodinates as line plot. However, bar plot may differs unexpectedly when you manually adjust the bar location or drawing area, such as using `set_xlim`, `set_ylim`, etc. In this cases, please modify your script to meet with new coordinates.
384384

385+
- The :func:`parallel_coordinates` function now takes argument ``color``
386+
instead of ``colors``. A ``FutureWarning`` is raised to alert that
387+
the old ``colors`` argument will not be supported in a future release
388+
389+
- The :func:`parallel_coordinates` and :func:`andrews_curves` functions now take
390+
positional argument ``frame`` instead of ``data``. A ``FutureWarning`` is
391+
raised if the old ``data`` argument is used by name.
392+
385393
.. _whatsnew_0140.prior_deprecations:
386394

387395
Prior Version Deprecations/Changes

‎pandas/tests/test_graphics.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,11 +1220,29 @@ def scat2(x, y, by=None, ax=None, figsize=None):
12201220
def test_andrews_curves(self):
12211221
from pandas import read_csv
12221222
from pandas.tools.plotting import andrews_curves
1223-
1223+
from matplotlib import cm
1224+
12241225
path = os.path.join(curpath(), 'data', 'iris.csv')
12251226
df = read_csv(path)
12261227

12271228
_check_plot_works(andrews_curves, df, 'Name')
1229+
_check_plot_works(andrews_curves, df, 'Name',
1230+
color=('#556270', '#4ECDC4', '#C7F464'))
1231+
_check_plot_works(andrews_curves, df, 'Name',
1232+
color=['dodgerblue', 'aquamarine', 'seagreen'])
1233+
_check_plot_works(andrews_curves, df, 'Name', colormap=cm.jet)
1234+
1235+
colors = ['b', 'g', 'r']
1236+
df = DataFrame({"A": [1, 2, 3],
1237+
"B": [1, 2, 3],
1238+
"C": [1, 2, 3],
1239+
"Name": colors})
1240+
ax = andrews_curves(df, 'Name', color=colors)
1241+
legend_colors = [l.get_color() for l in ax.legend().get_lines()]
1242+
self.assertEqual(colors, legend_colors)
1243+
1244+
with tm.assert_produces_warning(FutureWarning):
1245+
andrews_curves(data=df, class_column='Name')
12281246

12291247
@slow
12301248
def test_parallel_coordinates(self):
@@ -1235,20 +1253,31 @@ def test_parallel_coordinates(self):
12351253
df = read_csv(path)
12361254
_check_plot_works(parallel_coordinates, df, 'Name')
12371255
_check_plot_works(parallel_coordinates, df, 'Name',
1238-
colors=('#556270', '#4ECDC4', '#C7F464'))
1239-
_check_plot_works(parallel_coordinates, df, 'Name',
1240-
colors=['dodgerblue', 'aquamarine', 'seagreen'])
1256+
color=('#556270', '#4ECDC4', '#C7F464'))
12411257
_check_plot_works(parallel_coordinates, df, 'Name',
1242-
colors=('#556270', '#4ECDC4', '#C7F464'))
1243-
_check_plot_works(parallel_coordinates, df, 'Name',
1244-
colors=['dodgerblue', 'aquamarine', 'seagreen'])
1258+
color=['dodgerblue', 'aquamarine', 'seagreen'])
12451259
_check_plot_works(parallel_coordinates, df, 'Name', colormap=cm.jet)
12461260

12471261
df = read_csv(path, header=None, skiprows=1, names=[1, 2, 4, 8,
12481262
'Name'])
12491263
_check_plot_works(parallel_coordinates, df, 'Name', use_columns=True)
12501264
_check_plot_works(parallel_coordinates, df, 'Name',
12511265
xticks=[1, 5, 25, 125])
1266+
1267+
colors = ['b', 'g', 'r']
1268+
df = DataFrame({"A": [1, 2, 3],
1269+
"B": [1, 2, 3],
1270+
"C": [1, 2, 3],
1271+
"Name": colors})
1272+
ax = parallel_coordinates(df, 'Name', color=colors)
1273+
legend_colors = [l.get_color() for l in ax.legend().get_lines()]
1274+
self.assertEqual(colors, legend_colors)
1275+
1276+
with tm.assert_produces_warning(FutureWarning):
1277+
parallel_coordinates(df, 'Name', colors=colors)
1278+
1279+
with tm.assert_produces_warning(FutureWarning):
1280+
parallel_coordinates(data=df, class_column='Name')
12521281

12531282
@slow
12541283
def test_radviz(self):
@@ -1259,8 +1288,24 @@ def test_radviz(self):
12591288
path = os.path.join(curpath(), 'data', 'iris.csv')
12601289
df = read_csv(path)
12611290
_check_plot_works(radviz, df, 'Name')
1291+
_check_plot_works(radviz, df, 'Name',
1292+
color=('#556270', '#4ECDC4', '#C7F464'))
1293+
_check_plot_works(radviz, df, 'Name',
1294+
color=['dodgerblue', 'aquamarine', 'seagreen'])
12621295
_check_plot_works(radviz, df, 'Name', colormap=cm.jet)
12631296

1297+
colors = [[0., 0., 1., 1.],
1298+
[0., 0.5, 1., 1.],
1299+
[1., 0., 0., 1.]]
1300+
df = DataFrame({"A": [1, 2, 3],
1301+
"B": [2, 1, 3],
1302+
"C": [3, 2, 1],
1303+
"Name": ['b', 'g', 'r']})
1304+
ax = radviz(df, 'Name', color=colors)
1305+
legend_colors = [c.get_facecolor().squeeze().tolist()
1306+
for c in ax.collections]
1307+
self.assertEqual(colors, legend_colors)
1308+
12641309
@slow
12651310
def test_plot_int_columns(self):
12661311
df = DataFrame(randn(100, 4)).cumsum()

‎pandas/tools/plotting.py

Lines changed: 69 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import numpy as np
1010

11-
from pandas.util.decorators import cache_readonly
11+
from pandas.util.decorators import cache_readonly, deprecate_kwarg
1212
import pandas.core.common as com
1313
from pandas.core.index import MultiIndex
1414
from pandas.core.series import Series, remove_na
@@ -354,19 +354,22 @@ def _get_marker_compat(marker):
354354
return 'o'
355355
return marker
356356

357-
358-
def radviz(frame, class_column, ax=None, colormap=None, **kwds):
357+
def radviz(frame, class_column, ax=None, color=None, colormap=None, **kwds):
359358
"""RadViz - a multivariate data visualization algorithm
360359
361360
Parameters:
362361
-----------
363-
frame: DataFrame object
364-
class_column: Column name that contains information about class membership
362+
frame: DataFrame
363+
class_column: str
364+
Column name containing class names
365365
ax: Matplotlib axis object, optional
366+
color: list or tuple, optional
367+
Colors to use for the different classes
366368
colormap : str or matplotlib colormap object, default None
367369
Colormap to select colors from. If string, load colormap with that name
368370
from matplotlib.
369-
kwds: Matplotlib scatter method keyword arguments, optional
371+
kwds: keywords
372+
Options to pass to matplotlib scatter plotting method
370373
371374
Returns:
372375
--------
@@ -380,44 +383,42 @@ def normalize(series):
380383
b = max(series)
381384
return (series - a) / (b - a)
382385

383-
column_names = [column_name for column_name in frame.columns
384-
if column_name != class_column]
385-
386-
df = frame[column_names].apply(normalize)
386+
n = len(frame)
387+
classes = frame[class_column].drop_duplicates()
388+
class_col = frame[class_column]
389+
df = frame.drop(class_column, axis=1).apply(normalize)
387390

388391
if ax is None:
389392
ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
390393

391-
classes = set(frame[class_column])
392394
to_plot = {}
393-
394395
colors = _get_standard_colors(num_colors=len(classes), colormap=colormap,
395-
color_type='random', color=kwds.get('color'))
396+
color_type='random', color=color)
396397

397-
for class_ in classes:
398-
to_plot[class_] = [[], []]
398+
for kls in classes:
399+
to_plot[kls] = [[], []]
399400

400401
n = len(frame.columns) - 1
401402
s = np.array([(np.cos(t), np.sin(t))
402403
for t in [2.0 * np.pi * (i / float(n))
403404
for i in range(n)]])
404405

405-
for i in range(len(frame)):
406-
row = df.irow(i).values
406+
for i in range(n):
407+
row = df.iloc[i].values
407408
row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
408409
y = (s * row_).sum(axis=0) / row.sum()
409-
class_name = frame[class_column].iget(i)
410-
to_plot[class_name][0].append(y[0])
411-
to_plot[class_name][1].append(y[1])
410+
kls = class_col.iat[i]
411+
to_plot[kls][0].append(y[0])
412+
to_plot[kls][1].append(y[1])
412413

413-
for i, class_ in enumerate(classes):
414-
ax.scatter(to_plot[class_][0], to_plot[class_][1], color=colors[i],
415-
label=com.pprint_thing(class_), **kwds)
414+
for i, kls in enumerate(classes):
415+
ax.scatter(to_plot[kls][0], to_plot[kls][1], color=colors[i],
416+
label=com.pprint_thing(kls), **kwds)
416417
ax.legend()
417418

418419
ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
419420

420-
for xy, name in zip(s, column_names):
421+
for xy, name in zip(s, df.columns):
421422

422423
ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
423424

@@ -437,21 +438,24 @@ def normalize(series):
437438
ax.axis('equal')
438439
return ax
439440

440-
441-
def andrews_curves(data, class_column, ax=None, samples=200, colormap=None,
442-
**kwds):
441+
@deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
442+
def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
443+
colormap=None, **kwds):
443444
"""
444445
Parameters:
445446
-----------
446-
data : DataFrame
447+
frame : DataFrame
447448
Data to be plotted, preferably normalized to (0.0, 1.0)
448449
class_column : Name of the column containing class names
449450
ax : matplotlib axes object, default None
450451
samples : Number of points to plot in each curve
452+
color: list or tuple, optional
453+
Colors to use for the different classes
451454
colormap : str or matplotlib colormap object, default None
452455
Colormap to select colors from. If string, load colormap with that name
453456
from matplotlib.
454-
kwds : Optional plotting arguments to be passed to matplotlib
457+
kwds: keywords
458+
Options to pass to matplotlib plotting method
455459
456460
Returns:
457461
--------
@@ -475,30 +479,31 @@ def f(x):
475479
return result
476480
return f
477481

478-
n = len(data)
479-
class_col = data[class_column]
480-
uniq_class = class_col.drop_duplicates()
481-
columns = [data[col] for col in data.columns if (col != class_column)]
482+
n = len(frame)
483+
class_col = frame[class_column]
484+
classes = frame[class_column].drop_duplicates()
485+
df = frame.drop(class_column, axis=1)
482486
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
483487
used_legends = set([])
484488

485-
colors = _get_standard_colors(num_colors=len(uniq_class), colormap=colormap,
486-
color_type='random', color=kwds.get('color'))
487-
col_dict = dict([(klass, col) for klass, col in zip(uniq_class, colors)])
489+
color_values = _get_standard_colors(num_colors=len(classes),
490+
colormap=colormap, color_type='random',
491+
color=color)
492+
colors = dict(zip(classes, color_values))
488493
if ax is None:
489494
ax = plt.gca(xlim=(-pi, pi))
490495
for i in range(n):
491-
row = [columns[c][i] for c in range(len(columns))]
496+
row = df.iloc[i].values
492497
f = function(row)
493498
y = [f(t) for t in x]
494-
label = None
495-
if com.pprint_thing(class_col[i]) not in used_legends:
496-
label = com.pprint_thing(class_col[i])
499+
kls = class_col.iat[i]
500+
label = com.pprint_thing(kls)
501+
if label not in used_legends:
497502
used_legends.add(label)
498-
ax.plot(x, y, color=col_dict[class_col[i]], label=label, **kwds)
503+
ax.plot(x, y, color=colors[kls], label=label, **kwds)
499504
else:
500-
ax.plot(x, y, color=col_dict[class_col[i]], **kwds)
501-
505+
ax.plot(x, y, color=colors[kls], **kwds)
506+
502507
ax.legend(loc='upper right')
503508
ax.grid()
504509
return ax
@@ -564,31 +569,32 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
564569
plt.setp(axis.get_yticklabels(), fontsize=8)
565570
return fig
566571

567-
568-
def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
569-
use_columns=False, xticks=None, colormap=None, **kwds):
572+
@deprecate_kwarg(old_arg_name='colors', new_arg_name='color')
573+
@deprecate_kwarg(old_arg_name='data', new_arg_name='frame')
574+
def parallel_coordinates(frame, class_column, cols=None, ax=None, color=None,
575+
use_columns=False, xticks=None, colormap=None,
576+
**kwds):
570577
"""Parallel coordinates plotting.
571578
572579
Parameters
573580
----------
574-
data: DataFrame
575-
A DataFrame containing data to be plotted
581+
frame: DataFrame
576582
class_column: str
577583
Column name containing class names
578584
cols: list, optional
579585
A list of column names to use
580586
ax: matplotlib.axis, optional
581587
matplotlib axis object
582-
colors: list or tuple, optional
588+
color: list or tuple, optional
583589
Colors to use for the different classes
584590
use_columns: bool, optional
585591
If true, columns will be used as xticks
586592
xticks: list or tuple, optional
587593
A list of values to use for xticks
588594
colormap: str or matplotlib colormap, default None
589595
Colormap to use for line colors.
590-
kwds: list, optional
591-
A list of keywords for matplotlib plot method
596+
kwds: keywords
597+
Options to pass to matplotlib plotting method
592598
593599
Returns
594600
-------
@@ -600,20 +606,19 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
600606
>>> from pandas.tools.plotting import parallel_coordinates
601607
>>> from matplotlib import pyplot as plt
602608
>>> df = read_csv('https://raw.github.com/pydata/pandas/master/pandas/tests/data/iris.csv')
603-
>>> parallel_coordinates(df, 'Name', colors=('#556270', '#4ECDC4', '#C7F464'))
609+
>>> parallel_coordinates(df, 'Name', color=('#556270', '#4ECDC4', '#C7F464'))
604610
>>> plt.show()
605611
"""
606612
import matplotlib.pyplot as plt
607613

608-
609-
n = len(data)
610-
classes = set(data[class_column])
611-
class_col = data[class_column]
614+
n = len(frame)
615+
classes = frame[class_column].drop_duplicates()
616+
class_col = frame[class_column]
612617

613618
if cols is None:
614-
df = data.drop(class_column, axis=1)
619+
df = frame.drop(class_column, axis=1)
615620
else:
616-
df = data[cols]
621+
df = frame[cols]
617622

618623
used_legends = set([])
619624

@@ -638,19 +643,17 @@ def parallel_coordinates(data, class_column, cols=None, ax=None, colors=None,
638643

639644
color_values = _get_standard_colors(num_colors=len(classes),
640645
colormap=colormap, color_type='random',
641-
color=colors)
646+
color=color)
642647

643648
colors = dict(zip(classes, color_values))
644649

645650
for i in range(n):
646-
row = df.irow(i).values
647-
y = row
648-
kls = class_col.iget_value(i)
649-
if com.pprint_thing(kls) not in used_legends:
650-
label = com.pprint_thing(kls)
651+
y = df.iloc[i].values
652+
kls = class_col.iat[i]
653+
label = com.pprint_thing(kls)
654+
if label not in used_legends:
651655
used_legends.add(label)
652-
ax.plot(x, y, color=colors[kls],
653-
label=label, **kwds)
656+
ax.plot(x, y, color=colors[kls], label=label, **kwds)
654657
else:
655658
ax.plot(x, y, color=colors[kls], **kwds)
656659

0 commit comments

Comments
 (0)
Please sign in to comment.