Skip to content

Commit 8974377

Browse files
Delay jax import
1 parent fa49b43 commit 8974377

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

pymc_extras/inference/find_map.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from pymc.initial_point import make_initial_point_fn
1616
from pymc.model.transform.optimization import freeze_dims_and_data
1717
from pymc.pytensorf import join_nonshared_inputs
18-
from pymc.sampling.jax import _replace_shared_variables
1918
from pymc.util import get_default_varnames
2019
from pytensor.compile import Function
2120
from pytensor.compile.mode import Mode
@@ -306,6 +305,8 @@ def scipy_optimize_funcs_from_loss(
306305
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
307306
# away.
308307
if use_jax_gradients:
308+
from pymc.sampling.jax import _replace_shared_variables
309+
309310
[loss] = _replace_shared_variables([loss])
310311

311312
compute_grad = use_grad and not use_jax_gradients

0 commit comments

Comments
 (0)