@@ -129,6 +129,50 @@ def _gcf():
129
129
import matplotlib .pyplot as plt
130
130
return plt .gcf ()
131
131
132
+ def andrews_curves (data , class_column , samples = 200 ):
133
+ """
134
+ Parameters:
135
+ data: A DataFrame containing data to be plotted, preferably
136
+ normalized to (0.0, 1.0).
137
+ class_column: Name of the column containing class names.
138
+ samples: Number of points to plot in each curve.
139
+ """
140
+ from math import sqrt , pi , sin , cos
141
+ import matplotlib .pyplot as plt
142
+ import random
143
+ def function (amplitudes ):
144
+ def f (x ):
145
+ x1 = amplitudes [0 ]
146
+ result = x1 / sqrt (2.0 )
147
+ harmonic = 1.0
148
+ for x_even , x_odd in zip (amplitudes [1 ::2 ], amplitudes [2 ::2 ]):
149
+ result += (x_even * sin (harmonic * x ) +
150
+ x_odd * cos (harmonic * x ))
151
+ harmonic += 1.0
152
+ return result
153
+ return f
154
+ def random_color (column ):
155
+ random .seed (column )
156
+ return [random .random () for _ in range (3 )]
157
+ n = len (data )
158
+ classes = set (data [class_column ])
159
+ class_col = data [class_column ]
160
+ columns = [data [col ] for col in data .columns if (col != class_column )]
161
+ x = [- pi + 2.0 * pi * (t / float (samples )) for t in range (samples )]
162
+ used_legends = set ([])
163
+ for i in range (n ):
164
+ row = [columns [c ][i ] for c in range (len (columns ))]
165
+ f = function (row )
166
+ y = [f (t ) for t in x ]
167
+ label = None
168
+ if class_col [i ] not in used_legends :
169
+ label = class_col [i ]
170
+ used_legends .add (class_col [i ])
171
+ plt .plot (x , y , color = random_color (class_col [i ]), label = label )
172
+ plt .xlim (xmin = - pi , xmax = pi )
173
+ plt .legend (loc = 'upper right' )
174
+ plt .grid ()
175
+
132
176
def grouped_hist (data , column = None , by = None , ax = None , bins = 50 , log = False ,
133
177
figsize = None , layout = None , sharex = False , sharey = False ,
134
178
rot = 90 ):
0 commit comments