Skip to content

Commit 4a713dc

Browse files
ferrinetwiecki
authored andcommitted
WIP: Implement opvi (#1694)
* migrate useful functions from previous PR (cherry picked from commit 9f61ab4) * opvi draft (cherry picked from commit d0997ff) * made some test work (cherry picked from commit b1a87d5) * refactored approximation to support aevb (without test) * refactor opvi delete unnecessary methods from operator, change method order * change log_q_local computation * add full rank approximation * add more_params argument to ObjectiveFunction.updates (aevb case) * refactor density computation in full rank approximation * typo: cast dict values to list * typo: cast dict values to list * typo: undefined T in dist_math * refactor gradient scaling as suggested in approximateinference.org/accepted/RoederEtAl2016.pdf * implement Langevin-Stein (LS) operator * fix docstring * add blank line in docs * refactor ObjectiveFunction * add not working LS Op test * experiments with not working LS Op * change activations * refactor networks * add step_function * remove Langevin Stein, done refactoring * remove Langevin Stein, done refactoring * change optimizers * refactor init params * implement tests * implement Inference * code style * test fix * add minibatch test (fails now) * add more tests for minibatch training * add logdet to FullRank approximation * add conversion of arrays to floatX * tiny changes * change number of iterations * fix test and pylint check * memoize functions in Objective function * Optimize code a lot * a bit more efficient pickling * add docs * Add MeanField -> FullRank parameter transfer * refactor MeanField and FullRank a bit * fix FullRank bug with shapes in random * refactor Model.flatten (CC @taku-y) * add `approximate` to inference * rename approximate->fit * change abbreviations * Fix bug with scaling input variable in aevb * fix theano bottleneck in graph * more efficient scaling for local vars * fix typo in local Q * add aevb test * refactor memoize to work with my objects * add tests for numpy view usage * pickle-hash fix * pickle-hash fix again * add node sampling + make up some code * add notebook with example * sample_proba explained
1 parent 01e5aef commit 4a713dc

File tree

13 files changed

+2745
-28
lines changed

13 files changed

+2745
-28
lines changed

docs/source/notebooks/bayesian_neural_network_opvi-advi.ipynb

Lines changed: 865 additions & 0 deletions
Large diffs are not rendered by default.

pymc3/distributions/dist_math.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import theano.tensor as tt
1010

1111
from .special import gammaln
12+
from ..math import logdet as _logdet
13+
14+
c = - 0.5 * np.log(2 * np.pi)
1215

1316

1417
def bound(logp, *conditions, **kwargs):
@@ -96,3 +99,117 @@ def i1(x):
9699
x**9 / 1474560 + x**11 / 176947200 + x**13 / 29727129600,
97100
np.e**x / (2 * np.pi * x)**0.5 * (1 - 3 / (8 * x) + 15 / (128 * x**2) + 315 / (3072 * x**3)
98101
+ 14175 / (98304 * x**4)))
102+
103+
104+
def sd2rho(sd):
105+
"""
106+
`sd -> rho` theano converter
107+
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
108+
return tt.log(tt.exp(sd) - 1)
109+
110+
111+
def rho2sd(rho):
112+
"""
113+
`rho -> sd` theano converter
114+
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
115+
return tt.log1p(tt.exp(rho))
116+
117+
118+
def log_normal(x, mean, **kwargs):
119+
"""
120+
Calculate logarithm of normal distribution at point `x`
121+
with given `mean` and `std`
122+
Parameters
123+
----------
124+
x : Tensor
125+
point of evaluation
126+
mean : Tensor
127+
mean of normal distribution
128+
kwargs : one of parameters `{sd, tau, w, rho}`
129+
Notes
130+
-----
131+
There are four variants for density parametrization.
132+
They are:
133+
1) standard deviation - `std`
134+
2) `w`, logarithm of `std` :math:`w = log(std)`
135+
3) `rho` that follows this equation :math:`rho = log(exp(std) - 1)`
136+
4) `tau` that follows this equation :math:`tau = std^{-1}`
137+
----
138+
"""
139+
sd = kwargs.get('sd')
140+
w = kwargs.get('w')
141+
rho = kwargs.get('rho')
142+
tau = kwargs.get('tau')
143+
eps = kwargs.get('eps', 0.0)
144+
check = sum(map(lambda a: a is not None, [sd, w, rho, tau]))
145+
if check > 1:
146+
raise ValueError('more than one required kwarg is passed')
147+
if check == 0:
148+
raise ValueError('none of required kwarg is passed')
149+
if sd is not None:
150+
std = sd
151+
elif w is not None:
152+
std = tt.exp(w)
153+
elif rho is not None:
154+
std = rho2sd(rho)
155+
else:
156+
std = tau**(-1)
157+
std += eps
158+
return c - tt.log(tt.abs_(std)) - (x - mean) ** 2 / (2 * std ** 2)
159+
160+
161+
def log_normal_mv(x, mean, gpu_compat=False, **kwargs):
162+
"""
163+
Calculate logarithm of normal distribution at point `x`
164+
with given `mean` and `sigma` matrix
165+
Parameters
166+
----------
167+
x : Tensor
168+
point of evaluation
169+
mean : Tensor
170+
mean of normal distribution
171+
kwargs : one of parameters `{cov, tau, chol}`
172+
173+
Flags
174+
----------
175+
gpu_compat : False, because LogDet is not GPU compatible yet.
176+
If this is set as true, the GPU compatible (but numerically unstable) log(det) is used.
177+
178+
Notes
179+
-----
180+
There are three variants for density parametrization.
181+
They are:
182+
1) covariance matrix - `cov`
183+
2) precision matrix - `tau`,
184+
3) cholesky decomposition matrix - `chol`
185+
----
186+
"""
187+
if gpu_compat:
188+
def logdet(m):
189+
return tt.log(tt.abs_(tt.nlinalg.det(m)))
190+
else:
191+
logdet = _logdet
192+
193+
T = kwargs.get('tau')
194+
S = kwargs.get('cov')
195+
L = kwargs.get('chol')
196+
check = sum(map(lambda a: a is not None, [T, S, L]))
197+
if check > 1:
198+
raise ValueError('more than one required kwarg is passed')
199+
if check == 0:
200+
raise ValueError('none of required kwarg is passed')
201+
# avoid unnecessary computations
202+
if L is not None:
203+
S = L.dot(L.T)
204+
T = tt.nlinalg.matrix_inverse(S)
205+
log_det = -logdet(S)
206+
elif T is not None:
207+
log_det = logdet(T)
208+
else:
209+
T = tt.nlinalg.matrix_inverse(S)
210+
log_det = -logdet(S)
211+
delta = x - mean
212+
k = S.shape[0]
213+
result = k * tt.log(2 * np.pi) - log_det
214+
result += delta.dot(T).dot(delta)
215+
return -1 / 2. * result

