@@ -147,6 +147,65 @@ def _get_marker_compat(marker):
147
147
return 'o'
148
148
return marker
149
149
150
+ def radviz (frame , class_column , ax = None , ** kwds ):
151
+ """RadViz - a multivariate data visualization algorithm
152
+
153
+ Parameters:
154
+ -----------
155
+ frame: DataFrame object
156
+ class_column: Column name that contains information about class membership
157
+ ax: Matplotlib axis object, optional
158
+ kwds: Matplotlib scatter method keyword arguments, optional
159
+
160
+ Returns:
161
+ --------
162
+ ax: Matplotlib axis object
163
+ """
164
+ import matplotlib .pyplot as plt
165
+ import matplotlib .patches as patches
166
+ import matplotlib .text as text
167
+ import random
168
+ def random_color (column ):
169
+ random .seed (column )
170
+ return [random .random () for _ in range (3 )]
171
+ def normalize (series ):
172
+ a = min (series )
173
+ b = max (series )
174
+ return (series - a ) / (b - a )
175
+ column_names = [column_name for column_name in frame .columns if column_name != class_column ]
176
+ columns = [normalize (frame [column_name ]) for column_name in column_names ]
177
+ if ax == None :
178
+ ax = plt .gca (xlim = [- 1 , 1 ], ylim = [- 1 , 1 ])
179
+ classes = set (frame [class_column ])
180
+ to_plot = {}
181
+ for class_ in classes :
182
+ to_plot [class_ ] = [[], []]
183
+ n = len (frame .columns ) - 1
184
+ s = np .array ([(np .cos (t ), np .sin (t )) for t in [2.0 * np .pi * (i / float (n )) for i in range (n )]])
185
+ for i in range (len (frame )):
186
+ row = np .array ([column [i ] for column in columns ])
187
+ row_ = np .repeat (np .expand_dims (row , axis = 1 ), 2 , axis = 1 )
188
+ y = (s * row_ ).sum (axis = 0 ) / row .sum ()
189
+ class_name = frame [class_column ][i ]
190
+ to_plot [class_name ][0 ].append (y [0 ])
191
+ to_plot [class_name ][1 ].append (y [1 ])
192
+ for class_ in classes :
193
+ ax .scatter (to_plot [class_ ][0 ], to_plot [class_ ][1 ], color = random_color (class_ ), label = str (class_ ), ** kwds )
194
+ ax .add_patch (patches .Circle ((0.0 , 0.0 ), radius = 1.0 , facecolor = 'none' ))
195
+ for xy , name in zip (s , column_names ):
196
+ ax .add_patch (patches .Circle (xy , radius = 0.025 , facecolor = 'gray' ))
197
+ if xy [0 ] < 0.0 and xy [1 ] < 0.0 :
198
+ ax .text (xy [0 ] - 0.025 , xy [1 ] - 0.025 , name , ha = 'right' , va = 'top' , size = 'small' )
199
+ elif xy [0 ] < 0.0 and xy [1 ] >= 0.0 :
200
+ ax .text (xy [0 ] - 0.025 , xy [1 ] + 0.025 , name , ha = 'right' , va = 'bottom' , size = 'small' )
201
+ elif xy [0 ] >= 0.0 and xy [1 ] < 0.0 :
202
+ ax .text (xy [0 ] + 0.025 , xy [1 ] - 0.025 , name , ha = 'left' , va = 'top' , size = 'small' )
203
+ elif xy [0 ] >= 0.0 and xy [1 ] >= 0.0 :
204
+ ax .text (xy [0 ] + 0.025 , xy [1 ] + 0.025 , name , ha = 'left' , va = 'bottom' , size = 'small' )
205
+ ax .legend (loc = 'upper right' )
206
+ ax .axis ('equal' )
207
+ return ax
208
+
150
209
def andrews_curves (data , class_column , ax = None , samples = 200 ):
151
210
"""
152
211
Parameters:
0 commit comments