-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Refactor Normal and MvNormal #2847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from 23 commits
Commits
Show all changes
68 commits
Select commit
Hold shift + click to select a range
87c4795
Remove unused variables
gBokiau 380253e
backyard cleaning
gBokiau 0dece51
WIP - switch to Theano implementations
gBokiau 9d5c4e4
Harmonize how covariance matrices are initialized
gBokiau a7e182e
Using `OpFromGraph` for MvNormal logp's
gBokiau 60493c9
fix imports
gBokiau edf0ae8
delay floatX(k)
gBokiau 4ca2caa
Fix ifelse statements
gBokiau a28d3e1
logp returns a vector, not a matrix
gBokiau ead0d1f
Fix float mismatches
gBokiau 9f9da75
TODO check why logp doesn't always return a vector
gBokiau 41fc49a
amend tests and fix typos
gBokiau 8e8f84f
Hopefully solve float errors
gBokiau 8b2b217
inelegant solution for mvt for the time being
gBokiau 23de07e
minor fixes
gBokiau 8be16c5
fix typo in test, hopefully final shot at fixing float32-mode errors
gBokiau e3dbb16
GP: Delegate cholesky to MvNormal
gBokiau b4effab
harmonize floatX in mv
gBokiau d26d096
more float fixes
gBokiau f2fc715
return -inf instead of nan
gBokiau 6fb9d8c
…more typehinting
gBokiau 8e3c7cf
erring on the safe side of type hinting
gBokiau 596877f
again, not sure about this
gBokiau 44e7a25
styling and correct tests
gBokiau 81a0c78
This might or might not work
gBokiau cb738cb
typo
gBokiau d10c302
Apply same logic to timeseries
gBokiau 8afb4b2
fixing and extending tests
gBokiau a8de397
one more typo
gBokiau 7833d50
fix ommissions in tests
gBokiau a81eb16
more typos
gBokiau 1e40870
Adjust logic for MvStudent
gBokiau 0426333
fix approach to replacement
gBokiau 1acae23
Fixing more omissions
gBokiau ecff6fc
Typos. Not very elegant, will have to review structure.
gBokiau 5479445
Not sure why TestScalarParameterSamples::test_mv_t fails on onedim
gBokiau d9ccfdf
oversights
gBokiau 896c0c8
one more oversight
gBokiau 2a85594
I don't think want the cholesky to return NaN's in GP
gBokiau a28ba4d
Small style improvements
gBokiau edf2b30
Omitted to remove
gBokiau 38f3ed0
Style
gBokiau 5e7cf9c
some fixes
gBokiau c905e8f
Cleaning up
gBokiau 61fb81d
typo
gBokiau 5db7527
Fix 1-dim shape params
gBokiau 5fb74b0
same typo
gBokiau 6025c41
Fix tau + postpone gradients with anything but cov for later
gBokiau ceddfce
bypass diagonal test when cholesky is given
gBokiau 82f56df
omissions, tau should now be ok everywhere
gBokiau 12435fc
Avoiding repetition
gBokiau b380f0e
Throwing FloatX where I can
gBokiau 5131403
more FloatX
gBokiau 06c6c90
woops
gBokiau 2d97c55
This complains that logp isn't scalar.
gBokiau d1feb14
Try FullRangGroup choleksy stabilisation
gBokiau bece0d8
floatX's
gBokiau 3fff0fa
typo
gBokiau 577cf92
maybe this
gBokiau daee63e
and yet more floatX's
gBokiau 774a402
and more still
gBokiau effa3d2
…then maybe this.
gBokiau 54e4a76
typo
gBokiau fc8590e
another typo
gBokiau afaa3aa
reverting to tt.switch after all
gBokiau 721c24a
not doing OpFromGraph atm
gBokiau e9e9d05
…right.
gBokiau d0035b1
omitted to replace
gBokiau File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,9 +6,10 @@ | |
from __future__ import division | ||
|
||
import numpy as np | ||
import scipy.linalg | ||
import theano.tensor as tt | ||
import theano | ||
from theano.ifelse import ifelse | ||
from theano.tensor import slinalg | ||
|
||
from .special import gammaln | ||
from pymc3.theanof import floatX | ||
|
@@ -143,15 +144,14 @@ def log_normal(x, mean, **kwargs): | |
return f(c) - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2. * std ** 2) | ||
|
||
|
||
def MvNormalLogp(): | ||
def MvNormalLogp(with_choleksy=False): | ||
"""Compute the log pdf of a multivariate normal distribution. | ||
|
||
This should be used in MvNormal.logp once Theano#5908 is released. | ||
|
||
Parameters | ||
---------- | ||
cov : tt.matrix | ||
The covariance matrix. | ||
The covariance matrix or its Cholesky decompositon (the latter if | ||
`chol_cov` is set to True when instantiating the Op). | ||
delta : tt.matrix | ||
Array of deviations from the mean. | ||
""" | ||
|
@@ -160,24 +160,37 @@ def MvNormalLogp(): | |
delta = tt.matrix('delta') | ||
delta.tag.test_value = floatX(np.zeros((2, 3))) | ||
|
||
solve_lower = tt.slinalg.Solve(A_structure='lower_triangular') | ||
solve_upper = tt.slinalg.Solve(A_structure='upper_triangular') | ||
cholesky = Cholesky(nofail=True, lower=True) | ||
solve_lower = slinalg.Solve(A_structure='lower_triangular', overwrite_b=True) | ||
solve_upper = slinalg.Solve(A_structure='upper_triangular', overwrite_b=True) | ||
|
||
n, k = delta.shape | ||
n, k = f(n), f(k) | ||
chol_cov = cholesky(cov) | ||
diag = tt.nlinalg.diag(chol_cov) | ||
ok = tt.all(diag > 0) | ||
n = f(n) | ||
|
||
if not with_choleksy: | ||
# add inplace=True when/if impletemented by Theano | ||
cholesky = slinalg.Cholesky(lower=True, on_error="nan") | ||
cov = cholesky(cov) | ||
# The Cholesky op will return NaNs if the cov is not positive definite | ||
# -- checking the first value is sufficient | ||
ok = ~tt.isnan(cov[0,0]) | ||
# will all be NaN if the Cholesky was no-go, which is fine | ||
diag = tt.ExtractDiag(view=True)(cov) | ||
else: | ||
diag = tt.ExtractDiag(view=True)(cov) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good point to use view |
||
# Here we must check if the Cholesky is positive definite | ||
ok = tt.all(diag>0) | ||
|
||
chol_cov = tt.switch(ok, chol_cov, tt.fill(chol_cov, 1)) | ||
# `solve_lower` throws errors with NaNs hence we replace the cov with | ||
# identity and return -Inf later | ||
chol_cov = ifelse(ok, cov, tt.eye(k, dtype=theano.config.floatX)) | ||
delta_trans = solve_lower(chol_cov, delta.T).T | ||
|
||
result = n * k * tt.log(f(2) * np.pi) | ||
result = n * f(k) * tt.log(f(2 * np.pi)) | ||
result += f(2) * n * tt.sum(tt.log(diag)) | ||
result += (delta_trans ** f(2)).sum() | ||
result = f(-.5) * result | ||
logp = tt.switch(ok, result, -np.inf) | ||
|
||
logp = ifelse(ok, f(result), f(-np.inf * tt.ones_like(result))) | ||
|
||
def dlogp(inputs, gradients): | ||
g_logp, = gradients | ||
|
@@ -186,109 +199,33 @@ def dlogp(inputs, gradients): | |
g_logp.tag.test_value = floatX(1.) | ||
n, k = delta.shape | ||
|
||
chol_cov = cholesky(cov) | ||
diag = tt.nlinalg.diag(chol_cov) | ||
ok = tt.all(diag > 0) | ||
if not with_choleksy: | ||
cov = cholesky(cov) | ||
ok = ~tt.isnan(cov[0,0]) | ||
else: | ||
diag = tt.ExtractDiag(view=True)(cov) | ||
ok = tt.all(diag>0) | ||
|
||
chol_cov = tt.switch(ok, chol_cov, tt.fill(chol_cov, 1)) | ||
I_k = tt.eye(k, dtype=theano.config.floatX) | ||
chol_cov = ifelse(ok, cov, I_k) | ||
delta_trans = solve_lower(chol_cov, delta.T).T | ||
|
||
inner = n * tt.eye(k) - tt.dot(delta_trans.T, delta_trans) | ||
inner = n * I_k - tt.dot(delta_trans.T, delta_trans) | ||
g_cov = solve_upper(chol_cov.T, inner) | ||
g_cov = solve_upper(chol_cov.T, g_cov.T) | ||
|
||
tau_delta = solve_upper(chol_cov.T, delta_trans.T) | ||
g_delta = tau_delta.T | ||
|
||
g_cov = tt.switch(ok, g_cov, -np.nan) | ||
g_delta = tt.switch(ok, g_delta, -np.nan) | ||
g_cov = ifelse(ok, f(g_cov), f(-np.nan * tt.zeros_like(g_cov))) | ||
g_delta = ifelse(ok, f(g_delta), f(-np.nan * tt.zeros_like(g_delta))) | ||
|
||
return [-0.5 * g_cov * g_logp, -g_delta * g_logp] | ||
|
||
return theano.OpFromGraph( | ||
[cov, delta], [logp], grad_overrides=dlogp, inline=True) | ||
|
||
|
||
class Cholesky(theano.Op): | ||
""" | ||
Return a triangular matrix square root of positive semi-definite `x`. | ||
|
||
This is a copy of the cholesky op in theano, that doesn't throw an | ||
error if the matrix is not positive definite, but instead returns | ||
nan. | ||
|
||
This has been merged upstream and we should switch to that | ||
version after the next theano release. | ||
|
||
L = cholesky(X, lower=True) implies dot(L, L.T) == X. | ||
""" | ||
__props__ = ('lower', 'destructive', 'nofail') | ||
|
||
def __init__(self, lower=True, nofail=False): | ||
self.lower = lower | ||
self.destructive = False | ||
self.nofail = nofail | ||
|
||
def make_node(self, x): | ||
x = tt.as_tensor_variable(x) | ||
if x.ndim != 2: | ||
raise ValueError('Matrix must me two dimensional.') | ||
return tt.Apply(self, [x], [x.type()]) | ||
|
||
def perform(self, node, inputs, outputs): | ||
x = inputs[0] | ||
z = outputs[0] | ||
try: | ||
z[0] = scipy.linalg.cholesky(x, lower=self.lower).astype(x.dtype) | ||
except (ValueError, scipy.linalg.LinAlgError): | ||
if self.nofail: | ||
z[0] = np.eye(x.shape[-1]) | ||
z[0][0, 0] = np.nan | ||
else: | ||
raise | ||
|
||
def grad(self, inputs, gradients): | ||
""" | ||
Cholesky decomposition reverse-mode gradient update. | ||
|
||
Symbolic expression for reverse-mode Cholesky gradient taken from [0]_ | ||
|
||
References | ||
---------- | ||
.. [0] I. Murray, "Differentiation of the Cholesky decomposition", | ||
http://arxiv.org/abs/1602.07527 | ||
|
||
""" | ||
|
||
x = inputs[0] | ||
dz = gradients[0] | ||
chol_x = self(x) | ||
ok = tt.all(tt.nlinalg.diag(chol_x) > 0) | ||
chol_x = tt.switch(ok, chol_x, tt.fill_diagonal(chol_x, 1)) | ||
dz = tt.switch(ok, dz, floatX(1)) | ||
|
||
# deal with upper triangular by converting to lower triangular | ||
if not self.lower: | ||
chol_x = chol_x.T | ||
dz = dz.T | ||
|
||
def tril_and_halve_diagonal(mtx): | ||
"""Extracts lower triangle of square matrix and halves diagonal.""" | ||
return tt.tril(mtx) - tt.diag(tt.diagonal(mtx) / 2.) | ||
|
||
def conjugate_solve_triangular(outer, inner): | ||
"""Computes L^{-T} P L^{-1} for lower-triangular L.""" | ||
solve = tt.slinalg.Solve(A_structure="upper_triangular") | ||
return solve(outer.T, solve(outer.T, inner.T).T) | ||
|
||
s = conjugate_solve_triangular( | ||
chol_x, tril_and_halve_diagonal(chol_x.T.dot(dz))) | ||
|
||
if self.lower: | ||
grad = tt.tril(s + s.T) - tt.diag(tt.diagonal(s)) | ||
else: | ||
grad = tt.triu(s + s.T) - tt.diag(tt.diagonal(s)) | ||
return [tt.switch(ok, grad, floatX(np.nan))] | ||
|
||
|
||
class SplineWrapper(theano.Op): | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this wasn't your idea, but I like
floatX
a lot more. I had to go search for the definition of this.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True. It's a mess though, I just seem to be wrapping FloatX around everything