pymc3/math.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@ def logit(p):
3737
return tt.log(p / (1 - p))
3838

3939

40+
def flatten_list(tensors):
41+
return tt.concatenate([var.ravel() for var in tensors])
42+
43+
4044
class LogDet(Op):
4145
"""Computes the logarithm of absolute determinant of a square
4246
matrix M, log(abs(det(M))), on CPU. Avoids det(M) overflow/
4347
underflow.
4448
4549
Note: Once PR #3959 (https://github.com/Theano/Theano/pull/3959/) by harpone is merged,
46-
this must be removed.
50+
this must be removed.
4751
"""
4852
def make_node(self, x):
4953
x = theano.tensor.as_tensor_variable(x)

pymc3/memoize.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import functools
2+
import pickle
23

34

45
def memoize(obj):
@@ -23,8 +24,16 @@ def hashable(a):
2324
Turn some unhashable objects into hashable ones.
2425
"""
2526
if isinstance(a, dict):
26-
return hashable(a.items())
27+
return hashable(tuple((hashable(a1), hashable(a2)) for a1, a2 in a.items()))
2728
try:
28-
return tuple(map(hashable, a))
29-
except:
30-
return a
29+
return hash(a)
30+
except TypeError:
31+
pass
32+
# Not hashable >>>
33+
try:
34+
return hash(pickle.dumps(a))
35+
except Exception:
36+
if hasattr(a, '__dict__'):
37+
return hashable(a.__dict__)
38+
else:
39+
return id(a)

pymc3/model.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
1+
import collections
12
import threading
23
import six
34

45
import numpy as np
56
import scipy.sparse as sps
6-
import theano
7-
import theano.tensor as tt
87
import theano.sparse as sparse
8+
from theano import theano, tensor as tt
99
from theano.tensor.var import TensorVariable
1010

1111
import pymc3 as pm
12+
from pymc3.math import flatten_list
1213
from .memoize import memoize
1314
from .theanof import gradient, hessian, inputvars, generator
1415
from .vartypes import typefilter, discrete_types, continuous_types, isgenerator
@@ -19,6 +20,8 @@
1920
'Point', 'Deterministic', 'Potential'
2021
]
2122

23+
FlatView = collections.namedtuple('FlatView', 'input, replacements, view')
24+
2225

2326
class InstanceMethod(object):
2427
"""Class for hiding references to instance methods so they can be pickled.
@@ -172,8 +175,10 @@ def fastd2logp(self, vars=None):
172175
@property
173176
def logpt(self):
174177
"""Theano scalar of log-probability of the model"""
175-
176-
return tt.sum(self.logp_elemwiset) * self.scaling
178+
if getattr(self, 'total_size', None) is not None:
179+
return tt.sum(self.logp_elemwiset) * self.scaling
180+
else:
181+
return tt.sum(self.logp_elemwiset)
177182

