Skip to content

Refactoring to support Structured VI #2416

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 97 commits into from
Sep 8, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
97 commits
Select commit Hold shift + click to select a range
15f0beb
start refactoring
ferrine Jul 13, 2017
95cd81e
fix typo
ferrine Jul 14, 2017
5cec1df
some more refactoring
ferrine Jul 15, 2017
6b23dc8
Delete old approximation class
ferrine Jul 15, 2017
b584fc1
Yet another design
ferrine Jul 19, 2017
e4e8db4
design refactoring
ferrine Jul 20, 2017
9f6f50a
design refactoring
ferrine Jul 22, 2017
1f58bde
no-scan batched diag
ferrine Jul 23, 2017
0bddeae
even faster batched_diag
ferrine Jul 23, 2017
d6a2e55
even faster batched_diag
ferrine Jul 23, 2017
8df4790
flows refactoring
ferrine Jul 23, 2017
08cd389
reshape user params
ferrine Jul 23, 2017
43b7041
wrote simple init test
ferrine Jul 24, 2017
7909871
wrote simple rest for sampling
ferrine Jul 25, 2017
f1a6dce
add Empitical to test
ferrine Jul 25, 2017
47e4d85
some flows pass aevb sample test
ferrine Jul 27, 2017
704f90f
fix planar/radial aevb flow
ferrine Jul 29, 2017
59ad34d
create test for mini logq in (aevb)groups
ferrine Jul 30, 2017
cc92840
some assertions for register_flow
ferrine Jul 31, 2017
7135bdc
fix MvNormal error
ferrine Aug 1, 2017
659fb67
move batched diag
ferrine Aug 1, 2017
1439c1d
implement more tests
ferrine Aug 1, 2017
688fa11
refactor api following the discussion
ferrine Aug 3, 2017
6b51c7c
design for single group shortcuts
ferrine Aug 4, 2017
a006cbd
use another default
ferrine Aug 4, 2017
24894ef
Add test_elbo test
ferrine Aug 4, 2017
13831dd
Inferences work, sample_node does not
ferrine Aug 4, 2017
28f01d3
fix typo
ferrine Aug 5, 2017
abb9001
changed scale flow parametsization, also a bit more stable sd2rho
ferrine Aug 5, 2017
3f2299f
implement some more tests fix bugs
ferrine Aug 5, 2017
7af4cb4
rename file
ferrine Aug 5, 2017
b7c0b13
tests pass locally for me
ferrine Aug 5, 2017
5c3d95b
fix some typos
ferrine Aug 5, 2017
4fd8490
reduce n in profiling test
ferrine Aug 5, 2017
ca2c19d
move conftest import
ferrine Aug 5, 2017
347482c
fix typos in __all__
ferrine Aug 5, 2017
2d46990
fix raised exception
ferrine Aug 5, 2017
223e4cd
new line
ferrine Aug 5, 2017
045882e
implement block diagonal matrix
ferrine Aug 6, 2017
39cb01d
Implement batched approximations
ferrine Aug 6, 2017
3d73939
move not_raises to helpers
ferrine Aug 6, 2017
05d9768
fix some errors, add docstring
ferrine Aug 6, 2017
556abd3
fix typo in test
ferrine Aug 6, 2017
759715a
correct name of property
ferrine Aug 6, 2017
c65a145
better error message
ferrine Aug 6, 2017
7f607e6
add missing scaling to local logq
ferrine Aug 6, 2017
e9656cf
convert bool to int8
ferrine Aug 6, 2017
4d3d402
fix refactoring typos
ferrine Aug 6, 2017
eb4728a
FullRank fix, CC @junpenglao
ferrine Aug 7, 2017
2337102
good point @junpenglao, it was unintended
ferrine Aug 8, 2017
b0f6c0e
add a shortcut for block diagonal
ferrine Aug 8, 2017
4879b6b
fix py2 error
ferrine Aug 9, 2017
d556b5f
use another signature in block_diagonal
ferrine Aug 10, 2017
1615b5c
follow @taku-y suggestions
ferrine Aug 10, 2017
5652687
add better documentation
ferrine Aug 10, 2017
725e965
fix syntax error for py2
ferrine Aug 10, 2017
c8f9c77
another syntax error
ferrine Aug 10, 2017
7e5e3f0
added `refine` method to Inference, updated docs
ferrine Aug 11, 2017
534f9e9
fix typo, refine docs
ferrine Aug 11, 2017
869da0e
fixes to docs
ferrine Aug 11, 2017
e7a7a23
FullRank rowwise fix
ferrine Aug 14, 2017
ee76876
iterate with loss typo
ferrine Aug 14, 2017
b5c807d
better refine layout
ferrine Aug 14, 2017
1bd74fb
fix error in single group, fix linter
ferrine Aug 14, 2017
44557fa
update bayes nnet notebook
ferrine Aug 17, 2017
2be5fba
fix typo
ferrine Aug 17, 2017
b3f9de1
update GLM-hierarchical-advi-minibatch.ipynb
ferrine Aug 17, 2017
899e2ef
update gaussian-mixture-model-advi.ipynb
ferrine Aug 17, 2017
62b341f
better test
ferrine Aug 17, 2017
ab95093
update update gaussian_process.ipynb
ferrine Aug 17, 2017
da14207
fix shape issue
ferrine Aug 25, 2017
d6c872f
move get transformed to util
ferrine Aug 26, 2017
62cd85c
pylint fix
ferrine Aug 26, 2017
c97a4de
change __hash__
ferrine Aug 26, 2017
85a1f0e
fix __str__ for not initialized groups
ferrine Aug 26, 2017
9168182
change replacements apply method
ferrine Aug 26, 2017
bc73513
update lda-advi-aevb.ipynb
ferrine Aug 27, 2017
b5ac106
use strict replacements for flat_input
ferrine Aug 27, 2017
53f8c4f
fix typo in error
ferrine Aug 27, 2017
d437aab
update empirical-approx-overview.ipynb
ferrine Aug 27, 2017
4f185a7
update normalizing_flows_overview.ipynb
ferrine Aug 27, 2017
85af2f6
update api_quickstart.ipynb
ferrine Aug 27, 2017
c2cf5d3
found and fixed a bug with aevb
ferrine Aug 28, 2017
2f226a3
update convolutional_vae_keras_advi.ipynb
ferrine Aug 28, 2017
51108d4
fix typo
ferrine Aug 29, 2017
298fa23
move to floatX
ferrine Aug 29, 2017
0a02391
update variational_api_quickstart.ipynb
ferrine Aug 29, 2017
d20b4cc
update normalizing_flows_overview.ipynb
ferrine Aug 31, 2017
b8d4816
update warning
ferrine Aug 31, 2017
f0f2e58
fix typos
ferrine Aug 31, 2017
2dfd78c
refine docs, add dev-docs
ferrine Sep 1, 2017
cc8fd34
Better docs for approximations
ferrine Sep 1, 2017
ff5d276
Add some notes, CC @junpenglao
ferrine Sep 1, 2017
f16c823
Refactor notes, CC @junpenglao
ferrine Sep 1, 2017
1453b85
autopep8 pymc3/variational --select=E101,E121 --in-place --recursive
ferrine Sep 1, 2017
6af086a
remove duplicate doc for asvgd
ferrine Sep 1, 2017
023876b
errmsg fix
ferrine Sep 4, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 35 additions & 29 deletions docs/source/notebooks/GLM-hierarchical-advi-minibatch.ipynb

