Skip to content

Commit 413a4cb

Browse files
Fix bug in fit_MAP when shared variables are used in graph (#468)
* Fix bug in `fit_MAP` when shared variables are used in graph * Delay jax import * Try version pinning dask * Pin dask version in setup.py
1 parent 7fb87b4 commit 413a4cb

File tree

5 files changed

+35
-4
lines changed

5 files changed

+35
-4
lines changed

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- pymc>=5.21
77
- pytest-cov>=2.5
88
- pytest>=3.0
9-
- dask
9+
- dask<2025.1.1
1010
- xhistogram
1111
- statsmodels
1212
- numba<=0.60.0

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ dependencies:
66
- pip
77
- pytest-cov>=2.5
88
- pytest>=3.0
9-
- dask
9+
- dask<2025.1.1
1010
- xhistogram
1111
- statsmodels
1212
- numba<=0.60.0

pymc_extras/inference/find_map.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def _compile_grad_and_hess_to_jax(
146146
orig_loss_fn = f_loss.vm.jit_fn
147147

148148
@jax.jit
149-
def loss_fn_jax_grad(x, *shared):
149+
def loss_fn_jax_grad(x):
150150
return jax.value_and_grad(lambda x: orig_loss_fn(x)[0])(x)
151151

152152
f_loss_and_grad = loss_fn_jax_grad
@@ -301,6 +301,14 @@ def scipy_optimize_funcs_from_loss(
301301
point=initial_point_dict, outputs=[loss], inputs=inputs
302302
)
303303

304+
# If we use pytensor gradients, we will use the pytensor function wrapper that handles shared variables. When
305+
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
306+
# away.
307+
if use_jax_gradients:
308+
from pymc.sampling.jax import _replace_shared_variables
309+
310+
[loss] = _replace_shared_variables([loss])
311+
304312
compute_grad = use_grad and not use_jax_gradients
305313
compute_hess = use_hess and not use_jax_gradients
306314
compute_hessp = use_hessp and not use_jax_gradients

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060

6161

6262
extras_require = dict(
63-
dask_histogram=["dask[complete]", "xhistogram"],
63+
dask_histogram=["dask[complete]<2025.1.1", "xhistogram"],
6464
histogram=["xhistogram"],
6565
)
6666
extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values())))

tests/test_find_map.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pymc as pm
3+
import pytensor
34
import pytensor.tensor as pt
45
import pytest
56

@@ -101,3 +102,25 @@ def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: Gradie
101102

102103
assert np.isclose(mu_hat, 3, atol=0.5)
103104
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)
105+
106+
107+
def test_JAX_map_shared_variables():
108+
with pm.Model() as m:
109+
data = pytensor.shared(np.random.normal(loc=3, scale=1.5, size=100), name="shared_data")
110+
mu = pm.Normal("mu")
111+
sigma = pm.Exponential("sigma", 1)
112+
y_hat = pm.Normal("y_hat", mu=mu, sigma=sigma, observed=data)
113+
114+
optimized_point = find_MAP(
115+
method="L-BFGS-B",
116+
use_grad=True,
117+
use_hess=False,
118+
use_hessp=False,
119+
progressbar=False,
120+
gradient_backend="jax",
121+
compile_kwargs={"mode": "JAX"},
122+
)
123+
mu_hat, log_sigma_hat = optimized_point["mu"], optimized_point["sigma_log__"]
124+
125+
assert np.isclose(mu_hat, 3, atol=0.5)
126+
assert np.isclose(np.exp(log_sigma_hat), 1.5, atol=0.5)

0 commit comments

Comments
 (0)