178183
@property
179184
def scaling(self):
@@ -659,6 +664,33 @@ def profile(self, outs, n=1000, point=None, profile=True, *args, **kwargs):
659664

660665
return f.profile
661666

667+
def flatten(self, vars=None):
668+
"""Flattens model's input and returns:
669+
FlatView with
670+
* input vector variable
671+
* replacements `input_var -> vars`
672+
* view {variable: VarMap}
673+
674+
Parameters
675+
----------
676+
vars : list of variables or None
677+
if None, then all model.free_RVs are used for flattening input
678+
679+
Returns
680+
-------
681+
flat_view
682+
"""
683+
if vars is None:
684+
vars = self.free_RVs
685+
order = ArrayOrdering(vars)
686+
inputvar = tt.vector('flat_view', dtype=theano.config.floatX)
687+
inputvar.tag.test_value = flatten_list(vars).tag.test_value
688+
replacements = {self.named_vars[name]: inputvar[slc].reshape(shape).astype(dtype)
689+
for name, slc, shape, dtype in order.vmap}
690+
view = {vm.var: vm for vm in order.vmap}
691+
flat_view = FlatView(inputvar, replacements, view)
692+
return flat_view
693+
662694

663695
def fn(outs, mode=None, model=None, *args, **kwargs):
664696
"""Compiles a Theano function which returns the values of `outs` and

pymc3/tests/helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import unittest
2-
import numpy.random as nr
32
from logging.handlers import BufferingHandler
3+
import numpy.random as nr
4+
from theano.sandbox.rng_mrg import MRG_RandomStreams
5+
from ..theanof import set_tt_rng, tt_rng
46

57

68
class SeededTest(unittest.TestCase):
@@ -12,6 +14,11 @@ def setUpClass(cls):
1214

1315
def setUp(self):
1416
nr.seed(self.random_seed)
17+
self.old_tt_rng = tt_rng()
18+
set_tt_rng(MRG_RandomStreams(self.random_seed))
19+
20+
def tearDown(self):
21+
set_tt_rng(self.old_tt_rng)
1522

1623
class TestHandler(BufferingHandler):
1724
def __init__(self, matcher):

pymc3/tests/test_examples.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ class TestLatentOccupancy(SeededTest):
234234
Copyright (c) 2008 University of Otago. All rights reserved.
235235
"""
236236
def setUp(self):
237+
super(TestLatentOccupancy, self).setUp()
237238
# Sample size
238239
n = 100
239240
# True mean count, given occupancy

pymc3/tests/test_math.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,15 @@
44
from pymc3.math import LogDet, logdet, probit, invprobit
55
from .helpers import SeededTest
66

7+
78
def test_probit():
89
p = np.array([0.01, 0.25, 0.5, 0.75, 0.99])
910
np.testing.assert_allclose(invprobit(probit(p)).eval(), p, atol=1e-5)
1011

11-
class TestLogDet(SeededTest):
1212

13+
class TestLogDet(SeededTest):
1314
def setUp(self):
15+
super(TestLogDet, self).setUp()
1416
utt.seed_rng()
1517
self.op_class = LogDet
1618
self.op = logdet

0 commit comments

Comments
 (0)