@@ -366,6 +366,7 @@ def test_tempered_logp_dlogp():
366
366
with pm .Model () as model :
367
367
pm .Normal ("x" )
368
368
pm .Normal ("y" , observed = 1 )
369
+ pm .Potential ("z" , at .constant (- 1.0 , dtype = aesara .config .floatX ))
369
370
370
371
func = model .logp_dlogp_function ()
371
372
func .set_extra_values ({})
@@ -380,21 +381,23 @@ def test_tempered_logp_dlogp():
380
381
func_temp_nograd .set_extra_values ({})
381
382
382
383
x = np .ones (1 , dtype = func .dtype )
383
- assert func (x ) == func_temp (x )
384
- assert func_nograd (x ) == func (x )[0 ]
385
- assert func_temp_nograd (x ) == func (x )[0 ]
384
+ npt .assert_allclose (func (x )[0 ], func_temp (x )[0 ])
385
+ npt .assert_allclose (func (x )[1 ], func_temp (x )[1 ])
386
+
387
+ npt .assert_allclose (func_nograd (x ), func (x )[0 ])
388
+ npt .assert_allclose (func_temp_nograd (x ), func (x )[0 ])
386
389
387
390
func_temp .set_weights (np .array ([0.0 ], dtype = func .dtype ))
388
391
func_temp_nograd .set_weights (np .array ([0.0 ], dtype = func .dtype ))
389
- npt .assert_allclose (func (x )[0 ], 2 * func_temp (x )[0 ])
392
+ npt .assert_allclose (func (x )[0 ], 2 * func_temp (x )[0 ] - 1 )
390
393
npt .assert_allclose (func (x )[1 ], func_temp (x )[1 ])
391
394
392
395
npt .assert_allclose (func_nograd (x ), func (x )[0 ])
393
396
npt .assert_allclose (func_temp_nograd (x ), func_temp (x )[0 ])
394
397
395
398
func_temp .set_weights (np .array ([0.5 ], dtype = func .dtype ))
396
399
func_temp_nograd .set_weights (np .array ([0.5 ], dtype = func .dtype ))
397
- npt .assert_allclose (func (x )[0 ], 4 / 3 * func_temp (x )[0 ])
400
+ npt .assert_allclose (func (x )[0 ], 4 / 3 * ( func_temp (x )[0 ] - 1 / 4 ) )
398
401
npt .assert_allclose (func (x )[1 ], func_temp (x )[1 ])
399
402
400
403
npt .assert_allclose (func_nograd (x ), func (x )[0 ])
0 commit comments