Skip to content

Commit ab704bf

Browse files
committed
add temperature
1 parent 3303509 commit ab704bf

File tree

5 files changed

+48
-17
lines changed

5 files changed

+48
-17
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def test_optimizer_with_full_data(self):
152152
mu_ = Normal('mu', mu=mu0, sd=sd0, testval=0)
153153
Normal('x', mu=mu_, sd=sd, observed=data)
154154
inf = self.inference(start={})
155-
inf.fit(10)
156155
approx = inf.fit(self.NITER,
157156
obj_optimizer=self.optimizer,
158157
callbacks=self.conv_cb,)
@@ -295,11 +294,9 @@ class TestSVGD(TestApproximates.Base):
295294

296295

297296
class TestASVGD(TestApproximates.Base):
298-
NITER = 15000
299-
inference = ASVGD
297+
NITER = 5000
298+
inference = functools.partial(ASVGD, temperature=1.5)
300299
test_aevb = _test_aevb
301-
optimizer = pm.adagrad_window(learning_rate=0.002)
302-
conv_cb = []
303300

304301

305302
class TestEmpirical(SeededTest):

pymc3/variational/inference.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,15 @@ class Inference(object):
4141
See (AEVB; Kingma and Welling, 2014) for details
4242
model : Model
4343
PyMC3 Model
44+
op_kwargs : dict
45+
kwargs passed to :class:`Operator`
4446
kwargs : kwargs
4547
additional kwargs for :class:`Approximation`
4648
"""
4749

48-
def __init__(self, op, approx, tf, local_rv=None, model=None, **kwargs):
50+
def __init__(self, op, approx, tf, local_rv=None, model=None, op_kwargs=None, **kwargs):
51+
if op_kwargs is None:
52+
op_kwargs = dict()
4953
self.hist = np.asarray(())
5054
if isinstance(approx, type) and issubclass(approx, Approximation):
5155
approx = approx(
@@ -56,7 +60,7 @@ def __init__(self, op, approx, tf, local_rv=None, model=None, **kwargs):
5660
else: # pragma: no cover
5761
raise TypeError(
5862
'approx should be Approximation instance or Approximation subclass')
59-
self.objective = op(approx)(tf)
63+
self.objective = op(approx, **op_kwargs)(tf)
6064

6165
approx = property(lambda self: self.objective.approx)
6266

@@ -535,6 +539,8 @@ class SVGD(Inference):
535539
PyMC3 model for inference
536540
kernel : `callable`
537541
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
542+
temperature : float
543+
parameter responsible for exploration, higher temperature gives more broad posterior estimate
538544
scale_cost_to_minibatch : bool, default False
539545
Scale cost to minibatch instead of full dataset
540546
start : `dict`
@@ -552,10 +558,14 @@ class SVGD(Inference):
552558
- Qiang Liu, Dilin Wang (2016)
553559
Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm
554560
arXiv:1608.04471
561+
562+
- Yang Liu, Prajit Ramachandran, Qiang Liu, Jian Peng (2017)
563+
Stein Variational Policy Gradient
564+
arXiv:1704.02399
555565
"""
556566

