Skip to content

Commit ab39ab1

Browse files
committed
Fix failed test from PR pymc-devs#443 and replaced batched_dot for einsum
* Modified the calculation of in and to use instead of the deprecated function. * Added calculations for in to improve readability. * Renamed to in pathfinder functions to reflect API changes. * Updated warning filters to ignore UserWarnings related to einsum subscripts.
1 parent c0ddb8e commit ab39ab1

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

pymc_extras/inference/pathfinder/pathfinder.py

+12-13
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from pymc.model import modelcontext
4141
from pymc.model.core import Point
4242
from pymc.pytensorf import (
43-
compile_pymc,
43+
compile,
4444
find_rng_nodes,
4545
reseed_rngs,
4646
)
@@ -77,7 +77,9 @@
7777

7878
logger = logging.getLogger(__name__)
7979
_warnings.filterwarnings(
80-
"ignore", category=FutureWarning, message="compile_pymc was renamed to compile"
80+
"ignore",
81+
category=UserWarning,
82+
message="The same einsum subscript is used for a broadcastable and non-broadcastable dimension",
8183
)
8284

8385
REGULARISATION_TERM = 1e-8
@@ -142,7 +144,7 @@ def get_logp_dlogp_of_ravel_inputs(
142144
[model.logp(jacobian=jacobian), model.dlogp(jacobian=jacobian)],
143145
model.value_vars,
144146
)
145-
logp_dlogp_fn = compile_pymc([inputs], (logP, dlogP), **compile_kwargs)
147+
logp_dlogp_fn = compile([inputs], (logP, dlogP), **compile_kwargs)
146148
logp_dlogp_fn.trust_input = True
147149

148150
return logp_dlogp_fn
@@ -502,7 +504,7 @@ def bfgs_sample_dense(
502504

503505
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
504506

505-
mu = x - pt.batched_dot(H_inv, g)
507+
mu = x - pt.einsum("ijk,ik->ij", H_inv, g)
506508

507509
phi = pt.matrix_transpose(
508510
# (L, N, 1)
@@ -571,15 +573,12 @@ def bfgs_sample_sparse(
571573
logdet = 2.0 * pt.sum(pt.log(pt.abs(pt.diagonal(Lchol, axis1=-2, axis2=-1))), axis=-1)
572574
logdet += pt.sum(pt.log(alpha), axis=-1)
573575

576+
# inverse Hessian
577+
# (L, N, N) + (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
578+
H_inv = alpha_diag + (beta @ gamma @ pt.matrix_transpose(beta))
579+
574580
# NOTE: changed the sign from "x + " to "x -" of the expression to match Stan which differs from Zhang et al., (2022). same for dense version.
575-
mu = x - (
576-
# (L, N), (L, N) -> (L, N)
577-
pt.batched_dot(alpha_diag, g)
578-
# beta @ gamma @ beta.T
579-
# (L, N, 2J), (L, 2J, 2J), (L, 2J, N) -> (L, N, N)
580-
# (L, N, N), (L, N) -> (L, N)
581-
+ pt.batched_dot((beta @ gamma @ pt.matrix_transpose(beta)), g)
582-
)
581+
mu = x - pt.einsum("ijk,ik->ij", H_inv, g)
583582

584583
phi = pt.matrix_transpose(
585584
# (L, N, 1)
@@ -853,7 +852,7 @@ def make_pathfinder_body(
853852

854853
# return psi, logP_psi, logQ_psi, elbo_argmax
855854

856-
pathfinder_body_fn = compile_pymc(
855+
pathfinder_body_fn = compile(
857856
[x_full, g_full],
858857
[psi, logP_psi, logQ_psi, elbo_argmax],
859858
**compile_kwargs,

tests/test_pathfinder.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
import pymc as pm
1919
import pytest
2020

21-
pytestmark = pytest.mark.filterwarnings("ignore:compile_pymc was renamed to compile:FutureWarning")
21+
pytestmark = pytest.mark.filterwarnings(
22+
"ignore:The same einsum subscript is used for a broadcastable and non-broadcastable dimension:UserWarning"
23+
)
2224

2325
import pymc_extras as pmx
2426

0 commit comments

Comments
 (0)