Skip to content

Commit fa49b43

Browse files
Fix bug in fit_MAP when shared variables are used in graph
1 parent 7fb87b4 commit fa49b43

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

pymc_extras/inference/find_map.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pymc.initial_point import make_initial_point_fn
1616
from pymc.model.transform.optimization import freeze_dims_and_data
1717
from pymc.pytensorf import join_nonshared_inputs
18+
from pymc.sampling.jax import _replace_shared_variables
1819
from pymc.util import get_default_varnames
1920
from pytensor.compile import Function
2021
from pytensor.compile.mode import Mode
@@ -146,7 +147,7 @@ def _compile_grad_and_hess_to_jax(
146147
orig_loss_fn = f_loss.vm.jit_fn
147148

148149
@jax.jit
149-
def loss_fn_jax_grad(x, *shared):
150+
def loss_fn_jax_grad(x):
150151
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
151152

152153
f_loss_and_grad = loss_fn_jax_grad
@@ -301,6 +302,12 @@ def scipy_optimize_funcs_from_loss(
301302
point=initial_point_dict, outputs=[loss], inputs=inputs
302303
)
303304

305+
# If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
306+
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
307+
# away.
308+
if use_jax_gradients:
309+
[loss] = _replace_shared_variables([loss])
310+
304311
compute_grad = use_grad and not use_jax_gradients
305312
compute_hess = use_hess and not use_jax_gradients
306313
compute_hessp = use_hessp and not use_jax_gradients

tests/test_find_map.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pymc as pm
3+
import pytensor
34
import pytensor.tensor as pt
45
import pytest
56

@@ -101,3 +102,25 @@ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: Gradie
101102

102103
assert np.isclose(mu_hat, 3, atol=0.5)
103104
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)
105+
106+
107+
def test_JAX_map_shared_variables():
108+
with pm.Model() as m:
109+
data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data")
110+
mu = pm.Normal("mu")
111+
sigma = pm.Exponential("sigma", 1)
112+
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data)
113+
114+
optimized_point = find_MAP(
115+
method="L-BFGS-B",
116+
use_grad=True,
117+
use_hess=False,
118+
use_hessp=False,
119+
progressbar=False,
120+
gradient_backend="jax",
121+
compile_kwargs={"mode": "JAX"},
122+
)
123+
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]
124+
125+
assert np.isclose(mu_hat, 3, atol=0.5)
126+
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)

0 commit comments

Comments
 (0)