Skip to content

Commit 6691d7e

Browse files
authored
Merge pull request #3557 from fonnesbeck/alt_vi_backend
Added backend option for VI sample trace
2 parents 0ea44a4 + 96273ee commit 6691d7e

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
- Sequential Monte Carlo - Approximate Bayesian Computation step method is now available. The implementation is in an experimental stage and will be further improved.
99
- Added `Matern12` covariance function for Gaussian processes. This is the Matern kernel with nu=1/2.
1010
- Progressbar reports number of divergences in real time, when available [#3547](https://github.com/pymc-devs/pymc3/pull/3547).
11+
- Sampling from variational approximation now allows for alternative trace backends [#3550].
1112

1213
### Maintenance
1314
- Moved math operations out of `Rice`, `TruncatedNormal`, `Triangular` and `ZeroInflatedNegativeBinomial` `random` methods. Math operations on values returned by `draw_values` might not broadcast well, and all the `size` aware broadcasting is left to `generate_samples`. Fixes [#3481](https://github.com/pymc-devs/pymc3/issues/3481) and [#3508](https://github.com/pymc-devs/pymc3/issues/3508)

pymc3/tests/test_variational_inference.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,14 +151,22 @@ def three_var_approx(three_var_model, three_var_groups):
151151
def three_var_approx_single_group_mf(three_var_model):
152152
return MeanField(model=three_var_model)
153153

154-
155-
def test_sample_simple(three_var_approx):
156-
trace = three_var_approx.sample(500)
157-
assert set(trace.varnames) == {'one', 'one_log__', 'three', 'two'}
158-
assert len(trace) == 500
159-
assert trace[0]['one'].shape == (10, 2)
160-
assert trace[0]['two'].shape == (10, )
161-
assert trace[0]['three'].shape == (10, 1, 2)
154+
@pytest.fixture(
155+
params = [
156+
('ndarray', None),
157+
('text', 'test'),
158+
('sqlite', 'test.sqlite'),
159+
('hdf5', 'test.h5')
160+
]
161+
)
162+
def test_sample_simple(three_var_approx, request):
163+
backend, name = request.param
164+
trace = three_var_approx.sample(100, backend=backend, name=name)
165+
assert set(trace.varnames) == {'one', 'one_log__', 'three', 'two'}
166+
assert len(trace) == 100
167+
assert trace[0]['one'].shape == (10, 2)
168+
assert trace[0]['two'].shape == (10, )
169+
assert trace[0]['three'].shape == (10, 1, 2)
162170

163171

164172
@pytest.fixture

pymc3/variational/opvi.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
from ..blocking import (
4646
ArrayOrdering, DictToArrayBijection, VarMap
4747
)
48+
from ..backends import NDArray, Text, SQLite, HDF5
4849
from ..model import modelcontext
4950
from ..theanof import tt_rng, change_flags, identity
5051
from ..util import get_default_varnames
@@ -1569,7 +1570,8 @@ def inner(draws=100):
15691570

15701571
return inner
15711572

1572-
def sample(self, draws=500, include_transformed=True):
1573+
def sample(self, draws=500, include_transformed=True, backend='ndarray',
1574+
name=None):
15731575
"""Draw samples from variational posterior.
15741576
15751577
Parameters
@@ -1578,6 +1580,11 @@ def sample(self, draws=500, include_transformed=True):
15781580
Number of random samples.
15791581
include_transformed : `bool`
15801582
If True, transformed variables are also sampled. Default is False.
1583+
backend : `str`
1584+
Trace backend type to use. Valid entries include: 'ndarray' (default),
1585+
'text', 'sqlite', 'hdf5'.
1586+
name : `str`
1587+
Name for backend (required for non-NDArray backends). Default is None.
15811588
15821589
Returns
15831590
-------
@@ -1588,8 +1595,10 @@ def sample(self, draws=500, include_transformed=True):
15881595
include_transformed=include_transformed)
15891596
samples = self.sample_dict_fn(draws) # type: dict
15901597
points = ({name: records[i] for name, records in samples.items()} for i in range(draws))
1591-
trace = pm.sampling.NDArray(model=self.model, vars=vars_sampled, test_point={
1592-
name: records[0] for name, records in samples.items()
1598+
_backends = dict(ndarray=NDArray, text=Text, hdf5=HDF5, sqlite=SQLite)
1599+
1600+
trace = _backends[backend](name=name, model=self.model, vars=vars_sampled, test_point={
1601+
name: records[0] for name, records in samples.items()
15931602
})
15941603
try:
15951604
trace.setup(draws=draws, chain=0)

0 commit comments

Comments
 (0)