30
30
def set_optimizer_function_defaults (method , use_grad , use_hess , use_hessp ):
31
31
method_info = MINIMIZE_MODE_KWARGS [method ].copy ()
32
32
33
+ if use_hess and use_hessp :
34
+ _log .warning (
35
+ 'Both "use_hess" and "use_hessp" are set to True. scipy.optimize.minimize never uses both at the '
36
+ 'same time. Setting "use_hess" to False.'
37
+ )
38
+ use_hess = False
39
+
33
40
use_grad = use_grad if use_grad is not None else method_info ["uses_grad" ]
34
41
use_hess = use_hess if use_hess is not None else method_info ["uses_hess" ]
35
42
use_hessp = use_hessp if use_hessp is not None else method_info ["uses_hessp" ]
36
43
37
- if use_hess and use_hessp :
38
- use_hess = False
39
-
40
44
return use_grad , use_hess , use_hessp
41
45
42
46
@@ -97,7 +101,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97
101
return f_untransform (posterior_draws )
98
102
99
103
100
- def _compile_jax_gradients (
104
+ def _compile_grad_and_hess_to_jax (
101
105
f_loss : Function , use_hess : bool , use_hessp : bool
102
106
) -> tuple [Callable | None , Callable | None ]:
103
107
"""
@@ -152,7 +156,7 @@ def f_hess_jax(x):
152
156
return f_loss_and_grad , f_hess , f_hessp
153
157
154
158
155
- def _compile_functions (
159
+ def _compile_functions_for_scipy_optimize (
156
160
loss : TensorVariable ,
157
161
inputs : list [TensorVariable ],
158
162
compute_grad : bool ,
@@ -177,7 +181,7 @@ def _compile_functions(
177
181
compute_hessp: bool
178
182
Whether to compile a function that computes the Hessian-vector product of the loss function.
179
183
compile_kwargs: dict, optional
180
- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
184
+ Additional keyword arguments to pass to the ``pm.compile `` function.
181
185
182
186
Returns
183
187
-------
@@ -193,19 +197,19 @@ def _compile_functions(
193
197
if compute_grad :
194
198
grads = pytensor .gradient .grad (loss , inputs )
195
199
grad = pt .concatenate ([grad .ravel () for grad in grads ])
196
- f_loss_and_grad = pm .compile_pymc (inputs , [loss , grad ], ** compile_kwargs )
200
+ f_loss_and_grad = pm .compile (inputs , [loss , grad ], ** compile_kwargs )
197
201
else :
198
- f_loss = pm .compile_pymc (inputs , loss , ** compile_kwargs )
202
+ f_loss = pm .compile (inputs , loss , ** compile_kwargs )
199
203
return [f_loss ]
200
204
201
205
if compute_hess :
202
206
hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
203
- f_hess = pm .compile_pymc (inputs , hess , ** compile_kwargs )
207
+ f_hess = pm .compile (inputs , hess , ** compile_kwargs )
204
208
205
209
if compute_hessp :
206
210
p = pt .tensor ("p" , shape = inputs [0 ].type .shape )
207
211
hessp = pytensor .gradient .hessian_vector_product (loss , inputs , p )
208
- f_hessp = pm .compile_pymc ([* inputs , p ], hessp [0 ], ** compile_kwargs )
212
+ f_hessp = pm .compile ([* inputs , p ], hessp [0 ], ** compile_kwargs )
209
213
210
214
return [f_loss_and_grad , f_hess , f_hessp ]
211
215
@@ -240,7 +244,7 @@ def scipy_optimize_funcs_from_loss(
240
244
gradient_backend: str, default "pytensor"
241
245
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242
246
compile_kwargs:
243
- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
247
+ Additional keyword arguments to pass to the ``pm.compile `` function.
244
248
245
249
Returns
246
250
-------
@@ -285,7 +289,7 @@ def scipy_optimize_funcs_from_loss(
285
289
compute_hess = use_hess and not use_jax_gradients
286
290
compute_hessp = use_hessp and not use_jax_gradients
287
291
288
- funcs = _compile_functions (
292
+ funcs = _compile_functions_for_scipy_optimize (
289
293
loss = loss ,
290
294
inputs = [flat_input ],
291
295
compute_grad = compute_grad ,
@@ -301,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
301
305
302
306
if use_jax_gradients :
303
307
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304
- f_loss , f_hess , f_hessp = _compile_jax_gradients (f_loss , use_hess , use_hessp )
308
+ f_loss , f_hess , f_hessp = _compile_grad_and_hess_to_jax (f_loss , use_hess , use_hessp )
305
309
306
310
return f_loss , f_hess , f_hessp
307
311
0 commit comments