Skip to content

Commit a54a53e

Browse files
aloctavodiatwiecki
authored andcommitted
Kde boundaries (#1567)
* use a fft-based kde instead of scipy.gaussian_kde * fix typo
1 parent 68e53eb commit a54a53e

File tree

1 file changed

+62
-10
lines changed

1 file changed

+62
-10
lines changed

pymc3/plots.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import matplotlib.pyplot as plt
55
import pymc3 as pm
66
from .stats import quantiles, hpd
7+
from scipy.signal import gaussian, convolve
78

89
__all__ = ['traceplot', 'kdeplot', 'kde2plot',
910
'forestplot', 'autocorrplot', 'plot_posterior']
@@ -120,16 +121,15 @@ def kdeplot_op(ax, data, prior=None, prior_alpha=1, prior_style='--'):
120121
for i in range(data.shape[1]):
121122
d = data[:, i]
122123
try:
123-
density = kde.gaussian_kde(d)
124-
l = np.min(d)
125-
u = np.max(d)
126-
x = np.linspace(0, 1, 100) * (u - l) + l
124+
density, l, u = fast_kde(d)
125+
x = np.linspace(l, u, len(density))
127126

128127
if prior is not None:
129128
p = prior.logp(x).eval()
130129
ax.plot(x, np.exp(p), alpha=prior_alpha, ls=prior_style)
131130

132-
ax.plot(x, density(x))
131+
ax.plot(x, density)
132+
ax.set_ylim(bottom=0)
133133

134134
except LinAlgError:
135135
errored.append(i)
@@ -721,11 +721,9 @@ def set_key_if_doesnt_exist(d, key, value):
721721
d[key] = value
722722

723723
if kde_plot:
724-
density = kde.gaussian_kde(trace_values)
725-
l = np.min(trace_values)
726-
u = np.max(trace_values)
727-
x = np.linspace(0, 1, 100) * (u - l) + l
728-
ax.plot(x, density(x), **kwargs)
724+
density, l, u = fast_kde(trace_values)
725+
x = np.linspace(l, u, len(density))
726+
ax.plot(x, density, **kwargs)
729727
else:
730728
set_key_if_doesnt_exist(kwargs, 'bins', 30)
731729
set_key_if_doesnt_exist(kwargs, 'edgecolor', 'w')
@@ -790,3 +788,57 @@ def get_trace_dict(tr, varnames):
790788

791789
fig.tight_layout()
792790
return ax
791+
792+
def fast_kde(x):
793+
"""
794+
A fft-based Gaussian kernel density estimate (KDE) for computing
795+
the KDE on a regular grid.
796+
The code was adapted from https://github.com/mfouesneau/faststats
797+
798+
Parameters
799+
----------
800+
801+
x : Numpy array or list
802+
803+
Returns
804+
-------
805+
806+
grid: A gridded 1D KDE of the input points (x).
807+
xmin: minimum value of x
808+
xmax: maximum value of x
809+
810+
"""
811+
812+
xmin, xmax = x.min(), x.max()
813+
814+
n = len(x)
815+
nx = 256
816+
817+
# compute histogram
818+
bins = np.linspace(x.min(), x.max(), nx)
819+
xyi = np.digitize(x, bins)
820+
dx = (xmax - xmin) / (nx - 1)
821+
grid = np.histogram(x, bins=nx)[0]
822+
823+
# Scaling factor for bandwidth
824+
scotts_factor = n ** (-0.2)
825+
# Determine the bandwidth using Scott's rule
826+
std_x = np.std(xyi)
827+
kern_nx = int(np.round(scotts_factor * 2 * np.pi * std_x))
828+
829+
# Evaluate the gaussian function on the kernel grid
830+
kernel = np.reshape(gaussian(kern_nx, scotts_factor * std_x), kern_nx)
831+
832+
833+
# Compute the KDE
834+
# use symmetric padding to correct for data boundaries in the kde
835+
npad = np.min((nx, 2 * kern_nx))
836+
837+
grid = np.concatenate([grid[npad: 0: -1], grid, grid[nx: nx - npad: -1]])
838+
grid = convolve(grid, kernel, mode='same')[npad: npad + nx]
839+
840+
norm_factor = n * dx * (2 * np.pi * std_x ** 2 * scotts_factor ** 2) ** 0.5
841+
842+
grid /= norm_factor
843+
844+
return grid, xmin, xmax

0 commit comments

Comments
 (0)