@@ -416,21 +416,40 @@ def bootstrap_plot(series, fig=None, size=50, samples=500, **kwds):
416
416
417
417
418
418
def parallel_coordinates (data , class_column , cols = None , ax = None , colors = None ,
419
- ** kwds ):
419
+ use_columns = False , xticks = None , ** kwds ):
420
420
"""Parallel coordinates plotting.
421
421
422
- Parameters:
423
- -----------
424
- data: A DataFrame containing data to be plotted
425
- class_column: Column name containing class names
426
- cols: A list of column names to use, optional
427
- ax: matplotlib axis object, optional
428
- colors: A list or tuple of colors to use for the different classes, optional
429
- kwds: A list of keywords for matplotlib plot method
422
+ Parameters
423
+ ----------
424
+ data: DataFrame
425
+ A DataFrame containing data to be plotted
426
+ class_column: str
427
+ Column name containing class names
428
+ cols: list, optional
429
+ A list of column names to use
430
+ ax: matplotlib.axis, optional
431
+ matplotlib axis object
432
+ colors: list or tuple, optional
433
+ Colors to use for the different classes
434
+ use_columns: bool, optional
435
+ If true, columns will be used as xticks
436
+ xticks: list or tuple, optional
437
+ A list of values to use for xticks
438
+ kwds: list, optional
439
+ A list of keywords for matplotlib plot method
430
440
431
- Returns:
432
- --------
441
+ Returns
442
+ -------
433
443
ax: matplotlib axis object
444
+
445
+ Examples
446
+ --------
447
+ >>> from pandas import read_csv
448
+ >>> from pandas.tools.plotting import parallel_coordinates
449
+ >>> from matplotlib import pyplot as plt
450
+ >>> df = read_csv('https://raw.github.com/pydata/pandas/master/pandas/tests/data/iris.csv')
451
+ >>> parallel_coordinates(df, 'Name', colors=('#556270', '#4ECDC4', '#C7F464'))
452
+ >>> plt.show()
434
453
"""
435
454
import matplotlib .pyplot as plt
436
455
import random
@@ -450,7 +469,20 @@ def random_color(column):
450
469
used_legends = set ([])
451
470
452
471
ncols = len (df .columns )
453
- x = range (ncols )
472
+
473
+ # determine values to use for xticks
474
+ if use_columns is True :
475
+ if not np .all (np .isreal (list (df .columns ))):
476
+ raise ValueError ('Columns must be numeric to be used as xticks' )
477
+ x = df .columns
478
+ elif xticks is not None :
479
+ if not np .all (np .isreal (xticks )):
480
+ raise ValueError ('xticks specified must be numeric' )
481
+ elif len (xticks ) != ncols :
482
+ raise ValueError ('Length of xticks must match number of columns' )
483
+ x = xticks
484
+ else :
485
+ x = range (ncols )
454
486
455
487
if ax == None :
456
488
ax = plt .gca ()
@@ -475,11 +507,12 @@ def random_color(column):
475
507
else :
476
508
ax .plot (x , y , color = colors [kls ], ** kwds )
477
509
478
- for i in range ( ncols ) :
510
+ for i in x :
479
511
ax .axvline (i , linewidth = 1 , color = 'black' )
480
512
481
513
ax .set_xticks (x )
482
514
ax .set_xticklabels (df .columns )
515
+ ax .set_xlim (x [0 ], x [- 1 ])
483
516
ax .legend (loc = 'upper right' )
484
517
ax .grid ()
485
518
return ax
0 commit comments