Large diffs are not rendered by default.

334 changes: 130 additions & 204 deletions docs/source/notebooks/api_quickstart.ipynb

Large diffs are not rendered by default.

196 changes: 49 additions & 147 deletions docs/source/notebooks/bayesian_neural_network_advi.ipynb

Large diffs are not rendered by default.

40 changes: 29 additions & 11 deletions docs/source/notebooks/convolutional_vae_keras_advi.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"%autosave 0\n",
Expand Down Expand Up @@ -121,7 +123,9 @@
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"from keras.models import Sequential\n",
Expand Down Expand Up @@ -172,7 +176,9 @@
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def cnn_enc(xs, latent_dim, nb_filters=64, nb_conv=3, intermediate_dim=128):\n",
Expand Down Expand Up @@ -216,7 +222,9 @@
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Encoder:\n",
Expand Down Expand Up @@ -292,7 +300,9 @@
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"def cnn_dec(zs, nb_filters=64, nb_conv=3, output_shape=(1, 28, 28)):\n",
Expand Down Expand Up @@ -331,7 +341,9 @@
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"class Decoder:\n",
Expand Down Expand Up @@ -385,7 +397,9 @@
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"# Constants\n",
Expand Down Expand Up @@ -454,7 +468,9 @@
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"with pm.Model() as model:\n",
Expand All @@ -478,10 +494,12 @@
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"local_RVs = OrderedDict({zs: (enc.means, enc.rhos)})"
"local_RVs = OrderedDict({zs: dict(mu=enc.means, rho=enc.rhos)})"
]
},
{
Expand Down Expand Up @@ -671,7 +689,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
"version": "3.6.0b4"
},
"latex_envs": {
"bibliofile": "biblio.bib",
Expand Down
109 changes: 39 additions & 70 deletions docs/source/notebooks/empirical-approx-overview.ipynb

Large diffs are not rendered by default.

101 changes: 49 additions & 52 deletions docs/source/notebooks/gaussian-mixture-model-advi.ipynb

Large diffs are not rendered by default.

69 changes: 45 additions & 24 deletions docs/source/notebooks/gaussian_process.ipynb

Large diffs are not rendered by default.

198 changes: 116 additions & 82 deletions docs/source/notebooks/lda-advi-aevb.ipynb

Large diffs are not rendered by default.

310 changes: 131 additions & 179 deletions docs/source/notebooks/normalizing_flows_overview.ipynb

Large diffs are not rendered by default.

357 changes: 168 additions & 189 deletions docs/source/notebooks/variational_api_quickstart.ipynb

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
creating custom backends).
"""
import numpy as np
from ..model import modelcontext
import warnings
import theano.tensor as tt

from ..model import modelcontext


class BackendError(Exception):
pass

Expand All @@ -24,11 +26,13 @@ class BaseTrace(object):
vars : list of variables
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
test_point : dict
use different test point that might be with changed variables shapes
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean a different test point can be used in case of variable shape changes?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that works for me

"""

