1
1
import logging
2
2
3
3
from collections .abc import Callable
4
+ from importlib .util import find_spec
4
5
from typing import Literal , cast , get_args
5
6
6
- import jax
7
7
import numpy as np
8
8
import pymc as pm
9
9
import pytensor
30
30
def set_optimizer_function_defaults (method , use_grad , use_hess , use_hessp ):
31
31
method_info = MINIMIZE_MODE_KWARGS [method ].copy ()
32
32
33
- use_grad = use_grad if use_grad is not None else method_info ["uses_grad" ]
34
- use_hess = use_hess if use_hess is not None else method_info ["uses_hess" ]
35
- use_hessp = use_hessp if use_hessp is not None else method_info ["uses_hessp" ]
36
-
37
33
if use_hess and use_hessp :
34
+ _log .warning (
35
+ 'Both "use_hess" and "use_hessp" are set to True, but scipy.optimize.minimize never uses both at the '
36
+ 'same time. When possible "use_hessp" is preferred because its is computationally more efficient. '
37
+ 'Setting "use_hess" to False.'
38
+ )
38
39
use_hess = False
39
40
41
+ use_grad = use_grad if use_grad is not None else method_info ["uses_grad" ]
42
+
43
+ if use_hessp is not None and use_hess is None :
44
+ use_hess = not use_hessp
45
+
46
+ elif use_hess is not None and use_hessp is None :
47
+ use_hessp = not use_hess
48
+
49
+ elif use_hessp is None and use_hess is None :
50
+ use_hessp = method_info ["uses_hessp" ]
51
+ use_hess = method_info ["uses_hess" ]
52
+ if use_hessp and use_hess :
53
+ # If a method could use either hess or hessp, we default to using hessp
54
+ use_hess = False
55
+
40
56
return use_grad , use_hess , use_hessp
41
57
42
58
@@ -59,7 +75,7 @@ def get_nearest_psd(A: np.ndarray) -> np.ndarray:
59
75
The nearest positive semi-definite matrix to the input matrix.
60
76
"""
61
77
C = (A + A .T ) / 2
62
- eigval , eigvec = np .linalg .eig (C )
78
+ eigval , eigvec = np .linalg .eigh (C )
63
79
eigval [eigval < 0 ] = 0
64
80
65
81
return eigvec @ np .diag (eigval ) @ eigvec .T
@@ -97,7 +113,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97
113
return f_untransform (posterior_draws )
98
114
99
115
100
- def _compile_jax_gradients (
116
+ def _compile_grad_and_hess_to_jax (
101
117
f_loss : Function , use_hess : bool , use_hessp : bool
102
118
) -> tuple [Callable | None , Callable | None ]:
103
119
"""
@@ -122,6 +138,8 @@ def _compile_jax_gradients(
122
138
f_hessp: Callable | None
123
139
The compiled hessian-vector product function, or None if use_hessp is False.
124
140
"""
141
+ import jax
142
+
125
143
f_hess = None
126
144
f_hessp = None
127
145
@@ -152,7 +170,7 @@ def f_hess_jax(x):
152
170
return f_loss_and_grad , f_hess , f_hessp
153
171
154
172
155
- def _compile_functions (
173
+ def _compile_functions_for_scipy_optimize (
156
174
loss : TensorVariable ,
157
175
inputs : list [TensorVariable ],
158
176
compute_grad : bool ,
@@ -177,7 +195,7 @@ def _compile_functions(
177
195
compute_hessp: bool
178
196
Whether to compile a function that computes the Hessian-vector product of the loss function.
179
197
compile_kwargs: dict, optional
180
- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
198
+ Additional keyword arguments to pass to the ``pm.compile `` function.
181
199
182
200
Returns
183
201
-------
@@ -193,19 +211,19 @@ def _compile_functions(
193
211
if compute_grad :
194
212
grads = pytensor .gradient .grad (loss , inputs )
195
213
grad = pt .concatenate ([grad .ravel () for grad in grads ])
196
- f_loss_and_grad = pm .compile_pymc (inputs , [loss , grad ], ** compile_kwargs )
214
+ f_loss_and_grad = pm .compile (inputs , [loss , grad ], ** compile_kwargs )
197
215
else :
198
- f_loss = pm .compile_pymc (inputs , loss , ** compile_kwargs )
216
+ f_loss = pm .compile (inputs , loss , ** compile_kwargs )
199
217
return [f_loss ]
200
218
201
219
if compute_hess :
202
220
hess = pytensor .gradient .jacobian (grad , inputs )[0 ]
203
- f_hess = pm .compile_pymc (inputs , hess , ** compile_kwargs )
221
+ f_hess = pm .compile (inputs , hess , ** compile_kwargs )
204
222
205
223
if compute_hessp :
206
224
p = pt .tensor ("p" , shape = inputs [0 ].type .shape )
207
225
hessp = pytensor .gradient .hessian_vector_product (loss , inputs , p )
208
- f_hessp = pm .compile_pymc ([* inputs , p ], hessp [0 ], ** compile_kwargs )
226
+ f_hessp = pm .compile ([* inputs , p ], hessp [0 ], ** compile_kwargs )
209
227
210
228
return [f_loss_and_grad , f_hess , f_hessp ]
211
229
@@ -240,7 +258,7 @@ def scipy_optimize_funcs_from_loss(
240
258
gradient_backend: str, default "pytensor"
241
259
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242
260
compile_kwargs:
243
- Additional keyword arguments to pass to the ``pm.compile_pymc `` function.
261
+ Additional keyword arguments to pass to the ``pm.compile `` function.
244
262
245
263
Returns
246
264
-------
@@ -265,6 +283,8 @@ def scipy_optimize_funcs_from_loss(
265
283
)
266
284
267
285
use_jax_gradients = (gradient_backend == "jax" ) and use_grad
286
+ if use_jax_gradients and not find_spec ("jax" ):
287
+ raise ImportError ("JAX must be installed to use JAX gradients" )
268
288
269
289
mode = compile_kwargs .get ("mode" , None )
270
290
if mode is None and use_jax_gradients :
@@ -285,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
285
305
compute_hess = use_hess and not use_jax_gradients
286
306
compute_hessp = use_hessp and not use_jax_gradients
287
307
288
- funcs = _compile_functions (
308
+ funcs = _compile_functions_for_scipy_optimize (
289
309
loss = loss ,
290
310
inputs = [flat_input ],
291
311
compute_grad = compute_grad ,
@@ -301,7 +321,7 @@ def scipy_optimize_funcs_from_loss(
301
321
302
322
if use_jax_gradients :
303
323
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304
- f_loss , f_hess , f_hessp = _compile_jax_gradients (f_loss , use_hess , use_hessp )
324
+ f_loss , f_hess , f_hessp = _compile_grad_and_hess_to_jax (f_loss , use_hess , use_hessp )
305
325
306
326
return f_loss , f_hess , f_hessp
307
327
0 commit comments