557567
def __init__(self, n_particles=100, jitter=.01, model=None, kernel=test_functions.rbf,
558-
scale_cost_to_minibatch=False, start=None, histogram=None,
568+
temperature=1, scale_cost_to_minibatch=False, start=None, histogram=None,
559569
random_seed=None, local_rv=None):
560570
if histogram is None:
561571
histogram = Empirical.from_noise(
@@ -597,6 +607,8 @@ class ASVGD(Inference):
597607
See (AEVB; Kingma and Welling, 2014) for details
598608
kernel : `callable`
599609
kernel function for KSD :math:`f(histogram) -> (k(x,.), \nabla_x k(x,.))`
610+
temperature : float
611+
parameter responsible for exploration, higher temperature gives more broad posterior estimate
600612
model : :class:`Model`
601613
kwargs : kwargs for :class:`Approximation`
602614
@@ -608,17 +620,22 @@ class ASVGD(Inference):
608620
609621
- Dilin Wang, Qiang Liu (2016)
610622
Learning to Draw Samples: With Application to Amortized MLE for Generative Adversarial Learning
611-
https://arxiv.org/abs/1611.01722
623+
arXiv:1611.01722
624+
625+
- Yang Liu, Prajit Ramachandran, Qiang Liu, Jian Peng (2017)
626+
Stein Variational Policy Gradient
627+
arXiv:1704.02399
612628
"""
613629

614630
def __init__(self, approx=FullRank, local_rv=None,
615-
kernel=test_functions.rbf, model=None, **kwargs):
631+
kernel=test_functions.rbf, temperature=1, model=None, **kwargs):
616632
super(ASVGD, self).__init__(
617633
op=AKSD,
618634
approx=approx,
619635
local_rv=local_rv,
620636
tf=kernel,
621637
model=model,
638+
op_kwargs=dict(temperature=temperature),
622639
**kwargs
623640
)
624641

pymc3/variational/operators.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from theano import theano, tensor as tt
23
from pymc3.variational.opvi import Operator, ObjectiveFunction, _warn_not_used
34
from pymc3.variational.stein import Stein
@@ -95,15 +96,29 @@ class KSD(Operator):
9596
SUPPORT_AEVB = False
9697
OBJECTIVE = KSDObjective
9798

98-
def __init__(self, approx):
99+
def __init__(self, approx, temperature=1):
99100
Operator.__init__(self, approx)
101+
self.temperature = temperature
100102
self.input_matrix = tt.matrix('KSD input matrix')
101103

102104
def apply(self, f):
103105
# f: kernel function for KSD f(histogram) -> (k(x,.), \nabla_x k(x,.))
104-
stein = Stein(self.approx, f, self.input_matrix)
106+
stein = Stein(
107+
approx=self.approx,
108+
kernel=f,
109+
input_matrix=self.input_matrix,
110+
temperature=self.temperature)
105111
return pm.floatX(-1) * stein.grad
106112

107113

108114
class AKSD(KSD):
115+
def __init__(self, approx, temperature=1):
116+
warnings.warn('You are using experimental inference Operator. '
117+
'It requires careful choice of temperature, default is 1. '
118+
'Default temperature works well for low dimensional problems and '
119+
'for significant `n_obj_mc`. Temperature > 1 gives more exploration '
120+
'power to algorithm, < 1 leads to undesirable results. Please take '
121+
'it in account when looking at inference result. Posterior variance '
122+
'is often **underestimated** when using temperature = 1.', stacklevel=2)
123+
super(AKSD, self).__init__(approx, temperature)
109124
SUPPORT_AEVB = True

pymc3/variational/stein.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from theano import theano, tensor as tt
22
from pymc3.variational.test_functions import rbf
3-
from pymc3.theanof import memoize
3+
from pymc3.theanof import memoize, floatX
44

55
__all__ = [
66
'Stein'
77
]
88

99

1010
class Stein(object):
11-
def __init__(self, approx, kernel=rbf, input_matrix=None):
11+
def __init__(self, approx, kernel=rbf, input_matrix=None, temperature=1):
1212
self.approx = approx
13+
self.temperature = floatX(temperature)
1314
self._kernel_f = kernel
1415
if input_matrix is None:
1516
input_matrix = tt.matrix('stein_input_matrix')
@@ -22,8 +23,9 @@ def grad(self):
2223
t = self.approx.normalizing_constant
2324
Kxy, dxkxy = self.Kxy, self.dxkxy
2425
dlogpdx = self.dlogp # Normalized
25-
n = self.input_matrix.shape[0].astype('float32')
26-
svgd_grad = (tt.dot(Kxy, dlogpdx) + dxkxy/t) / n
26+
n = floatX(self.input_matrix.shape[0])
27+
temperature = self.temperature
28+
svgd_grad = (tt.dot(Kxy, dlogpdx)/temperature + dxkxy/t) / n
2729
return svgd_grad
2830

2931
@property

pymc3/variational/test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from theano import tensor as tt
1+
from theano import tensor as tt, theano
22
from .opvi import TestFunction
33
from pymc3.theanof import floatX
44

0 commit comments

Comments
 (0)