Skip to content

Commit edf49e1

Browse files
committed
Group potential and likelihood terms in the same cost in logp_dlogp_function
This is consistent with how `SMC` handles tempered posteriors * Also added potential term to tempered tests
1 parent a92a414 commit edf49e1

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

pymc/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -631,9 +631,7 @@ def logp_dlogp_function(self, grad_vars=None, tempered=False, **kwargs):
631631
raise ValueError(f"Can only compute the gradient of continuous types: {var}")
632632

633633
if tempered:
634-
# TODO: Should this differ from self.datalogpt,
635-
# where the potential terms are added to the observations?
636-
costs = [self.varlogpt + self.potentiallogpt, self.observedlogpt]
634+
costs = [self.varlogpt, self.datalogpt]
637635
else:
638636
costs = [self.logpt()]
639637

pymc/tests/test_model.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ def test_tempered_logp_dlogp():
366366
with pm.Model() as model:
367367
pm.Normal("x")
368368
pm.Normal("y", observed=1)
369+
pm.Potential("z", at.constant(-1.0, dtype=aesara.config.floatX))
369370

370371
func = model.logp_dlogp_function()
371372
func.set_extra_values({})
@@ -380,21 +381,23 @@ def test_tempered_logp_dlogp():
380381
func_temp_nograd.set_extra_values({})
381382

382383
x = np.ones(1, dtype=func.dtype)
383-
assert func(x) == func_temp(x)
384-
assert func_nograd(x) == func(x)[0]
385-
assert func_temp_nograd(x) == func(x)[0]
384+
npt.assert_allclose(func(x)[0], func_temp(x)[0])
385+
npt.assert_allclose(func(x)[1], func_temp(x)[1])
386+
387+
npt.assert_allclose(func_nograd(x), func(x)[0])
388+
npt.assert_allclose(func_temp_nograd(x), func(x)[0])
386389

387390
func_temp.set_weights(np.array([0.0], dtype=func.dtype))
388391
func_temp_nograd.set_weights(np.array([0.0], dtype=func.dtype))
389-
npt.assert_allclose(func(x)[0], 2 * func_temp(x)[0])
392+
npt.assert_allclose(func(x)[0], 2 * func_temp(x)[0] - 1)
390393
npt.assert_allclose(func(x)[1], func_temp(x)[1])
391394

392395
npt.assert_allclose(func_nograd(x), func(x)[0])
393396
npt.assert_allclose(func_temp_nograd(x), func_temp(x)[0])
394397

395398
func_temp.set_weights(np.array([0.5], dtype=func.dtype))
396399
func_temp_nograd.set_weights(np.array([0.5], dtype=func.dtype))
397-
npt.assert_allclose(func(x)[0], 4 / 3 * func_temp(x)[0])
400+
npt.assert_allclose(func(x)[0], 4 / 3 * (func_temp(x)[0] - 1 / 4))
398401
npt.assert_allclose(func(x)[1], func_temp(x)[1])
399402

400403
npt.assert_allclose(func_nograd(x), func(x)[0])

0 commit comments

Comments
 (0)