|
40 | 40 | from pymc.model import modelcontext
|
41 | 41 | from pymc.model.core import Point
|
42 | 42 | from pymc.pytensorf import (
|
43 |
| - compile_pymc, |
| 43 | + compile, |
44 | 44 | find_rng_nodes,
|
45 | 45 | reseed_rngs,
|
46 | 46 | )
|
|
77 | 77 |
|
78 | 78 | logger = logging.getLogger(__name__)
|
79 | 79 | _warnings.filterwarnings(
|
80 |
| - "ignore", category=FutureWarning, message="compile_pymc was renamed to compile" |
| 80 | + "ignore", |
| 81 | + category=UserWarning, |
| 82 | + message="The same einsum subscript is used for a broadcastable and non-broadcastable dimension", |
81 | 83 | )
|
82 | 84 |
|
83 | 85 | REGULARISATION_TERM = 1e-8
|
@@ -142,7 +144,7 @@ def get_logp_dlogp_of_ravel_inputs(
|
142 | 144 | [model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
|
143 | 145 | model.value_vars,
|
144 | 146 | )
|
145 |
| - logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs) |
| 147 | + logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs) |
146 | 148 | logp_dlogp_fn.trust_input = True
|
147 | 149 |
|
148 | 150 | return logp_dlogp_fn
|
@@ -502,7 +504,7 @@ def bfgs_sample_dense(
|
502 | 504 |
|
503 | 505 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
504 | 506 |
|
505 |
| - mu = x - pt.batched_dot(H_inv, g) |
| 507 | + mu = x - pt.einsum("ijk,ik->ij", H_inv, g) |
506 | 508 |
|
507 | 509 | phi = pt.matrix_transpose(
|
508 | 510 | # (L, N, 1)
|
@@ -571,15 +573,12 @@ def bfgs_sample_sparse(
|
571 | 573 | logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
|
572 | 574 | logdet += pt.sum(pt.log(alpha), axis=-1)
|
573 | 575 |
|
| 576 | + # inverse Hessian |
| 577 | + # (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
| 578 | + H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta)) |
| 579 | + |
574 | 580 | # NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
|
575 |
| - mu = x - ( |
576 |
| - # (L, N), (L, N) -> (L, N) |
577 |
| - pt.batched_dot(alpha_diag, g) |
578 |
| - # beta @ gamma @ beta.T |
579 |
| - # (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N) |
580 |
| - # (L, N, N), (L, N) -> (L, N) |
581 |
| - + pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g) |
582 |
| - ) |
| 581 | + mu = x - pt.einsum("ijk,ik->ij", H_inv, g) |
583 | 582 |
|
584 | 583 | phi = pt.matrix_transpose(
|
585 | 584 | # (L, N, 1)
|
@@ -853,7 +852,7 @@ def make_pathfinder_body(
|
853 | 852 |
|
854 | 853 | # return psi, logP_psi, logQ_psi, elbo_argmax
|
855 | 854 |
|
856 |
| - pathfinder_body_fn = compile_pymc( |
| 855 | + pathfinder_body_fn = compile( |
857 | 856 | [x_full, g_full],
|
858 | 857 | [psi, logP_psi, logQ_psi, elbo_argmax],
|
859 | 858 | **compile_kwargs,
|
|
0 commit comments