Skip to content

Commit 4759d31

Browse files
committed
Initial commit of clustered heatmap
1 parent 286811a commit 4759d31

File tree

1 file changed

+279
-0
lines changed

1 file changed

+279
-0
lines changed

pandas/tools/plotting.py

+279
Original file line numberDiff line numberDiff line change
@@ -2488,6 +2488,285 @@ def _maybe_convert_date(x):
24882488
x = conv_func(x)
24892489
return x
24902490

2491+
# helper for cleaning up axes by removing ticks, tick labels, frame, etc.
2492+
def _clean_axis(ax):
2493+
"""Remove ticks, tick labels, and frame from axis"""
2494+
ax.get_xaxis().set_ticks([])
2495+
ax.get_yaxis().set_ticks([])
2496+
for sp in ax.spines.values():
2497+
sp.set_visible(False)
2498+
2499+
2500+
def _color_list_to_matrix_and_cmap(color_list, ind, row=True):
2501+
"""
2502+
For 'heatmap()'
2503+
This only works for 1-column color lists..
2504+
TODO: Support multiple color labels on an element in the heatmap
2505+
"""
2506+
import matplotlib as mpl
2507+
2508+
colors = set(color_list)
2509+
col_to_value = {col: i for i, col in enumerate(colors)}
2510+
2511+
# ind = column_dendrogram_distances['leaves']
2512+
matrix = np.array([col_to_value[col] for col in color_list])[ind]
2513+
print 'matrix.shape', matrix.shape,
2514+
print 'len(color_list)', len(color_list)
2515+
# Is this row-side or column side?
2516+
if row:
2517+
new_shape = (len(color_list), 1)
2518+
else:
2519+
new_shape = (1, len(color_list))
2520+
matrix = matrix.reshape(new_shape)
2521+
2522+
cmap = mpl.colors.ListedColormap(colors)
2523+
return matrix, cmap
2524+
2525+
2526+
2527+
2528+
2529+
def heatmap(df, title=None, colorbar_label='values',
2530+
col_side_colors=None, row_side_colors=None,
2531+
color_scale='linear', cmap=None,
2532+
row_linkage_method='complete',
2533+
col_linkage_method='complete',
2534+
figsize=None,
2535+
label_rows=True,
2536+
label_cols=True,
2537+
2538+
#col_labels=None,
2539+
#row_labels=None,
2540+
2541+
xlabel_fontsize=12,
2542+
ylabel_fontsize=10,
2543+
cluster_cols=True,
2544+
cluster_rows=True,
2545+
plot_df=None):
2546+
2547+
2548+
"""
2549+
2550+
@author Olga Botvinnik [email protected]
2551+
2552+
@param df:
2553+
@param title:
2554+
@param colorbar_label:
2555+
@param col_side_colors:
2556+
@param row_side_colors:
2557+
@param color_scale:
2558+
@param cmap:
2559+
@param figsize:
2560+
@param label_rows: Can be boolean or a list of strings, with exactly the
2561+
length of the number of rows in df.
2562+
@param label_cols: Can be boolean or a list of strings, with exactly the
2563+
length of the number of columns in df.
2564+
@param col_labels:
2565+
@param row_labels:
2566+
@param xlabel_fontsize:
2567+
@param ylabel_fontsize:
2568+
@param cluster_cols:
2569+
@param cluster_rows:
2570+
@param plot_df:
2571+
@return: @rtype: @raise TypeError:
2572+
"""
2573+
import matplotlib.pyplot as plt
2574+
import matplotlib.gridspec as gridspec
2575+
import scipy.spatial.distance as distance
2576+
import scipy.cluster.hierarchy as sch
2577+
import matplotlib as mpl
2578+
from collections import Iterable
2579+
2580+
almost_black = '#262626'
2581+
sch.set_link_color_palette([almost_black])
2582+
if type(plot_df) is None:
2583+
plot_df = df
2584+
2585+
if any(plot_df.index != df.index):
2586+
raise TypeError('plot_df must have the exact same indices as df')
2587+
if any(plot_df.columns != df.columns):
2588+
raise TypeError('plot_df must have the exact same columns as df')
2589+
# make norm
2590+
divergent = df.max().max() > 0 and df.min().min() < 0
2591+
2592+
if color_scale == 'log':
2593+
vmin = max(np.floor(df.dropna(how='all').min().dropna().min()), 1e-10)
2594+
vmax = np.ceil(df.dropna(how='all').max().dropna().max())
2595+
my_norm = mpl.colors.LogNorm(vmin, vmax)
2596+
print 'vmax', vmax
2597+
print 'vmin', vmin
2598+
elif divergent:
2599+
abs_max = abs(df.max().max())
2600+
abs_min = abs(df.min().min())
2601+
vmax = max(abs_max, abs_min)
2602+
my_norm = mpl.colors.Normalize(vmin=-vmax, vmax=vmax)
2603+
else:
2604+
my_norm = None
2605+
2606+
if cmap is None:
2607+
cmap = mpl.cm.RdBu_r if divergent else mpl.cm.Blues_r
2608+
cmap.set_bad('white')
2609+
2610+
# calculate pairwise distances for rows
2611+
row_pairwise_dists = distance.squareform(distance.pdist(df))
2612+
row_clusters = sch.linkage(row_pairwise_dists, method=row_linkage_method)
2613+
2614+
# calculate pairwise distances for columns
2615+
col_pairwise_dists = distance.squareform(distance.pdist(df.T))
2616+
# cluster
2617+
col_clusters = sch.linkage(col_pairwise_dists, method=col_linkage_method)
2618+
2619+
# heatmap with row names
2620+
dendrogram_height_fraction = df.shape[0] * 0.25 / df.shape[0]
2621+
dendrogram_width_fraction = df.shape[1] * 0.25 / df.shape[1]
2622+
width_ratios = [dendrogram_width_fraction, 1] \
2623+
if row_side_colors is None else [dendrogram_width_fraction, 0.05, 1]
2624+
height_ratios = [dendrogram_height_fraction, 1] \
2625+
if col_side_colors is None else [dendrogram_height_fraction, 0.05, 1]
2626+
nrows = 2 if col_side_colors is None else 3
2627+
ncols = 2 if row_side_colors is None else 3
2628+
2629+
print 'width_ratios', width_ratios
2630+
print 'height_ratios', height_ratios
2631+
2632+
width = df.shape[1] * 0.25
2633+
height = min(df.shape[0] * .75, 40)
2634+
if figsize is None:
2635+
figsize = (width, height)
2636+
print figsize
2637+
2638+
fig = plt.figure(figsize=figsize)
2639+
heatmap_gridspec = \
2640+
gridspec.GridSpec(nrows, ncols, wspace=0.0, hspace=0.0,
2641+
width_ratios=width_ratios,
2642+
height_ratios=height_ratios)
2643+
# print heatmap_gridspec
2644+
2645+
### col dendrogram ###
2646+
column_dendrogram_ax = fig.add_subplot(heatmap_gridspec[0, ncols - 1])
2647+
if cluster_cols:
2648+
column_dendrogram_distances = sch.dendrogram(col_clusters,
2649+
color_threshold=np.inf,
2650+
color_list=[
2651+
ppl.almost_black])
2652+
else:
2653+
column_dendrogram_distances = {'leaves': range(df.shape[1])}
2654+
_clean_axis(column_dendrogram_ax)
2655+
2656+
### col colorbar ###
2657+
if col_side_colors is not None:
2658+
column_colorbar_ax = fig.add_subplot(heatmap_gridspec[1, ncols - 1])
2659+
col_side_matrix, col_cmap = _color_list_to_matrix_and_cmap(
2660+
col_side_colors,
2661+
ind=column_dendrogram_distances['leaves'],
2662+
row=False)
2663+
column_colorbar_ax_pcolormesh = column_colorbar_ax.pcolormesh(
2664+
col_side_matrix, cmap=col_cmap,
2665+
edgecolors='white', linewidth=0.1)
2666+
column_colorbar_ax.set_xlim(0, col_side_matrix.shape[1])
2667+
_clean_axis(column_colorbar_ax)
2668+
2669+
### row dendrogram ###
2670+
row_dendrogram_ax = fig.add_subplot(heatmap_gridspec[nrows - 1, 0])
2671+
if cluster_rows:
2672+
row_dendrogram_distances = \
2673+
sch.dendrogram(row_clusters,
2674+
color_threshold=np.inf,
2675+
orientation='right',
2676+
color_list=[ppl.almost_black])
2677+
else:
2678+
row_dendrogram_distances = {'leaves': range(df.shape[0])}
2679+
_clean_axis(row_dendrogram_ax)
2680+
2681+
### row colorbar ###
2682+
if row_side_colors is not None:
2683+
row_colorbar_ax = fig.add_subplot(heatmap_gridspec[nrows - 1, 1])
2684+
row_side_matrix, row_cmap = _color_list_to_matrix_and_cmap(
2685+
row_side_colors,
2686+
ind=row_dendrogram_distances['leaves'],
2687+
row=True)
2688+
row_colorbar_ax_pcolormesh = row_colorbar_ax.pcolormesh(row_side_matrix,
2689+
cmap=row_cmap,
2690+
edgecolors='white',
2691+
linewidth=0.1)
2692+
row_colorbar_ax.set_ylim(0, row_side_matrix.shape[0])
2693+
_clean_axis(row_colorbar_ax)
2694+
2695+
### heatmap ####
2696+
heatmap_ax = fig.add_subplot(heatmap_gridspec[nrows - 1, ncols - 1])
2697+
heatmap_ax_pcolormesh = \
2698+
heatmap_ax.pcolormesh(plot_df.ix[row_dendrogram_distances[
2699+
'leaves'],
2700+
column_dendrogram_distances[
2701+
'leaves']].values,
2702+
norm=my_norm, cmap=cmap)
2703+
heatmap_ax.set_ylim(0, df.shape[0])
2704+
heatmap_ax.set_xlim(0, df.shape[1])
2705+
_clean_axis(heatmap_ax)
2706+
2707+
## row labels ##
2708+
if isinstance(label_rows, Iterable):
2709+
if len(label_rows) == df.shape[0]:
2710+
yticklabels = label_rows
2711+
label_rows = True
2712+
else:
2713+
raise BaseException("Length of 'label_rows' must be the same as "
2714+
"df.shape[0]")
2715+
elif label_rows:
2716+
yticklabels = df.index[row_dendrogram_distances['leaves']]
2717+
2718+
if label_rows:
2719+
heatmap_ax.set_yticks(np.arange(df.shape[0]) + 0.5)
2720+
heatmap_ax.yaxis.set_ticks_position('right')
2721+
heatmap_ax.set_yticklabels(yticklabels, fontsize=ylabel_fontsize)
2722+
2723+
# Add title if there is one:
2724+
if title is not None:
2725+
heatmap_ax.set_title(title)
2726+
2727+
## col labels ##
2728+
if isinstance(label_cols, Iterable):
2729+
if len(label_cols) == df.shape[0]:
2730+
xticklabels = label_rows
2731+
label_cols = True
2732+
else:
2733+
raise BaseException("Length of 'label_cols' must be the same as "
2734+
"df.shape[1]")
2735+
elif label_rows:
2736+
xticklabels = df.columns[column_dendrogram_distances['leaves']]
2737+
2738+
if label_cols:
2739+
heatmap_ax.set_xticks(np.arange(df.shape[1]) + 0.5)
2740+
xticklabels = heatmap_ax.set_xticklabels(xticklabels,
2741+
fontsize=xlabel_fontsize)
2742+
# rotate labels 90 degrees
2743+
for label in xticklabels:
2744+
label.set_rotation(90)
2745+
2746+
# remove the tick lines
2747+
for l in heatmap_ax.get_xticklines() + heatmap_ax.get_yticklines():
2748+
l.set_markersize(0)
2749+
2750+
### scale colorbar ###
2751+
scale_colorbar_ax = fig.add_subplot(
2752+
heatmap_gridspec[0:(nrows - 1),
2753+
0]) # colorbar for scale in upper left corner
2754+
cb = fig.colorbar(heatmap_ax_pcolormesh,
2755+
cax=scale_colorbar_ax) # note that we could pass the norm explicitly with norm=my_norm
2756+
cb.set_label(colorbar_label)
2757+
cb.ax.yaxis.set_ticks_position(
2758+
'left') # move ticks to left side of colorbar to avoid problems with tight_layout
2759+
cb.ax.yaxis.set_label_position(
2760+
'left') # move label to left side of colorbar to avoid problems with tight_layout
2761+
cb.outline.set_linewidth(0)
2762+
# make colorbar labels smaller
2763+
yticklabels = cb.ax.yaxis.get_ticklabels()
2764+
for t in yticklabels:
2765+
t.set_fontsize(t.get_fontsize() - 3)
2766+
2767+
fig.tight_layout()
2768+
return fig, row_dendrogram_distances, column_dendrogram_distances
2769+
24912770
if __name__ == '__main__':
24922771
# import pandas.rpy.common as com
24932772
# sales = com.load_data('sanfrancisco.home.sales', package='nutshell')

0 commit comments

Comments
 (0)