diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index c4255e706b19f..ba5ae3b0cb52c 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -21,6 +21,7 @@ try: # mpl optional import pandas.tseries.converter as conv + conv.register() # needs to override so set_xlim works with str/number except ImportError: pass @@ -30,70 +31,72 @@ # to True. mpl_stylesheet = { 'axes.axisbelow': True, - 'axes.color_cycle': ['#348ABD', - '#7A68A6', - '#A60628', - '#467821', - '#CF4457', - '#188487', - '#E24A33'], - 'axes.edgecolor': '#bcbcbc', - 'axes.facecolor': '#eeeeee', - 'axes.grid': True, - 'axes.labelcolor': '#555555', - 'axes.labelsize': 'large', - 'axes.linewidth': 1.0, - 'axes.titlesize': 'x-large', - 'figure.edgecolor': 'white', - 'figure.facecolor': 'white', - 'figure.figsize': (6.0, 4.0), - 'figure.subplot.hspace': 0.5, - 'font.family': 'monospace', - 'font.monospace': ['Andale Mono', - 'Nimbus Mono L', - 'Courier New', - 'Courier', - 'Fixed', - 'Terminal', - 'monospace'], - 'font.size': 10, - 'interactive': True, - 'keymap.all_axes': ['a'], - 'keymap.back': ['left', 'c', 'backspace'], - 'keymap.forward': ['right', 'v'], - 'keymap.fullscreen': ['f'], - 'keymap.grid': ['g'], - 'keymap.home': ['h', 'r', 'home'], - 'keymap.pan': ['p'], - 'keymap.save': ['s'], - 'keymap.xscale': ['L', 'k'], - 'keymap.yscale': ['l'], - 'keymap.zoom': ['o'], - 'legend.fancybox': True, - 'lines.antialiased': True, - 'lines.linewidth': 1.0, - 'patch.antialiased': True, - 'patch.edgecolor': '#EEEEEE', - 'patch.facecolor': '#348ABD', - 'patch.linewidth': 0.5, - 'toolbar': 'toolbar2', - 'xtick.color': '#555555', - 'xtick.direction': 'in', - 'xtick.major.pad': 6.0, - 'xtick.major.size': 0.0, - 'xtick.minor.pad': 6.0, - 'xtick.minor.size': 0.0, - 'ytick.color': '#555555', - 'ytick.direction': 'in', - 'ytick.major.pad': 6.0, - 'ytick.major.size': 0.0, - 'ytick.minor.pad': 6.0, - 'ytick.minor.size': 0.0 + 'axes.color_cycle': ['#348ABD', + '#7A68A6', + '#A60628', + '#467821', + '#CF4457', + '#188487', + '#E24A33'], + 'axes.edgecolor': '#bcbcbc', + 'axes.facecolor': '#eeeeee', + 'axes.grid': True, + 'axes.labelcolor': '#555555', + 'axes.labelsize': 'large', + 'axes.linewidth': 1.0, + 'axes.titlesize': 'x-large', + 'figure.edgecolor': 'white', + 'figure.facecolor': 'white', + 'figure.figsize': (6.0, 4.0), + 'figure.subplot.hspace': 0.5, + 'font.family': 'monospace', + 'font.monospace': ['Andale Mono', + 'Nimbus Mono L', + 'Courier New', + 'Courier', + 'Fixed', + 'Terminal', + 'monospace'], + 'font.size': 10, + 'interactive': True, + 'keymap.all_axes': ['a'], + 'keymap.back': ['left', 'c', 'backspace'], + 'keymap.forward': ['right', 'v'], + 'keymap.fullscreen': ['f'], + 'keymap.grid': ['g'], + 'keymap.home': ['h', 'r', 'home'], + 'keymap.pan': ['p'], + 'keymap.save': ['s'], + 'keymap.xscale': ['L', 'k'], + 'keymap.yscale': ['l'], + 'keymap.zoom': ['o'], + 'legend.fancybox': True, + 'lines.antialiased': True, + 'lines.linewidth': 1.0, + 'patch.antialiased': True, + 'patch.edgecolor': '#EEEEEE', + 'patch.facecolor': '#348ABD', + 'patch.linewidth': 0.5, + 'toolbar': 'toolbar2', + 'xtick.color': '#555555', + 'xtick.direction': 'in', + 'xtick.major.pad': 6.0, + 'xtick.major.size': 0.0, + 'xtick.minor.pad': 6.0, + 'xtick.minor.size': 0.0, + 'ytick.color': '#555555', + 'ytick.direction': 'in', + 'ytick.major.pad': 6.0, + 'ytick.major.size': 0.0, + 'ytick.minor.pad': 6.0, + 'ytick.minor.size': 0.0 } + def _get_standard_kind(kind): return {'density': 'kde'}.get(kind, kind) + def _get_standard_colors(num_colors=None, colormap=None, color_type='default', color=None): import matplotlib.pyplot as plt @@ -101,6 +104,7 @@ def _get_standard_colors(num_colors=None, colormap=None, color_type='default', if color is None and colormap is not None: if isinstance(colormap, compat.string_types): import matplotlib.cm as cm + cmap = colormap colormap = cm.get_cmap(colormap) if colormap is None: @@ -118,6 +122,7 @@ def _get_standard_colors(num_colors=None, colormap=None, color_type='default', colors = list(colors) elif color_type == 'random': import random + def random_color(column): random.seed(column) return [random.random() for _ in range(3)] @@ -127,7 +132,7 @@ def random_color(column): raise NotImplementedError if len(colors) != num_colors: - multiple = num_colors//len(colors) - 1 + multiple = num_colors // len(colors) - 1 mod = num_colors % len(colors) colors += multiple * colors @@ -135,6 +140,7 @@ def random_color(column): return colors + class _Options(dict): """ Stores pandas plotting options. @@ -262,6 +268,7 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, ax.hist(values, **hist_kwds) elif diagonal in ('kde', 'density'): from scipy.stats import gaussian_kde + y = values gkde = gaussian_kde(y) ind = np.linspace(y.min(), y.max(), 1000) @@ -279,9 +286,9 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, _label_axis(ax, kind='y', label=a, position='left') - if j!= 0: + if j != 0: ax.yaxis.set_visible(False) - if i != n-1: + if i != n - 1: ax.xaxis.set_visible(False) for ax in axes.flat: @@ -290,10 +297,11 @@ def scatter_matrix(frame, alpha=0.5, figsize=None, ax=None, grid=False, return axes -def _label_axis(ax, kind='x', label='', position='top', - ticks=True, rotate=False): +def _label_axis(ax, kind='x', label='', position='top', + ticks=True, rotate=False): from matplotlib.artist import setp + if kind == 'x': ax.set_xlabel(label, visible=True) ax.xaxis.set_visible(True) @@ -310,21 +318,22 @@ def _label_axis(ax, kind='x', label='', position='top', return - - - def _gca(): import matplotlib.pyplot as plt + return plt.gca() def _gcf(): import matplotlib.pyplot as plt + return plt.gcf() + def _get_marker_compat(marker): import matplotlib.lines as mlines import matplotlib as mpl + if mpl.__version__ < '1.1.0' and marker == '.': return 'o' if marker not in mlines.lineMarkers: @@ -450,6 +459,7 @@ def f(x): if len(amplitudes) % 2 != 0: result += amplitudes[-1] * sin(harmonic * x) return result + return f n = len(data) @@ -685,6 +695,7 @@ def autocorrelation_plot(series, ax=None): ax: Matplotlib axis object """ import matplotlib.pyplot as plt + n = len(series) data = np.asarray(series) if ax is None: @@ -694,6 +705,7 @@ def autocorrelation_plot(series, ax=None): def r(h): return ((data[:n - h] - mean) * (data[h:] - mean)).sum() / float(n) / c0 + x = np.arange(n) + 1 y = lmap(r, x) z95 = 1.959963984540054 @@ -735,6 +747,7 @@ def grouped_hist(data, column=None, by=None, ax=None, bins=50, figsize=None, ------- axes: collection of Matplotlib Axes """ + def plot_group(group, ax): ax.hist(group.dropna().values, bins=bins, **kwargs) @@ -816,6 +829,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True, def _validate_color_args(self): from pandas import DataFrame + if 'color' not in self.kwds and 'colors' in self.kwds: warnings.warn(("'colors' is being deprecated. Please use 'color'" "instead of 'colors'")) @@ -823,13 +837,14 @@ def _validate_color_args(self): self.kwds['color'] = colors if ('color' in self.kwds and - (isinstance(self.data, Series) or - isinstance(self.data, DataFrame) and len(self.data.columns) == 1)): + (isinstance(self.data, Series) or + isinstance(self.data, DataFrame) and len( + self.data.columns) == 1)): # support series.plot(color='green') self.kwds['color'] = [self.kwds['color']] if ('color' in self.kwds or 'colors' in self.kwds) and \ - self.colormap is not None: + self.colormap is not None: warnings.warn("'color' and 'colormap' cannot be used " "simultaneously. Using 'color'") @@ -843,6 +858,7 @@ def _validate_color_args(self): def _iter_data(self): from pandas.core.frame import DataFrame + if isinstance(self.data, (Series, np.ndarray)): yield self.label, np.asarray(self.data) elif isinstance(self.data, DataFrame): @@ -1017,6 +1033,7 @@ def legend_title(self): @cache_readonly def plt(self): import matplotlib.pyplot as plt + return plt _need_to_set_index = False @@ -1099,6 +1116,7 @@ def _get_ax(self, i): def on_right(self, i): from pandas.core.frame import DataFrame + if isinstance(self.secondary_y, bool): return self.secondary_y @@ -1126,6 +1144,7 @@ def _get_style(self, i, col_name): def _get_colors(self): from pandas.core.frame import DataFrame + if isinstance(self.data, DataFrame): num_colors = len(self.data.columns) else: @@ -1150,13 +1169,14 @@ def _get_marked_label(self, label, col_num): class KdePlot(MPLPlot): def __init__(self, data, bw_method=None, ind=None, **kwargs): MPLPlot.__init__(self, data, **kwargs) - self.bw_method=bw_method - self.ind=ind + self.bw_method = bw_method + self.ind = ind def _make_plot(self): from scipy.stats import gaussian_kde from scipy import __version__ as spv from distutils.version import LooseVersion + plotf = self._get_plot_function() colors = self._get_colors() for i, (label, y) in enumerate(self._iter_data()): @@ -1201,12 +1221,13 @@ def _post_plot_logic(self): for ax in self.axes: ax.legend(loc='best') + class ScatterPlot(MPLPlot): def __init__(self, data, x, y, **kwargs): MPLPlot.__init__(self, data, **kwargs) self.kwds.setdefault('c', self.plt.rcParams['patch.facecolor']) if x is None or y is None: - raise ValueError( 'scatter requires and x and y column') + raise ValueError('scatter requires and x and y column') if com.is_integer(x) and not self.data.columns.holds_integer(): x = self.data.columns[x] if com.is_integer(y) and not self.data.columns.holds_integer(): @@ -1228,7 +1249,6 @@ def _post_plot_logic(self): class LinePlot(MPLPlot): - def __init__(self, data, **kwargs): self.mark_right = kwargs.pop('mark_right', True) MPLPlot.__init__(self, data, **kwargs) @@ -1238,6 +1258,7 @@ def __init__(self, data, **kwargs): def _index_freq(self): from pandas.core.frame import DataFrame + if isinstance(self.data, (Series, DataFrame)): freq = getattr(self.data.index, 'freq', None) if freq is None: @@ -1259,9 +1280,11 @@ def _is_dynamic_freq(self, freq): def _no_base(self, freq): # hack this for 0.10.1, creating more technical debt...sigh from pandas.core.frame import DataFrame + if (isinstance(self.data, (Series, DataFrame)) and isinstance(self.data.index, DatetimeIndex)): import pandas.tseries.frequencies as freqmod + base = freqmod.get_freq(freq) x = self.data.index if (base <= freqmod.FreqGroup.FR_DAY): @@ -1333,6 +1356,7 @@ def _make_plot(self): def _make_ts_plot(self, data, **kwargs): from pandas.tseries.plotting import tsplot + kwargs = kwargs.copy() colors = self._get_colors() @@ -1342,7 +1366,7 @@ def _make_ts_plot(self, data, **kwargs): def _plot(data, col_num, ax, label, style, **kwds): newlines = tsplot(data, plotf, ax=ax, label=label, - style=style, **kwds) + style=style, **kwds) ax.grid(self.grid) lines.append(newlines[0]) @@ -1402,6 +1426,7 @@ def _maybe_convert_index(self, data): # tsplot converts automatically, but don't want to convert index # over and over for DataFrames from pandas.core.frame import DataFrame + if (isinstance(data.index, DatetimeIndex) and isinstance(data, DataFrame)): freq = getattr(data.index, 'freq', None) @@ -1455,7 +1480,6 @@ def _post_plot_logic(self): class BarPlot(MPLPlot): - _default_rot = {'bar': 90, 'barh': 0} def __init__(self, data, **kwargs): @@ -1467,7 +1491,7 @@ def __init__(self, data, **kwargs): else: self.tickoffset = 0.375 self.bar_width = 0.5 - self.log = kwargs.pop('log',False) + self.log = kwargs.pop('log', False) MPLPlot.__init__(self, data, **kwargs) def _args_adjust(self): @@ -1478,7 +1502,7 @@ def _args_adjust(self): def bar_f(self): if self.kind == 'bar': def f(ax, x, y, w, start=None, **kwds): - return ax.bar(x, y, w, bottom=start,log=self.log, **kwds) + return ax.bar(x, y, w, bottom=start, log=self.log, **kwds) elif self.kind == 'barh': def f(ax, x, y, w, start=None, log=self.log, **kwds): return ax.barh(x, y, w, left=start, **kwds) @@ -1519,7 +1543,7 @@ def _make_plot(self): start = 0 if mpl_le_1_2_1 else None if self.subplots: - rect = bar_f(ax, self.ax_pos, y, self.bar_width, + rect = bar_f(ax, self.ax_pos, y, self.bar_width, start=start, **kwds) ax.set_title(label) elif self.stacked: @@ -1567,8 +1591,8 @@ def _post_plot_logic(self): if name is not None: ax.set_ylabel(name) - # if self.subplots and self.legend: - # self.axes[0].legend(loc='best') + # if self.subplots and self.legend: + # self.axes[0].legend(loc='best') class BoxPlot(MPLPlot): @@ -1585,7 +1609,6 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, xlim=None, ylim=None, logx=False, logy=False, xticks=None, yticks=None, kind='line', sort_columns=False, fontsize=None, secondary_y=False, **kwds): - """ Make line, bar, or scatter plots of DataFrame series with the index on the x-axis using matplotlib / pylab. @@ -1664,8 +1687,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, raise ValueError('Invalid chart type given %s' % kind) if kind == 'scatter': - plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, - rot=rot,legend=legend, ax=ax, style=style, + plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots, + rot=rot, legend=legend, ax=ax, style=style, fontsize=fontsize, use_index=use_index, sharex=sharex, sharey=sharey, xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, title=title, grid=grid, @@ -1695,7 +1718,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, else: plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, - legend=legend, ax=ax, style=style, fontsize=fontsize, + legend=legend, ax=ax, style=style, + fontsize=fontsize, use_index=use_index, sharex=sharex, sharey=sharey, xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, title=title, grid=grid, figsize=figsize, logx=logx, @@ -1775,6 +1799,7 @@ def plot_series(series, label=None, kind='line', use_index=True, rot=None, be ignored. """ import matplotlib.pyplot as plt + if ax is None and len(plt.get_fignums()) > 0: ax = _gca() if ax.get_yaxis().get_ticks_position().strip().lower() == 'right': @@ -1829,6 +1854,7 @@ def boxplot(data, column=None, by=None, ax=None, fontsize=None, ax : matplotlib.axes.AxesSubplot """ from pandas import Series, DataFrame + if isinstance(data, Series): data = DataFrame({'x': data}) column = 'x' @@ -1838,11 +1864,12 @@ def _get_colors(): return _get_standard_colors(color=kwds.get('color'), num_colors=1) def maybe_color_bp(bp): - if 'color' not in kwds : + if 'color' not in kwds: from matplotlib.artist import setp - setp(bp['boxes'],color=colors[0],alpha=1) - setp(bp['whiskers'],color=colors[0],alpha=1) - setp(bp['medians'],color=colors[2],alpha=1) + + setp(bp['boxes'], color=colors[0], alpha=1) + setp(bp['whiskers'], color=colors[0], alpha=1) + setp(bp['medians'], color=colors[2], alpha=1) def plot_group(grouped, ax): keys, values = zip(*grouped) @@ -1916,7 +1943,8 @@ def format_date_labels(ax, rot): pass -def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False, **kwargs): +def scatter_plot(data, x, y, by=None, ax=None, figsize=None, grid=False, + **kwargs): """ Make a scatter plot from two DataFrame columns @@ -2018,6 +2046,7 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, return axes import matplotlib.pyplot as plt + n = len(data.columns) if layout is not None: @@ -2026,7 +2055,9 @@ def hist_frame(data, column=None, by=None, grid=True, xlabelsize=None, rows, cols = layout if rows * cols < n: - raise ValueError('Layout of %sx%s is incompatible with %s columns' % (rows, cols, n)) + raise ValueError( + 'Layout of %sx%s is incompatible with %s columns' % ( + rows, cols, n)) else: rows, cols = 1, 1 while rows * cols < n: @@ -2100,9 +2131,9 @@ def hist_series(self, by=None, ax=None, grid=True, xlabelsize=None, if kwds.get('layout', None) is not None: raise ValueError("The 'layout' keyword is not supported when " "'by' is None") - # hack until the plotting interface is a bit more unified + # hack until the plotting interface is a bit more unified fig = kwds.pop('figure', plt.gcf() if plt.get_fignums() else - plt.figure(figsize=figsize)) + plt.figure(figsize=figsize)) if (figsize is not None and tuple(figsize) != tuple(fig.get_size_inches())): fig.set_size_inches(*figsize, forward=True) @@ -2194,6 +2225,7 @@ def boxplot_frame_groupby(grouped, subplots=True, column=None, fontsize=None, ret[key] = d else: from pandas.tools.merge import concat + keys, frames = zip(*grouped) if grouped.axis == 0: df = concat(frames, keys=keys, axis=1) @@ -2488,6 +2520,424 @@ def _maybe_convert_date(x): x = conv_func(x) return x +# helper for cleaning up axes by removing ticks, tick labels, frame, etc. +def _clean_axis(ax): + """Remove ticks, tick labels, and frame from axis""" + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + for sp in ax.spines.values(): + sp.set_visible(False) + + +def _color_list_to_matrix_and_cmap(color_list, ind, row=True): + """ + For 'heatmap()' + This only works for 1-column color lists.. + TODO: Support multiple color labels on an element in the heatmap + """ + import matplotlib as mpl + + colors = set(color_list) + col_to_value = dict((col, i) for i, col in enumerate(colors)) + + # ind = column_dendrogram_distances['leaves'] + matrix = np.array([col_to_value[col] for col in color_list])[ind] + # Is this row-side or column side? + if row: + new_shape = (len(color_list), 1) + else: + new_shape = (1, len(color_list)) + matrix = matrix.reshape(new_shape) + + cmap = mpl.colors.ListedColormap(colors) + return matrix, cmap + + +def heatmap(df, + title=None, + title_fontsize=12, + colorbar_label='values', + col_side_colors=None, + row_side_colors=None, + color_scale='linear', + cmap=None, + linkage_method='average', + figsize=None, + label_rows=True, + label_cols=True, + vmin=None, + vmax=None, + xlabel_fontsize=12, + ylabel_fontsize=10, + cluster_cols=True, + cluster_rows=True, + linewidth=0, + edgecolor='white', + plot_df=None, + colorbar_ticklabels_fontsize=10, + colorbar_loc="upper left", + use_fastcluster=False, + metric='euclidean'): + """ + @author Olga Botvinnik olga.botvinnik@gmail.com + + This is liberally borrowed (with permission) from http://bit.ly/1eWcYWc + Many thanks to Christopher DeBoever and Mike Lovci for providing heatmap + guidance. + + + :param title_fontsize: + :param colorbar_ticklabels_fontsize: + :param colorbar_loc: Can be 'upper left' (in the corner), 'right', + or 'bottom' + + + :param df: The dataframe you want to cluster on + :param title: Title of the figure + :param colorbar_label: What to colorbar (color scale of the heatmap) + :param col_side_colors: Label the columns with a color + :param row_side_colors: Label the rows with a color + :param color_scale: Either 'linear' or 'log' + :param cmap: A matplotlib colormap, default is mpl.cm.Blues_r if data is + sequential, or mpl.cm.RdBu_r if data is divergent (has both positive and + negative numbers) + :param figsize: Size of the figure. The default is a function of the + dataframe size. + :param label_rows: Can be boolean or a list of strings, with exactly the + length of the number of rows in df. + :param label_cols: Can be boolean or a list of strings, with exactly the + length of the number of columns in df. + :param col_labels: If True, label with df.columns. If False, unlabeled. + Else, this can be an iterable to relabel the columns with labels of your own + choosing. This is helpful if you have duplicate column names and pandas + won't let you reindex it. + :param row_labels: If True, label with df.index. If False, unlabeled. + Else, this can be an iterable to relabel the row names with labels of your + own choosing. This is helpful if you have duplicate index names and pandas + won't let you reindex it. + :param xlabel_fontsize: Default 12pt + :param ylabel_fontsize: Default 10pt + :param cluster_cols: Boolean, whether or not to cluster the columns + :param cluster_rows: + :param plot_df: The dataframe you want to plot. This can contain NAs and + other nasty things. + :param row_linkage_method: + :param col_linkage_method: + :param vmin: Minimum value to plot on heatmap + :param vmax: Maximum value to plot on heatmap + :param linewidth: Linewidth of lines around heatmap box elements + (default 0) + :param edgecolor: Color of lines around heatmap box elements (default + white) + """ + #@return: fig, row_dendrogram, col_dendrogram + #@rtype: matplotlib.figure.Figure, dict, dict + #@raise TypeError: + import matplotlib.pyplot as plt + import matplotlib.gridspec as gridspec + import scipy.spatial.distance as distance + import scipy.cluster.hierarchy as sch + import matplotlib as mpl + from collections import Iterable + + #if cluster + + if (df.shape[0] > 1000 or df.shape[1] > 1000) or use_fastcluster: + try: + import fastcluster + linkage_function = fastcluster.linkage + except ImportError: + raise warnings.warn('Module "fastcluster" not found. The ' + 'dataframe ' + 'provided has ' + 'shape {}, and one ' + 'of the dimensions has greater than 1000 ' + 'variables. Calculating linkage on such a ' + 'matrix will take a long time with vanilla ' + '"scipy.cluster.hierarchy.linkage", and we ' + 'suggest fastcluster for such large datasets'\ + .format(df.shape), RuntimeWarning) + else: + linkage_function = sch.linkage + + almost_black = '#262626' + sch.set_link_color_palette([almost_black]) + if plot_df is None: + plot_df = df + + if (plot_df.index != df.index).any(): + raise ValueError('plot_df must have the exact same indices as df') + if (plot_df.columns != df.columns).any(): + raise ValueError('plot_df must have the exact same columns as df') + # make norm + + # Check if the matrix has values both above and below zero, or only above + # or only below zero. If both above and below, then the data is + # "divergent" and we will use a colormap with 0 centered at white, + # negative values blue, and positive values red. Otherwise, we will use + # the YlGnBu colormap. + divergent = df.max().max() > 0 and df.min().min() < 0 + + if color_scale == 'log': + if vmin is None: + vmin = max(np.floor(df.dropna(how='all').min().dropna().min()), 1e-10) + if vmax is None: + vmax = np.ceil(df.dropna(how='all').max().dropna().max()) + my_norm = mpl.colors.LogNorm(vmin, vmax) + elif divergent: + abs_max = abs(df.max().max()) + abs_min = abs(df.min().min()) + vmaxx = max(abs_max, abs_min) + my_norm = mpl.colors.Normalize(vmin=-vmaxx, vmax=vmaxx) + else: + my_norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) + + if cmap is None: + cmap = mpl.cm.RdBu_r if divergent else mpl.cm.YlGnBu + cmap.set_bad('white') + + # TODO: Add optimal leaf ordering for clusters + # TODO: if color_scale is 'log', should distance also be on np.log(df)? + # calculate pairwise distances for rows + if color_scale == 'log': + df = np.log10(df) + row_pairwise_dists = distance.squareform(distance.pdist(df, + metric=metric)) + row_linkage = linkage_function(row_pairwise_dists, method=linkage_method) + + # calculate pairwise distances for columns + col_pairwise_dists = distance.squareform(distance.pdist(df.T, + metric=metric)) + # cluster + col_linkage = linkage_function(col_pairwise_dists, method=linkage_method) + + # heatmap with row names + + def get_width_ratios(shape, side_colors, + colorbar_loc, dimension, side_colors_ratio=0.05): + """ + Figures out the ratio of each subfigure within the larger figure. + The dendrograms currently are 2*half_dendrogram, which is a proportion of + the dataframe shape. Right now, this only supports the colormap in + the upper left. The full figure map looks like: + + 0.1 0.1 0.05 1.0 + 0.1 cb column + 0.1 dendrogram + 0.05 col colors + | r d r + | o e o + | w n w + | d + 1.0| r c heatmap + | o o + | g l + | r o + | a r + | m s + + The colorbar is half_dendrogram of the whitespace in the corner between + the row and column dendrogram. Otherwise, it's too big and its + corners touch the heatmap, which I didn't like. + + For example, if there are side_colors, need to provide an extra value + in the ratio tuples, with the width side_colors_ratio. But if there + aren't any side colors, then the tuple is of size 3 (half_dendrogram, + half_dendrogram, 1.0), and if there are then the tuple is of size 4 ( + half_dendrogram, half_dendrogram, 0.05, 1.0) + + :param side_colors: + :type side_colors: + :param colorbar_loc: + :type colorbar_loc: + :param dimension: + :type dimension: + :param side_colors_ratio: + :type side_colors_ratio: + :return: + :rtype: + """ + i = 0 if dimension == 'height' else 1 + half_dendrogram = shape[i] * 0.1/shape[i] + if colorbar_loc not in ('upper left', 'right', 'bottom'): + raise AssertionError("{} is not a valid 'colorbar_loc' (valid: " + "'upper left', 'right', 'bottom')".format( + colorbar_loc)) + if dimension not in ('height', 'width'): + raise AssertionError("{} is not a valid 'dimension' (valid: " + "'height', 'width')".format( + dimension)) + + ratios = [half_dendrogram, half_dendrogram] + if side_colors: + ratios += [side_colors_ratio] + + if (colorbar_loc == 'right' and dimension == 'width') or ( + colorbar_loc == 'bottom' and dimension == 'height'): + return ratios + [1, 0.05] + else: + return ratios + [1] + + + width_ratios = get_width_ratios(df.shape, + row_side_colors, + colorbar_loc, dimension='width') + height_ratios = get_width_ratios(df.shape, + col_side_colors, + colorbar_loc, dimension='height') + nrows = 3 if col_side_colors is None else 4 + ncols = 3 if row_side_colors is None else 4 + + width = df.shape[1] * 0.25 + height = min(df.shape[0] * .75, 40) + if figsize is None: + figsize = (width, height) + #print figsize + + + + fig = plt.figure(figsize=figsize) + heatmap_gridspec = \ + gridspec.GridSpec(nrows, ncols, wspace=0.0, hspace=0.0, + width_ratios=width_ratios, + height_ratios=height_ratios) + # print heatmap_gridspec + + ### col dendrogram ### + col_dendrogram_ax = fig.add_subplot(heatmap_gridspec[1, ncols - 1]) + if cluster_cols: + col_dendrogram = sch.dendrogram(col_linkage, + color_threshold=np.inf, + color_list=[almost_black]) + else: + col_dendrogram = {'leaves': list(range(df.shape[1]))} + _clean_axis(col_dendrogram_ax) + + # TODO: Allow for array of color labels + ### col colorbar ### + if col_side_colors is not None: + column_colorbar_ax = fig.add_subplot(heatmap_gridspec[2, ncols - 1]) + col_side_matrix, col_cmap = _color_list_to_matrix_and_cmap( + col_side_colors, + ind=col_dendrogram['leaves'], + row=False) + column_colorbar_ax_pcolormesh = column_colorbar_ax.pcolormesh( + col_side_matrix, cmap=col_cmap, + edgecolor=edgecolor, linewidth=linewidth) + column_colorbar_ax.set_xlim(0, col_side_matrix.shape[1]) + _clean_axis(column_colorbar_ax) + + ### row dendrogram ### + row_dendrogram_ax = fig.add_subplot(heatmap_gridspec[nrows - 1, 1]) + if cluster_rows: + row_dendrogram = \ + sch.dendrogram(row_linkage, + color_threshold=np.inf, + orientation='right', + color_list=[almost_black]) + else: + row_dendrogram = {'leaves': list(range(df.shape[0]))} + _clean_axis(row_dendrogram_ax) + + ### row colorbar ### + if row_side_colors is not None: + row_colorbar_ax = fig.add_subplot(heatmap_gridspec[nrows - 1, 2]) + row_side_matrix, row_cmap = _color_list_to_matrix_and_cmap( + row_side_colors, + ind=row_dendrogram['leaves'], + row=True) + row_colorbar_ax.pcolormesh(row_side_matrix, cmap=row_cmap, + edgecolors=edgecolor, linewidth=linewidth) + row_colorbar_ax.set_ylim(0, row_side_matrix.shape[0]) + _clean_axis(row_colorbar_ax) + + ### heatmap #### + heatmap_ax = fig.add_subplot(heatmap_gridspec[nrows - 1, ncols - 1]) + heatmap_ax_pcolormesh = \ + heatmap_ax.pcolormesh(plot_df.ix[row_dendrogram['leaves'], + col_dendrogram['leaves']].values, + norm=my_norm, cmap=cmap, + edgecolor=edgecolor, + lw=linewidth) + + heatmap_ax.set_ylim(0, df.shape[0]) + heatmap_ax.set_xlim(0, df.shape[1]) + _clean_axis(heatmap_ax) + + ## row labels ## + if isinstance(label_rows, Iterable): + if len(label_rows) == df.shape[0]: + yticklabels = label_rows + label_rows = True + else: + raise AssertionError("Length of 'label_rows' must be the same as " + "df.shape[0] (len(label_rows)={}, df.shape[" + "0]={})".format(len(label_rows), df.shape[0])) + elif label_rows: + yticklabels = df.index + + if label_rows: + yticklabels = [yticklabels[i] for i in row_dendrogram['leaves']] + heatmap_ax.set_yticks(np.arange(df.shape[0]) + 0.5) + heatmap_ax.yaxis.set_ticks_position('right') + heatmap_ax.set_yticklabels(yticklabels, fontsize=ylabel_fontsize) + + # Add title if there is one: + if title is not None: + col_dendrogram_ax.set_title(title, fontsize=title_fontsize) + + ## col labels ## + if isinstance(label_cols, Iterable): + if len(label_cols) == df.shape[1]: + xticklabels = label_cols + label_cols = True + else: + raise AssertionError("Length of 'label_cols' must be the same as " + "df.shape[1] (len(label_cols)={}, df.shape[" + "1]={})".format(len(label_cols), df.shape[1])) + elif label_cols: + xticklabels = df.columns + + if label_cols: + xticklabels = [xticklabels[i] for i in col_dendrogram['leaves']] + heatmap_ax.set_xticks(np.arange(df.shape[1]) + 0.5) + xticklabels = heatmap_ax.set_xticklabels(xticklabels, + fontsize=xlabel_fontsize) + # rotate labels 90 degrees + for label in xticklabels: + label.set_rotation(90) + + # remove the tick lines + for l in heatmap_ax.get_xticklines() + heatmap_ax.get_yticklines(): + l.set_markersize(0) + + ### scale colorbar ### + scale_colorbar_ax = fig.add_subplot( + heatmap_gridspec[0:(nrows - 1), + 0]) # colorbar for scale in upper left corner + + # note that we could pass the norm explicitly with norm=my_norm + cb = fig.colorbar(heatmap_ax_pcolormesh, + cax=scale_colorbar_ax) + cb.set_label(colorbar_label) + + # move ticks to left side of colorbar to avoid problems with tight_layout + cb.ax.yaxis.set_ticks_position('left') + cb.outline.set_linewidth(0) + + ## Make colorbar narrower + #xmin, xmax, ymin, ymax = cb.ax.axis() + #cb.ax.set_xlim(xmin, xmax/0.2) + + # make colorbar labels smaller + yticklabels = cb.ax.yaxis.get_ticklabels() + for t in yticklabels: + t.set_fontsize(colorbar_ticklabels_fontsize) + + fig.tight_layout() + return fig, row_dendrogram, col_dendrogram + + if __name__ == '__main__': # import pandas.rpy.common as com # sales = com.load_data('sanfrancisco.home.sales', package='nutshell') @@ -2501,6 +2951,7 @@ def _maybe_convert_date(x): import pandas.tools.plotting as plots import pandas.core.frame as fr + reload(plots) reload(fr) from pandas.core.frame import DataFrame