supports_sampler_stats = False

def __init__(self, name, model=None, vars=None):
def __init__(self, name, model=None, vars=None, test_point=None):
self.name = name

model = modelcontext(model)
Expand All @@ -41,7 +45,13 @@ def __init__(self, name, model=None, vars=None):

# Get variable shapes. Most backends will need this
# information.
var_values = list(zip(self.varnames, self.fn(model.test_point)))
if test_point is None:
test_point = model.test_point
else:
test_point_ = model.test_point.copy()
test_point_.update(test_point)
test_point = test_point_
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_shapes = {var: value.shape
for var, value in var_values}
self.var_dtypes = {var: value.dtype
Expand Down Expand Up @@ -72,7 +82,6 @@ def _set_sampler_vars(self, sampler_vars):

self.sampler_vars = sampler_vars


def setup(self, draws, chain, sampler_vars=None):
"""Perform chain-specific setup.

Expand Down
24 changes: 13 additions & 11 deletions pymc3/backends/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,26 @@ def activator(instance):
class HDF5(base.BaseTrace):
"""HDF5 trace object

Parameters
----------
name : str
Name of backend. This has no meaning for the HDF5 backend.
model : Model
If None, the model is taken from the `with` context.
vars : list of variables
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
Parameters
----------
name : str
Name of backend. This has no meaning for the HDF5 backend.
model : Model
If None, the model is taken from the `with` context.
vars : list of variables
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
test_point : dict
use different test point that might be with changed variables shapes
"""

supports_sampler_stats = True

def __init__(self, name=None, model=None, vars=None):
def __init__(self, name=None, model=None, vars=None, test_point=None):
self.hdf5_file = None
self.draw_idx = 0
self.draws = None
super(HDF5, self).__init__(name, model, vars)
super(HDF5, self).__init__(name, model, vars, test_point)

def _get_sampler_stats(self, varname, sampler_idx, burn, thin):
with self.activate_file:
Expand Down
4 changes: 2 additions & 2 deletions pymc3/backends/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class NDArray(base.BaseTrace):

supports_sampler_stats = True

def __init__(self, name=None, model=None, vars=None):
super(NDArray, self).__init__(name, model, vars)
def __init__(self, name=None, model=None, vars=None, test_point=None):
super(NDArray, self).__init__(name, model, vars, test_point)
self.draw_idx = 0
self.draws = None
self.samples = {}
Expand Down
6 changes: 4 additions & 2 deletions pymc3/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@ class SQLite(base.BaseTrace):
vars : list of variables
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
test_point : dict
use different test point that might be with changed variables shapes
"""

def __init__(self, name, model=None, vars=None):
super(SQLite, self).__init__(name, model, vars)
def __init__(self, name, model=None, vars=None, test_point=None):
super(SQLite, self).__init__(name, model, vars, test_point)
self._var_cols = {}
self.var_inserts = {} # varname -> insert statement
self.draw_idx = 0
Expand Down
6 changes: 4 additions & 2 deletions pymc3/backends/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,14 @@ class Text(base.BaseTrace):
vars : list of variables
Sampling values will be stored for these variables. If None,
`model.unobserved_RVs` is used.
test_point : dict
use different test point that might be with changed variables shapes
"""

def __init__(self, name, model=None, vars=None):
def __init__(self, name, model=None, vars=None, test_point=None):
if not os.path.exists(name):
os.mkdir(name)
super(Text, self).__init__(name, model, vars)
super(Text, self).__init__(name, model, vars, test_point)

self.flat_names = {v: ttab.create_flat_names(v, shape)
for v, shape in self.var_shapes.items()}
Expand Down
36 changes: 16 additions & 20 deletions pymc3/blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,26 @@ class ArrayOrdering(object):

def __init__(self, vars):
self.vmap = []
self._by_name = {}
size = 0
self.by_name = {}
self.size = 0

for var in vars:
name = var.name
if name is None:
raise ValueError('Unnamed variable in ArrayOrdering.')
if name in self._by_name:
if name in self.by_name:
raise ValueError('Name of variable not unique: %s.' % name)
if not hasattr(var, 'dshape') or not hasattr(var, 'dsize'):
raise ValueError('Shape of variable not known %s' % name)

slc = slice(size, size + var.dsize)
slc = slice(self.size, self.size + var.dsize)
varmap = VarMap(name, slc, var.dshape, var.dtype)
self.vmap.append(varmap)
self._by_name[name] = varmap
size += var.dsize

self.size = size
self.by_name[name] = varmap
self.size += var.dsize

def __getitem__(self, key):
return self._by_name[key]
return self.by_name[key]


class DictToArrayBijection(object):
Expand Down Expand Up @@ -122,24 +120,22 @@ class ListArrayOrdering(object):
"""

def __init__(self, list_arrays, intype='numpy'):
if intype not in {'tensor', 'numpy'}:
raise ValueError("intype not in {'tensor', 'numpy'}")
self.vmap = []
dim = 0

count = 0
self.intype = intype
self.size = 0
for array in list_arrays:
if intype == 'tensor':
if self.intype == 'tensor':
name = array.name
array = array.tag.test_value
elif intype == 'numpy':
else:
name = 'numpy'

slc = slice(dim, dim + array.size)
slc = slice(self.size, self.size + array.size)
self.vmap.append(DataMap(
count, slc, array.shape, array.dtype, name))
dim += array.size
count += 1

self.size = dim
len(self.vmap), slc, array.shape, array.dtype, name))
self.size += array.size


class ListToArrayBijection(object):
Expand Down
6 changes: 3 additions & 3 deletions pymc3/distributions/dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def i0(x):
return tt.switch(tt.lt(x, 5), 1. + x**2 / 4. + x**4 / 64. + x**6 / 2304. + x**8 / 147456.
+ x**10 / 14745600. + x**12 / 2123366400.,
np.e**x / (2. * np.pi * x)**0.5 * (1. + 1. / (8. * x) + 9. / (128. * x**2) + 225. / (3072 * x**3)
+ 11025. / (98304. * x**4)))
+ 11025. / (98304. * x**4)))


def i1(x):
Expand All @@ -108,14 +108,14 @@ def sd2rho(sd):
"""
`sd -> rho` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log(tt.exp(sd) - 1.)
return tt.log(tt.exp(tt.abs_(sd)) - 1.)


def rho2sd(rho):
"""
`rho -> sd` theano converter
:math:`mu + sd*e = mu + log(1+exp(rho))*e`"""
return tt.log1p(tt.exp(rho))
return tt.nnet.softplus(rho)


def log_normal(x, mean, **kwargs):
Expand Down
Loading