Skip to content

Commit 0e2c9bb

Browse files
committed
Added docstring for radviz
1 parent d58a255 commit 0e2c9bb

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

pandas/tools/plotting.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,65 @@ def _get_marker_compat(marker):
147147
return 'o'
148148
return marker
149149

150+
def radviz(frame, class_column, ax=None, **kwds):
151+
"""RadViz - a multivariate data visualization algorithm
152+
153+
Parameters:
154+
-----------
155+
frame: DataFrame object
156+
class_column: Column name that contains information about class membership
157+
ax: Matplotlib axis object, optional
158+
kwds: Matplotlib scatter method keyword arguments, optional
159+
160+
Returns:
161+
--------
162+
ax: Matplotlib axis object
163+
"""
164+
import matplotlib.pyplot as plt
165+
import matplotlib.patches as patches
166+
import matplotlib.text as text
167+
import random
168+
def random_color(column):
169+
random.seed(column)
170+
return [random.random() for _ in range(3)]
171+
def normalize(series):
172+
a = min(series)
173+
b = max(series)
174+
return (series - a) / (b - a)
175+
column_names = [column_name for column_name in frame.columns if column_name != class_column]
176+
columns = [normalize(frame[column_name]) for column_name in column_names]
177+
if ax == None:
178+
ax = plt.gca(xlim=[-1, 1], ylim=[-1, 1])
179+
classes = set(frame[class_column])
180+
to_plot = {}
181+
for class_ in classes:
182+
to_plot[class_] = [[], []]
183+
n = len(frame.columns) - 1
184+
s = np.array([(np.cos(t), np.sin(t)) for t in [2.0 * np.pi * (i / float(n)) for i in range(n)]])
185+
for i in range(len(frame)):
186+
row = np.array([column[i] for column in columns])
187+
row_ = np.repeat(np.expand_dims(row, axis=1), 2, axis=1)
188+
y = (s * row_).sum(axis=0) / row.sum()
189+
class_name = frame[class_column][i]
190+
to_plot[class_name][0].append(y[0])
191+
to_plot[class_name][1].append(y[1])
192+
for class_ in classes:
193+
ax.scatter(to_plot[class_][0], to_plot[class_][1], color=random_color(class_), label=str(class_), **kwds)
194+
ax.add_patch(patches.Circle((0.0, 0.0), radius=1.0, facecolor='none'))
195+
for xy, name in zip(s, column_names):
196+
ax.add_patch(patches.Circle(xy, radius=0.025, facecolor='gray'))
197+
if xy[0] < 0.0 and xy[1] < 0.0:
198+
ax.text(xy[0] - 0.025, xy[1] - 0.025, name, ha='right', va='top', size='small')
199+
elif xy[0] < 0.0 and xy[1] >= 0.0:
200+
ax.text(xy[0] - 0.025, xy[1] + 0.025, name, ha='right', va='bottom', size='small')
201+
elif xy[0] >= 0.0 and xy[1] < 0.0:
202+
ax.text(xy[0] + 0.025, xy[1] - 0.025, name, ha='left', va='top', size='small')
203+
elif xy[0] >= 0.0 and xy[1] >= 0.0:
204+
ax.text(xy[0] + 0.025, xy[1] + 0.025, name, ha='left', va='bottom', size='small')
205+
ax.legend(loc='upper right')
206+
ax.axis('equal')
207+
return ax
208+
150209
def andrews_curves(data, class_column, ax=None, samples=200):
151210
"""
152211
Parameters:

0 commit comments

Comments
 (0)