Skip to content

Commit cbaae56

Browse files
author
Kyle
committed
PERF: Updated andrews_curves to use a Numpy array for samples
DOC: Added some documentation to andrews_curves TST: Added a variable length test to TestDataFramePlots.test_andrews_curves
1 parent 091df3e commit cbaae56

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

pandas/tests/test_graphics_others.py

+20
Original file line numberDiff line numberDiff line change
@@ -463,6 +463,26 @@ def test_andrews_curves(self):
463463
cmaps = lmap(cm.jet, np.linspace(0, 1, df['Name'].nunique()))
464464
self._check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df['Name'][:10])
465465

466+
length = 10
467+
df = DataFrame({"A": random.rand(length),
468+
"B": random.rand(length),
469+
"C": random.rand(length),
470+
"Name": ["A"] * length})
471+
472+
_check_plot_works(andrews_curves, frame=df, class_column='Name')
473+
474+
rgba = ('#556270', '#4ECDC4', '#C7F464')
475+
ax = _check_plot_works(andrews_curves, frame=df, class_column='Name', color=rgba)
476+
self._check_colors(ax.get_lines()[:10], linecolors=rgba, mapping=df['Name'][:10])
477+
478+
cnames = ['dodgerblue', 'aquamarine', 'seagreen']
479+
ax = _check_plot_works(andrews_curves, frame=df, class_column='Name', color=cnames)
480+
self._check_colors(ax.get_lines()[:10], linecolors=cnames, mapping=df['Name'][:10])
481+
482+
ax = _check_plot_works(andrews_curves, frame=df, class_column='Name', colormap=cm.jet)
483+
cmaps = lmap(cm.jet, np.linspace(0, 1, df['Name'].nunique()))
484+
self._check_colors(ax.get_lines()[:10], linecolors=cmaps, mapping=df['Name'][:10])
485+
466486
colors = ['b', 'g', 'r']
467487
df = DataFrame({"A": [1, 2, 3],
468488
"B": [1, 2, 3],

pandas/tools/plotting.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,15 @@ def normalize(series):
507507
def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
508508
colormap=None, **kwds):
509509
"""
510+
Generates a matplotlib plot of Andrews curves, for visualising clusters of multivariate data.
511+
512+
Andrews curves have the functional form:
513+
514+
f(t) = x_1/sqrt(2) + x_2 sin(t) + x_3 cos(t) + x_4 sin(2t) + x_5 cos(2t) + ...
515+
516+
Where x coefficients correspond to the values of each dimension and t is linearly spaced between -pi and +pi. Each
517+
row of frame then corresponds to a single curve.
518+
510519
Parameters:
511520
-----------
512521
frame : DataFrame
@@ -527,28 +536,28 @@ def andrews_curves(frame, class_column, ax=None, samples=200, color=None,
527536
ax: Matplotlib axis object
528537
529538
"""
530-
from math import sqrt, pi, sin, cos
539+
from math import sqrt, pi
531540
import matplotlib.pyplot as plt
532541

533542
def function(amplitudes):
534-
def f(x):
543+
def f(t):
535544
x1 = amplitudes[0]
536545
result = x1 / sqrt(2.0)
537546
harmonic = 1.0
538547
for x_even, x_odd in zip(amplitudes[1::2], amplitudes[2::2]):
539-
result += (x_even * sin(harmonic * x) +
540-
x_odd * cos(harmonic * x))
548+
result += (x_even * np.sin(harmonic * t) +
549+
x_odd * np.cos(harmonic * t))
541550
harmonic += 1.0
542551
if len(amplitudes) % 2 != 0:
543-
result += amplitudes[-1] * sin(harmonic * x)
552+
result += amplitudes[-1] * np.sin(harmonic * t)
544553
return result
545554
return f
546555

547556
n = len(frame)
548557
class_col = frame[class_column]
549558
classes = frame[class_column].drop_duplicates()
550559
df = frame.drop(class_column, axis=1)
551-
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
560+
t = np.linspace(-pi, pi, samples)
552561
used_legends = set([])
553562

554563
color_values = _get_standard_colors(num_colors=len(classes),
@@ -560,14 +569,14 @@ def f(x):
560569
for i in range(n):
561570
row = df.iloc[i].values
562571
f = function(row)
563-
y = [f(t) for t in x]
572+
y = f(t)
564573
kls = class_col.iat[i]
565574
label = com.pprint_thing(kls)
566575
if label not in used_legends:
567576
used_legends.add(label)
568-
ax.plot(x, y, color=colors[kls], label=label, **kwds)
577+
ax.plot(t, y, color=colors[kls], label=label, **kwds)
569578
else:
570-
ax.plot(x, y, color=colors[kls], **kwds)
579+
ax.plot(t, y, color=colors[kls], **kwds)
571580

572581
ax.legend(loc='upper right')
573582
ax.grid()

0 commit comments

Comments
 (0)