Skip to content

Commit ee8a83b

Browse files
ferrinetwiecki
authored andcommitted
fix potentials in inference
1 parent 2696094 commit ee8a83b

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

pymc3/tests/test_variational_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,7 @@ def test_init_from_noize(self):
324324

325325
_model = models.simple_model()[1]
326326
with _model:
327+
pm.Potential('pot', tt.ones((10, 10)))
327328
_advi = ADVI()
328329
_fullrank_advi = FullRankADVI()
329330
_svgd = SVGD()

pymc3/variational/opvi.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,17 @@ def __init__(self, approx):
270270
input = property(lambda self: self.approx.flat_view.input)
271271

272272
def logp(self, z):
273-
factors = [tt.sum(var.logpt)for var in self.model.basic_RVs + self.model.potentials]
273+
factors = ([tt.sum(var.logpt)for var in self.model.basic_RVs] +
274+
[tt.sum(var) for var in self.model.potentials])
275+
274276
p = self.approx.to_flat_input(tt.add(*factors))
275277
p = theano.clone(p, {self.input: z})
276278
return p
277279

278280
def logp_norm(self, z):
279281
t = self.approx.normalizing_constant
280-
factors = [tt.sum(var.logpt) / t for var in self.model.basic_RVs + self.model.potentials]
282+
factors = ([tt.sum(var.logpt) / t for var in self.model.basic_RVs] +
283+
[tt.sum(var) / t for var in self.model.potentials])
281284
logpt = tt.add(*factors)
282285
p = self.approx.to_flat_input(logpt)
283286
p = theano.clone(p, {self.input: z})

0 commit comments

Comments
 (0)