Skip to content

Commit 569a82e

Browse files
authored
Merge pull request #2416 from ferrine/approximations_refactoring
Refactoring to support Structured VI
2 parents 33c8bff + 023876b commit 569a82e

32 files changed

+3341
-2718
lines changed

docs/source/notebooks/GLM-hierarchical-advi-minibatch.ipynb

Lines changed: 35 additions & 29 deletions
Large diffs are not rendered by default.

docs/source/notebooks/api_quickstart.ipynb

Lines changed: 130 additions & 204 deletions
Large diffs are not rendered by default.

docs/source/notebooks/bayesian_neural_network_advi.ipynb

Lines changed: 49 additions & 147 deletions
Large diffs are not rendered by default.

docs/source/notebooks/convolutional_vae_keras_advi.ipynb

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
{
3232
"cell_type": "code",
3333
"execution_count": null,
34-
"metadata": {},
34+
"metadata": {
35+
"collapsed": true
36+
},
3537
"outputs": [],
3638
"source": [
3739
"%autosave 0\n",
@@ -121,7 +123,9 @@
121123
{
122124
"cell_type": "code",
123125
"execution_count": 4,
124-
"metadata": {},
126+
"metadata": {
127+
"collapsed": true
128+
},
125129
"outputs": [],
126130
"source": [
127131
"from keras.models import Sequential\n",
@@ -172,7 +176,9 @@
172176
{
173177
"cell_type": "code",
174178
"execution_count": 8,
175-
"metadata": {},
179+
"metadata": {
180+
"collapsed": true
181+
},
176182
"outputs": [],
177183
"source": [
178184
"def cnn_enc(xs, latent_dim, nb_filters=64, nb_conv=3, intermediate_dim=128):\n",
@@ -216,7 +222,9 @@
216222
{
217223
"cell_type": "code",
218224
"execution_count": 9,
219-
"metadata": {},
225+
"metadata": {
226+
"collapsed": true
227+
},
220228
"outputs": [],
221229
"source": [
222230
"class Encoder:\n",
@@ -292,7 +300,9 @@
292300
{
293301
"cell_type": "code",
294302
"execution_count": 10,
295-
"metadata": {},
303+
"metadata": {
304+
"collapsed": true
305+
},
296306
"outputs": [],
297307
"source": [
298308
"def cnn_dec(zs, nb_filters=64, nb_conv=3, output_shape=(1, 28, 28)):\n",
@@ -331,7 +341,9 @@
331341
{
332342
"cell_type": "code",
333343
"execution_count": 11,
334-
"metadata": {},
344+
"metadata": {
345+
"collapsed": true
346+
},
335347
"outputs": [],
336348
"source": [
337349
"class Decoder:\n",
@@ -385,7 +397,9 @@
385397
{
386398
"cell_type": "code",
387399
"execution_count": 12,
388-
"metadata": {},
400+
"metadata": {
401+
"collapsed": true
402+
},
389403
"outputs": [],
390404
"source": [
391405
"# Constants\n",
@@ -454,7 +468,9 @@
454468
{
455469
"cell_type": "code",
456470
"execution_count": 14,
457-
"metadata": {},
471+
"metadata": {
472+
"collapsed": true
473+
},
458474
"outputs": [],
459475
"source": [
460476
"with pm.Model() as model:\n",
@@ -478,10 +494,12 @@
478494
{
479495
"cell_type": "code",
480496
"execution_count": 15,
481-
"metadata": {},
497+
"metadata": {
498+
"collapsed": true
499+
},
482500
"outputs": [],
483501
"source": [
484-
"local_RVs = OrderedDict({zs: (enc.means, enc.rhos)})"
502+
"local_RVs = OrderedDict({zs: dict(mu=enc.means, rho=enc.rhos)})"
485503
]
486504
},
487505
{
@@ -671,7 +689,7 @@
671689
"name": "python",
672690
"nbconvert_exporter": "python",
673691
"pygments_lexer": "ipython3",
674-
"version": "3.6.1"
692+
"version": "3.6.0b4"
675693
},
676694
"latex_envs": {
677695
"bibliofile": "biblio.bib",

docs/source/notebooks/empirical-approx-overview.ipynb

Lines changed: 39 additions & 70 deletions
Large diffs are not rendered by default.

docs/source/notebooks/gaussian-mixture-model-advi.ipynb

Lines changed: 49 additions & 52 deletions
Large diffs are not rendered by default.

docs/source/notebooks/gaussian_process.ipynb

Lines changed: 45 additions & 24 deletions
Large diffs are not rendered by default.

docs/source/notebooks/lda-advi-aevb.ipynb

Lines changed: 116 additions & 82 deletions
Large diffs are not rendered by default.

docs/source/notebooks/normalizing_flows_overview.ipynb

Lines changed: 131 additions & 179 deletions
Large diffs are not rendered by default.

docs/source/notebooks/variational_api_quickstart.ipynb

Lines changed: 168 additions & 189 deletions
Large diffs are not rendered by default.

pymc3/backends/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
creating custom backends).
55
"""
66
import numpy as np
7-
from ..model import modelcontext
87
import warnings
98
import theano.tensor as tt
109

10+
from ..model import modelcontext
11+
12+
1113
class BackendError(Exception):
1214
pass
1315

@@ -24,11 +26,13 @@ class BaseTrace(object):
2426
vars : list of variables
2527
Sampling values will be stored for these variables. If None,
2628
`model.unobserved_RVs` is used.
29+
test_point : dict
30+
use different test point that might be with changed variables shapes
2731
"""
2832

2933
supports_sampler_stats = False
3034

31-
def __init__(self, name, model=None, vars=None):
35+
def __init__(self, name, model=None, vars=None, test_point=None):
3236
self.name = name
3337

3438
model = modelcontext(model)
@@ -41,7 +45,13 @@ def __init__(self, name, model=None, vars=None):
4145

4246
# Get variable shapes. Most backends will need this
4347
# information.
44-
var_values = list(zip(self.varnames, self.fn(model.test_point)))
48+
if test_point is None:
49+
test_point = model.test_point
50+
else:
51+
test_point_ = model.test_point.copy()
52+
test_point_.update(test_point)
53+
test_point = test_point_
54+
var_values = list(zip(self.varnames, self.fn(test_point)))
4555
self.var_shapes = {var: value.shape
4656
for var, value in var_values}
4757
self.var_dtypes = {var: value.dtype
@@ -72,7 +82,6 @@ def _set_sampler_vars(self, sampler_vars):
7282

7383
self.sampler_vars = sampler_vars
7484

75-
7685
def setup(self, draws, chain, sampler_vars=None):
7786
"""Perform chain-specific setup.
7887

pymc3/backends/hdf5.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,26 @@ def activator(instance):
1818
class HDF5(base.BaseTrace):
1919
"""HDF5 trace object
2020
21-
Parameters
22-
----------
23-
name : str
24-
Name of backend. This has no meaning for the HDF5 backend.
25-
model : Model
26-
If None, the model is taken from the `with` context.
27-
vars : list of variables
28-
Sampling values will be stored for these variables. If None,
29-
`model.unobserved_RVs` is used.
21+
Parameters
22+
----------
23+
name : str
24+
Name of backend. This has no meaning for the HDF5 backend.
25+
model : Model
26+
If None, the model is taken from the `with` context.
27+
vars : list of variables
28+
Sampling values will be stored for these variables. If None,
29+
`model.unobserved_RVs` is used.
30+
test_point : dict
31+
use different test point that might be with changed variables shapes
3032
"""
3133

3234
supports_sampler_stats = True
3335

34-
def __init__(self, name=None, model=None, vars=None):
36+
def __init__(self, name=None, model=None, vars=None, test_point=None):
3537
self.hdf5_file = None
3638
self.draw_idx = 0
3739
self.draws = None
38-
super(HDF5, self).__init__(name, model, vars)
40+
super(HDF5, self).__init__(name, model, vars, test_point)
3941

4042
def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
4143
with self.activate_file:

pymc3/backends/ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ class NDArray(base.BaseTrace):
2222

2323
supports_sampler_stats = True
2424

25-
def __init__(self, name=None, model=None, vars=None):
26-
super(NDArray, self).__init__(name, model, vars)
25+
def __init__(self, name=None, model=None, vars=None, test_point=None):
26+
super(NDArray, self).__init__(name, model, vars, test_point)
2727
self.draw_idx = 0
2828
self.draws = None
2929
self.samples = {}

pymc3/backends/sqlite.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,12 @@ class SQLite(base.BaseTrace):
7070
vars : list of variables
7171
Sampling values will be stored for these variables. If None,
7272
`model.unobserved_RVs` is used.
73+
test_point : dict
74+
use different test point that might be with changed variables shapes
7375
"""
7476

75-
def __init__(self, name, model=None, vars=None):
76-
super(SQLite, self).__init__(name, model, vars)
77+
def __init__(self, name, model=None, vars=None, test_point=None):
78+
super(SQLite, self).__init__(name, model, vars, test_point)
7779
self._var_cols = {}
7880
self.var_inserts = {} # varname -> insert statement
7981
self.draw_idx = 0

pymc3/backends/text.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ class Text(base.BaseTrace):
3636
vars : list of variables
3737
Sampling values will be stored for these variables. If None,
3838
`model.unobserved_RVs` is used.
39+
test_point : dict
40+
use different test point that might be with changed variables shapes
3941
"""
4042

41-
def __init__(self, name, model=None, vars=None):
43+
def __init__(self, name, model=None, vars=None, test_point=None):
4244
if not os.path.exists(name):
4345
os.mkdir(name)
44-
super(Text, self).__init__(name, model, vars)
46+
super(Text, self).__init__(name, model, vars, test_point)
4547

4648
self.flat_names = {v: ttab.create_flat_names(v, shape)
4749
for v, shape in self.var_shapes.items()}

pymc3/blocking.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,26 @@ class ArrayOrdering(object):
2323

2424
def __init__(self, vars):
2525
self.vmap = []
26-
self._by_name = {}
27-
size = 0
26+
self.by_name = {}
27+
self.size = 0
2828

2929
for var in vars:
3030
name = var.name
3131
if name is None:
3232
raise ValueError('Unnamed variable in ArrayOrdering.')
33-
if name in self._by_name:
33+
if name in self.by_name:
3434
raise ValueError('Name of variable not unique: %s.' % name)
3535
if not hasattr(var, 'dshape') or not hasattr(var, 'dsize'):
3636
raise ValueError('Shape of variable not known %s' % name)
3737

38-
slc = slice(size, size + var.dsize)
38+
slc = slice(self.size, self.size + var.dsize)
3939
varmap = VarMap(name, slc, var.dshape, var.dtype)
4040
self.vmap.append(varmap)
41-
self._by_name[name] = varmap
42-
size += var.dsize
43-
44-
self.size = size
41+
self.by_name[name] = varmap
42+
self.size += var.dsize
4543

4644
def __getitem__(self, key):
47-
return self._by_name[key]
45+
return self.by_name[key]
4846

4947

5048
class DictToArrayBijection(object):
@@ -122,24 +120,22 @@ class ListArrayOrdering(object):
122120
"""
123121

124122
def __init__(self, list_arrays, intype='numpy'):
123+
if intype not in {'tensor', 'numpy'}:
124+
raise ValueError("intype not in {'tensor', 'numpy'}")
125125
self.vmap = []
126-
dim = 0
127-
128-
count = 0
126+
self.intype = intype
127+
self.size = 0
129128
for array in list_arrays:
130-
if intype == 'tensor':
129+
if self.intype == 'tensor':
131130
name = array.name
132131
array = array.tag.test_value
133-
elif intype == 'numpy':
132+
else:
134133
name = 'numpy'
135134

136-
slc = slice(dim, dim + array.size)
135+
slc = slice(self.size, self.size + array.size)
137136
self.vmap.append(DataMap(
138-
count, slc, array.shape, array.dtype, name))
139-
dim += array.size
140-
count += 1
141-
142-
self.size = dim
137+
len(self.vmap), slc, array.shape, array.dtype, name))
138+
self.size += array.size
143139

144140

145141
class ListToArrayBijection(object):

pymc3/distributions/dist_math.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def i0(x):
9191
return tt.switch(tt.lt(x, 5), 1. + x**2 / 4. + x**4 / 64. + x**6 / 2304. + x**8 / 147456.
9292
+ x**10 / 14745600. + x**12 / 2123366400.,
9393
np.e**x / (2. * np.pi * x)**0.5 * (1. + 1. / (8. * x) + 9. / (128. * x**2) + 225. / (3072 * x**3)
94-
+ 11025. / (98304. * x**4)))
94+
+ 11025. / (98304. * x**4)))
9595

9696

9797
def i1(x):
@@ -108,14 +108,14 @@ def sd2rho(sd):
108108
"""
109109
`sd -> rho` theano converter
110110
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
111-
return tt.log(tt.exp(sd) - 1.)
111+
return tt.log(tt.exp(tt.abs_(sd)) - 1.)
112112

113113

114114
def rho2sd(rho):
115115
"""
116116
`rho -> sd` theano converter
117117
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
118-
return tt.log1p(tt.exp(rho))
118+
return tt.nnet.softplus(rho)
119119

120120

121121
def log_normal(x, mean, **kwargs):

0 commit comments

Comments
 (0)