|
4 | 4 | import matplotlib.pyplot as plt
|
5 | 5 | import pymc3 as pm
|
6 | 6 | from .stats import quantiles, hpd
|
| 7 | +from scipy.signal import gaussian, convolve |
7 | 8 |
|
8 | 9 | __all__ = ['traceplot', 'kdeplot', 'kde2plot',
|
9 | 10 | 'forestplot', 'autocorrplot', 'plot_posterior']
|
@@ -120,16 +121,15 @@ def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'):
|
120 | 121 | for i in range(data.shape[1]):
|
121 | 122 | d = data[:, i]
|
122 | 123 | try:
|
123 |
| - density = kde.gaussian_kde(d) |
124 |
| - l = np.min(d) |
125 |
| - u = np.max(d) |
126 |
| - x = np.linspace(0, 1, 100) * (u - l) + l |
| 124 | + density, l, u = fast_kde(d) |
| 125 | + x = np.linspace(l, u, len(density)) |
127 | 126 |
|
128 | 127 | if prior is not None:
|
129 | 128 | p = prior.logp(x).eval()
|
130 | 129 | ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style)
|
131 | 130 |
|
132 |
| - ax.plot(x, density(x)) |
| 131 | + ax.plot(x, density) |
| 132 | + ax.set_ylim(bottom=0) |
133 | 133 |
|
134 | 134 | except LinAlgError:
|
135 | 135 | errored.append(i)
|
@@ -721,11 +721,9 @@ def set_key_if_doesnt_exist(d, key, value):
|
721 | 721 | d[key] = value
|
722 | 722 |
|
723 | 723 | if kde_plot:
|
724 |
| - density = kde.gaussian_kde(trace_values) |
725 |
| - l = np.min(trace_values) |
726 |
| - u = np.max(trace_values) |
727 |
| - x = np.linspace(0, 1, 100) * (u - l) + l |
728 |
| - ax.plot(x, density(x), **kwargs) |
| 724 | + density, l, u = fast_kde(trace_values) |
| 725 | + x = np.linspace(l, u, len(density)) |
| 726 | + ax.plot(x, density, **kwargs) |
729 | 727 | else:
|
730 | 728 | set_key_if_doesnt_exist(kwargs, 'bins', 30)
|
731 | 729 | set_key_if_doesnt_exist(kwargs, 'edgecolor', 'w')
|
@@ -790,3 +788,57 @@ def get_trace_dict(tr, varnames):
|
790 | 788 |
|
791 | 789 | fig.tight_layout()
|
792 | 790 | return ax
|
| 791 | + |
| 792 | +def fast_kde(x): |
| 793 | + """ |
| 794 | + A fft-based Gaussian kernel density estimate (KDE) for computing |
| 795 | + the KDE on a regular grid. |
| 796 | + The code was adapted from https://github.com/mfouesneau/faststats |
| 797 | + |
| 798 | + Parameters |
| 799 | + ---------- |
| 800 | +
|
| 801 | + x : Numpy array or list |
| 802 | +
|
| 803 | + Returns |
| 804 | + ------- |
| 805 | +
|
| 806 | + grid: A gridded 1D KDE of the input points (x). |
| 807 | + xmin: minimum value of x |
| 808 | + xmax: maximum value of x |
| 809 | + |
| 810 | + """ |
| 811 | + |
| 812 | + xmin, xmax = x.min(), x.max() |
| 813 | + |
| 814 | + n = len(x) |
| 815 | + nx = 256 |
| 816 | + |
| 817 | + # compute histogram |
| 818 | + bins = np.linspace(x.min(), x.max(), nx) |
| 819 | + xyi = np.digitize(x, bins) |
| 820 | + dx = (xmax - xmin) / (nx - 1) |
| 821 | + grid = np.histogram(x, bins=nx)[0] |
| 822 | + |
| 823 | + # Scaling factor for bandwidth |
| 824 | + scotts_factor = n ** (-0.2) |
| 825 | + # Determine the bandwidth using Scott's rule |
| 826 | + std_x = np.std(xyi) |
| 827 | + kern_nx = int(np.round(scotts_factor * 2 * np.pi * std_x)) |
| 828 | + |
| 829 | + # Evaluate the gaussian function on the kernel grid |
| 830 | + kernel = np.reshape(gaussian(kern_nx, scotts_factor * std_x), kern_nx) |
| 831 | + |
| 832 | + |
| 833 | + # Compute the KDE |
| 834 | + # use symmetric padding to correct for data boundaries in the kde |
| 835 | + npad = np.min((nx, 2 * kern_nx)) |
| 836 | + |
| 837 | + grid = np.concatenate([grid[npad: 0: -1], grid, grid[nx: nx - npad: -1]]) |
| 838 | + grid = convolve(grid, kernel, mode='same')[npad: npad + nx] |
| 839 | + |
| 840 | + norm_factor = n * dx * (2 * np.pi * std_x ** 2 * scotts_factor ** 2) ** 0.5 |
| 841 | + |
| 842 | + grid /= norm_factor |
| 843 | + |
| 844 | + return grid, xmin, xmax |
0 commit comments