-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
SparseLatent GP #2951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SparseLatent GP #2951
Changes from all commits
e13c4cd
0d34aed
1fa1ea7
768fa70
44cfbb9
c3cdd2c
ab6e01d
2389392
e63278e
0f40ea7
974d84b
522cb15
812d421
e9d3241
9959049
26828b1
f05bac7
cf3c0fa
e9a28b5
fbdbdf8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
from . import cov | ||
from . import mean | ||
from . import util | ||
from .gp import Latent, Marginal, MarginalSparse, TP, MarginalKron | ||
from .gp import Latent, LatentSparse, Marginal, MarginalSparse, TP, MarginalKron |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,6 +9,24 @@ | |
solve_upper = tt.slinalg.Solve(A_structure='upper_triangular') | ||
solve = tt.slinalg.Solve(A_structure='general') | ||
|
||
def invert_dot(L, X): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These are great! For consistency though, it would be good to use these functions throughout the gp source code (I know, a pain) or just write the theano ops as before. |
||
"""Wrapper for common pattern K^{-1} @ X where K = L @ L^T""" | ||
return solve_upper(L.T, solve_lower(L, X)) | ||
|
||
def project_inverse(P, L, diag=True, P_T=None): | ||
"""Wrapper for common pattern P @ K^{-1} @ P^T where K = L @ L^T""" | ||
same_P = P_T is None | ||
if same_P: | ||
P_T = P.T | ||
A = solve_lower(L, P_T) | ||
if diag: | ||
return tt.sum(A * A, axis=0) # the diagonal of A.T @ A | ||
else: | ||
if same_P: | ||
return tt.dot(A.T, A) | ||
else: | ||
return tt.dot(P, invert_dot(L, P_T)) | ||
|
||
|
||
def infer_shape(X, n_points=None): | ||
if n_points is None: | ||
|
@@ -71,7 +89,7 @@ def setter(self, val): | |
return gp_wrapper | ||
|
||
|
||
def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds"): | ||
def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds", label=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice |
||
""" A helper function for plotting 1D GP posteriors from trace """ | ||
import matplotlib.pyplot as plt | ||
|
||
|
@@ -80,11 +98,13 @@ def plot_gp_dist(ax, samples, x, plot_samples=True, palette="Reds"): | |
colors = (percs - np.min(percs)) / (np.max(percs) - np.min(percs)) | ||
samples = samples.T | ||
x = x.flatten() | ||
i_last = len(percs) - 1 | ||
for i, p in enumerate(percs[::-1]): | ||
upper = np.percentile(samples, p, axis=1) | ||
lower = np.percentile(samples, 100-p, axis=1) | ||
color_val = colors[i] | ||
ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=0.8) | ||
lab = label if i == i_last else None | ||
ax.fill_between(x, upper, lower, color=cmap(color_val), alpha=0.8, label=lab) | ||
if plot_samples: | ||
# plot a few samples | ||
idx = np.random.randint(0, samples.shape[1], 30) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about
shape_f
?