Skip to content

Optimization #1953

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
794 changes: 794 additions & 0 deletions docs/source/notebooks/simple_stochastic_optimization.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .tests import test

from .data import *
from .optimization import *

import logging
_log = logging.getLogger('pymc3')
Expand Down
83 changes: 83 additions & 0 deletions pymc3/optimization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import tqdm
import numpy as np
import theano
from theano.configparser import change_flags
import pymc3 as pm

__all__ = [
'Optimizer'
]


class Optimizer(object):
"""
Optimization with posterior replacements

Parameters
----------
approx : Approximation
loss : scalar
wrt : shared params
more_replacements : other replacements in the graph
optimizer : callable that returns updates, pm.adam by default
"""
@change_flags(compute_test_value='off')
def __init__(self, approx, loss, wrt, more_replacements=None, optimizer=None):
if optimizer is None:
optimizer = pm.adam
self.optimizer = optimizer
self.wrt = wrt
self.approx = approx
self.loss = approx.apply_replacements(loss, more_replacements=more_replacements)
updates = self.optimizer(self.loss, self.wrt)
self.step_function = theano.function([], self.loss, updates=updates)
self.hist = np.asarray(())

@change_flags(compute_test_value='off')
def refresh(self, kwargs=None):
"""
Recompile step_function and reset updates

This can be needed sometimes when
you use stateful optimizers like ADAM
and want to reuse the function for new problem

Parameters
----------
kwargs : kwargs for theano.function
"""
updates = self.optimizer(self.loss, self.wrt)
self.step_function = theano.function([], self.loss, updates=updates, **kwargs)
self.hist = np.asarray(())

def fit(self, n=5000, callbacks=()):
"""
Perform optimization steps

Parameters
----------
n : int
number of iterations
callbacks : list[callable]
list of callables with following signature
f(Approximation, loss_history, i) -> None
"""
progress = tqdm.trange(n)
scores = np.empty(n)
scores[:] = np.nan
i = 0
try:
for i in progress:
scores[i] = self.step_function()
for callback in callbacks:
callback(self.approx, scores[:i + 1], i)
if i % ((n+1000)//1000) == 0:
progress.set_description(
'E_q[Loss] = %.4f' % scores[max(0, i - n//50):i+1].mean()
)
except (KeyboardInterrupt, StopIteration):
self.hist = np.concatenate((self.hist, scores[:i]))
else:
self.hist = np.concatenate((self.hist, scores))
finally:
progress.close()
8 changes: 4 additions & 4 deletions pymc3/variational/opvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,11 @@ def apply_replacements(self, node, deterministic=False,

Parameters
----------
node : Theano Variables (or Theano expressions)
node : Theano Variable(s)
node or nodes for replacements
deterministic : bool
whether to use zeros as initial distribution
if True - zero initial point will produce constant latent variables
whether to use point with highest density as initial point for distribution
if True - constant initial point will produce constant latent variables
include : list
latent variables to be replaced
exclude : list
Expand All @@ -623,7 +623,7 @@ def sample_node(self, node, size=100,

Parameters
----------
node : Theano Variables (or Theano expressions)
node : Theano Variable(s)
size : scalar
number of samples
more_replacements : dict
Expand Down