|
| 1 | +import numpy as np |
| 2 | +try: |
| 3 | + import matplotlib.pyplot as plt |
| 4 | +except ImportError: # mpl is optional |
| 5 | + pass |
| 6 | +from .kdeplot import fast_kde |
| 7 | +from .utils import get_default_varnames |
| 8 | +from ..stats import hpd |
| 9 | + |
| 10 | + |
| 11 | +def densityplot(trace, models=None, varnames=None, alpha=0.05, point_estimate='mean', |
| 12 | + colors='cycle', outline=True, hpd_markers='', shade=0., figsize=None, textsize=12, |
| 13 | + plot_transformed=False, ax=None): |
| 14 | + """ |
| 15 | + Generates KDE plots truncated at their 100*(1-alpha)% credible intervals from a trace or list of |
| 16 | + traces. KDE plots are grouped per variable and colors assigned to models. |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + trace : trace or list of traces |
| 21 | + Trace(s) from an MCMC sample. |
| 22 | + models : list |
| 23 | + List with names for the models in the list of traces. Useful when |
| 24 | + plotting more that one trace. |
| 25 | + varnames: list |
| 26 | + List of variables to plot (defaults to None, which results in all |
| 27 | + variables plotted). |
| 28 | + alpha : float |
| 29 | + Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). |
| 30 | + point_estimate : str or None |
| 31 | + Plot point estimate per variable. Values should be 'mean', 'median' or None. |
| 32 | + Defaults to 'mean'. |
| 33 | + colors : list or string, optional |
| 34 | + List with valid matplotlib colors, one color per model. Alternative a string can be passed. |
| 35 | + If the string is `cycle `, it will automatically choose a color per model from matplolib's |
| 36 | + cycle. If a single color is passed, e.g. 'k', 'C2' or 'red' this color will be used for all |
| 37 | + models. Defaults to 'C0' (blueish in most matplotlib styles) |
| 38 | + outline : boolean |
| 39 | + Use a line to draw the truncated KDE and. Defaults to True |
| 40 | + hpd_markers : str |
| 41 | + A valid `matplotlib.markers` like 'v', used to indicate the limits of the hpd interval. |
| 42 | + Defaults to empty string (no marker). |
| 43 | + shade : float |
| 44 | + Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1 |
| 45 | + (opaque). Defaults to 0. |
| 46 | + figsize : tuple |
| 47 | + Figure size. If None, size is (6, number of variables * 2) |
| 48 | + textsize : int |
| 49 | + Text size of the legend. Default 12. |
| 50 | + plot_transformed : bool |
| 51 | + Flag for plotting automatically transformed variables in addition to original variables |
| 52 | + Defaults to False. |
| 53 | + ax : axes |
| 54 | + Matplotlib axes. |
| 55 | +
|
| 56 | + Returns |
| 57 | + ------- |
| 58 | +
|
| 59 | + ax : Matplotlib axes |
| 60 | +
|
| 61 | + """ |
| 62 | + if point_estimate not in ('mean', 'median', None): |
| 63 | + raise ValueError("Point estimate should be 'mean' or 'median'") |
| 64 | + |
| 65 | + if not isinstance(trace, (list, tuple)): |
| 66 | + trace = [trace] |
| 67 | + |
| 68 | + lenght_trace = len(trace) |
| 69 | + |
| 70 | + if models is None: |
| 71 | + if lenght_trace > 1: |
| 72 | + models = ['m_{}'.format(i) for i in range(lenght_trace)] |
| 73 | + else: |
| 74 | + models = [''] |
| 75 | + elif len(models) != lenght_trace: |
| 76 | + raise ValueError("The number of names for the models does not match the number of models") |
| 77 | + |
| 78 | + lenght_models = len(models) |
| 79 | + |
| 80 | + if colors == 'cycle': |
| 81 | + colors = ['C{}'.format(i % 10) for i in range(lenght_models)] |
| 82 | + elif isinstance(colors, str): |
| 83 | + colors = [colors for i in range(lenght_models)] |
| 84 | + |
| 85 | + if varnames is None: |
| 86 | + varnames = [] |
| 87 | + for tr in trace: |
| 88 | + varnames_tmp = get_default_varnames(tr.varnames, plot_transformed) |
| 89 | + for v in varnames_tmp: |
| 90 | + if v not in varnames: |
| 91 | + varnames.append(v) |
| 92 | + |
| 93 | + if figsize is None: |
| 94 | + figsize = (6, len(varnames) * 2) |
| 95 | + |
| 96 | + fig, kplot = plt.subplots(len(varnames), 1, squeeze=False, figsize=figsize) |
| 97 | + kplot = kplot.flatten() |
| 98 | + |
| 99 | + for v_idx, vname in enumerate(varnames): |
| 100 | + for t_idx, tr in enumerate(trace): |
| 101 | + if vname in tr.varnames: |
| 102 | + vec = tr.get_values(vname) |
| 103 | + k = np.size(vec[0]) |
| 104 | + if k > 1: |
| 105 | + vec = np.split(vec.T.ravel(), k) |
| 106 | + for i in range(k): |
| 107 | + _kde_helper(vec[i], vname, colors[t_idx], alpha, point_estimate, |
| 108 | + hpd_markers, outline, shade, kplot[v_idx]) |
| 109 | + else: |
| 110 | + _kde_helper(vec, vname, colors[t_idx], alpha, point_estimate, |
| 111 | + hpd_markers, outline, shade, kplot[v_idx]) |
| 112 | + |
| 113 | + if lenght_trace > 1: |
| 114 | + for m_idx, m in enumerate(models): |
| 115 | + kplot[0].plot([], label=m, c=colors[m_idx]) |
| 116 | + kplot[0].legend(fontsize=textsize) |
| 117 | + |
| 118 | + fig.tight_layout() |
| 119 | + |
| 120 | + return kplot |
| 121 | + |
| 122 | + |
| 123 | +def _kde_helper(vec, vname, c, alpha, point_estimate, hpd_markers, |
| 124 | + outline, shade, ax): |
| 125 | + """ |
| 126 | + vec : array |
| 127 | + 1D array from trace |
| 128 | + vname : str |
| 129 | + variable name |
| 130 | + c : str |
| 131 | + matplotlib color |
| 132 | + alpha : float |
| 133 | + Alpha value for (1-alpha)*100% credible intervals (defaults to 0.05). |
| 134 | + point_estimate : str or None |
| 135 | + 'mean' or 'median' |
| 136 | + shade : float |
| 137 | + Alpha blending value for the shaded area under the curve, between 0 (no shade) and 1 |
| 138 | + (opaque). Defaults to 0. |
| 139 | + ax : matplotlib axes |
| 140 | + """ |
| 141 | + density, l, u = fast_kde(vec) |
| 142 | + x = np.linspace(l, u, len(density)) |
| 143 | + hpd_ = hpd(vec, alpha) |
| 144 | + cut = (x >= hpd_[0]) & (x <= hpd_[1]) |
| 145 | + |
| 146 | + xmin = x[cut][0] |
| 147 | + xmax = x[cut][-1] |
| 148 | + ymin = density[cut][0] |
| 149 | + ymax = density[cut][-1] |
| 150 | + |
| 151 | + if outline: |
| 152 | + ax.plot(x[cut], density[cut], color=c) |
| 153 | + ax.plot([xmin, xmin], [-0.5, ymin], color=c, ls='-') |
| 154 | + ax.plot([xmax, xmax], [-0.5, ymax], color=c, ls='-') |
| 155 | + |
| 156 | + if hpd_markers: |
| 157 | + ax.plot(xmin, 0, 'v', color=c, markeredgecolor='k') |
| 158 | + ax.plot(xmax, 0, 'v', color=c, markeredgecolor='k') |
| 159 | + |
| 160 | + if shade: |
| 161 | + ax.fill_between(x, density, where=cut, color=c, alpha=shade) |
| 162 | + |
| 163 | + if point_estimate is not None: |
| 164 | + if point_estimate == 'mean': |
| 165 | + ps = np.mean(vec) |
| 166 | + if point_estimate == 'median': |
| 167 | + ps = np.median(vec) |
| 168 | + ax.plot(ps, 0, 'o', color=c, markeredgecolor='k') |
| 169 | + |
| 170 | + ax.set_yticks([]) |
| 171 | + ax.set_title(vname) |
| 172 | + for pos in ['left', 'right', 'top']: |
| 173 | + ax.spines[pos].set_visible(0) |
0 commit comments