File tree Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Expand file tree Collapse file tree 1 file changed +2
-1
lines changed Original file line number Diff line number Diff line change 15
15
from pymc .initial_point import make_initial_point_fn
16
16
from pymc .model .transform .optimization import freeze_dims_and_data
17
17
from pymc .pytensorf import join_nonshared_inputs
18
- from pymc .sampling .jax import _replace_shared_variables
19
18
from pymc .util import get_default_varnames
20
19
from pytensor .compile import Function
21
20
from pytensor .compile .mode import Mode
@@ -306,6 +305,8 @@ def scipy_optimize_funcs_from_loss(
306
305
# computing jax gradients, we discard the function wrapper, so we can't handle shared variables --> rewrite them
307
306
# away.
308
307
if use_jax_gradients :
308
+ from pymc .sampling .jax import _replace_shared_variables
309
+
309
310
[loss ] = _replace_shared_variables ([loss ])
310
311
311
312
compute_grad = use_grad and not use_jax_gradients
You can’t perform that action at this time.
0 commit comments