Skip to content

Commit 1cef2e7

Browse files
Address comments by @theorashid from on PR pymc-devs#385
1 parent 9d47188 commit 1cef2e7

File tree

4 files changed

+50
-39
lines changed

4 files changed

+50
-39
lines changed

pymc_extras/inference/find_map.py

+17-13
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,17 @@
3030
def set_optimizer_function_defaults(method, use_grad, use_hess, use_hessp):
3131
method_info = MINIMIZE_MODE_KWARGS[method].copy()
3232

33+
if use_hess and use_hessp:
34+
_log.warning(
35+
'Both "use_hess" and "use_hessp" are set to True. scipy.optimize.minimize never uses both at the '
36+
'same time. Setting "use_hess" to False.'
37+
)
38+
use_hess = False
39+
3340
use_grad = use_grad if use_grad is not None else method_info["uses_grad"]
3441
use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
3542
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]
3643

37-
if use_hess and use_hessp:
38-
use_hess = False
39-
4044
return use_grad, use_hess, use_hessp
4145

4246

@@ -97,7 +101,7 @@ def _create_transformed_draws(H_inv, slices, out_shapes, posterior_draws, model,
97101
return f_untransform(posterior_draws)
98102

99103

100-
def _compile_jax_gradients(
104+
def _compile_grad_and_hess_to_jax(
101105
f_loss: Function, use_hess: bool, use_hessp: bool
102106
) -> tuple[Callable | None, Callable | None]:
103107
"""
@@ -152,7 +156,7 @@ def f_hess_jax(x):
152156
return f_loss_and_grad, f_hess, f_hessp
153157

154158

155-
def _compile_functions(
159+
def _compile_functions_for_scipy_optimize(
156160
loss: TensorVariable,
157161
inputs: list[TensorVariable],
158162
compute_grad: bool,
@@ -177,7 +181,7 @@ def _compile_functions(
177181
compute_hessp: bool
178182
Whether to compile a function that computes the Hessian-vector product of the loss function.
179183
compile_kwargs: dict, optional
180-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
184+
Additional keyword arguments to pass to the ``pm.compile`` function.
181185
182186
Returns
183187
-------
@@ -193,19 +197,19 @@ def _compile_functions(
193197
if compute_grad:
194198
grads = pytensor.gradient.grad(loss, inputs)
195199
grad = pt.concatenate([grad.ravel() for grad in grads])
196-
f_loss_and_grad = pm.compile_pymc(inputs, [loss, grad], **compile_kwargs)
200+
f_loss_and_grad = pm.compile(inputs, [loss, grad], **compile_kwargs)
197201
else:
198-
f_loss = pm.compile_pymc(inputs, loss, **compile_kwargs)
202+
f_loss = pm.compile(inputs, loss, **compile_kwargs)
199203
return [f_loss]
200204

201205
if compute_hess:
202206
hess = pytensor.gradient.jacobian(grad, inputs)[0]
203-
f_hess = pm.compile_pymc(inputs, hess, **compile_kwargs)
207+
f_hess = pm.compile(inputs, hess, **compile_kwargs)
204208

205209
if compute_hessp:
206210
p = pt.tensor("p", shape=inputs[0].type.shape)
207211
hessp = pytensor.gradient.hessian_vector_product(loss, inputs, p)
208-
f_hessp = pm.compile_pymc([*inputs, p], hessp[0], **compile_kwargs)
212+
f_hessp = pm.compile([*inputs, p], hessp[0], **compile_kwargs)
209213

210214
return [f_loss_and_grad, f_hess, f_hessp]
211215

@@ -240,7 +244,7 @@ def scipy_optimize_funcs_from_loss(
240244
gradient_backend: str, default "pytensor"
241245
Which backend to use to compute gradients. Must be one of "jax" or "pytensor"
242246
compile_kwargs:
243-
Additional keyword arguments to pass to the ``pm.compile_pymc`` function.
247+
Additional keyword arguments to pass to the ``pm.compile`` function.
244248
245249
Returns
246250
-------
@@ -285,7 +289,7 @@ def scipy_optimize_funcs_from_loss(
285289
compute_hess = use_hess and not use_jax_gradients
286290
compute_hessp = use_hessp and not use_jax_gradients
287291

288-
funcs = _compile_functions(
292+
funcs = _compile_functions_for_scipy_optimize(
289293
loss=loss,
290294
inputs=[flat_input],
291295
compute_grad=compute_grad,
@@ -301,7 +305,7 @@ def scipy_optimize_funcs_from_loss(
301305

302306
if use_jax_gradients:
303307
# f_loss here is f_loss_and_grad; the name is unchanged to simplify the return values
304-
f_loss, f_hess, f_hessp = _compile_jax_gradients(f_loss, use_hess, use_hessp)
308+
f_loss, f_hess, f_hessp = _compile_grad_and_hess_to_jax(f_loss, use_hess, use_hessp)
305309

306310
return f_loss, f_hess, f_hessp
307311

pymc_extras/inference/laplace.py

+12-10
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def add_data_to_inferencedata(
231231
return idata
232232

233233

234-
def fit_mvn_to_MAP(
234+
def fit_mvn_at_MAP(
235235
optimized_point: dict[str, np.ndarray],
236236
model: pm.Model | None = None,
237237
on_bad_cov: Literal["warn", "error", "ignore"] = "ignore",
@@ -344,8 +344,10 @@ def sample_laplace_posterior(
344344
345345
Parameters
346346
----------
347-
mu
348-
H_inv
347+
mu: RaveledVars
348+
The MAP estimate of the model parameters.
349+
H_inv: np.ndarray
350+
The inverse Hessian matrix of the log-posterior evaluated at the MAP estimate.
349351
model : Model
350352
A PyMC model
351353
chains : int
@@ -384,9 +386,7 @@ def sample_laplace_posterior(
384386
constrained_rvs, replace={unconstrained_vector: batched_values}
385387
)
386388

387-
f_constrain = pm.compile_pymc(
388-
inputs=[batched_values], outputs=batched_rvs, **compile_kwargs
389-
)
389+
f_constrain = pm.compile(inputs=[batched_values], outputs=batched_rvs, **compile_kwargs)
390390
posterior_draws = f_constrain(posterior_draws)
391391

392392
else:
@@ -472,15 +472,17 @@ def fit_laplace(
472472
and 1).
473473
474474
.. warning::
475-
This argumnet should be considered highly experimental. It has not been verified if this method produces
475+
This argument should be considered highly experimental. It has not been verified if this method produces
476476
valid draws from the posterior. **Use at your own risk**.
477477
478478
gradient_backend: str, default "pytensor"
479479
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
480480
chains: int, default: 2
481-
The number of sampling chains running in parallel.
481+
The number of chain dimensions to sample. Note that this is *not* the number of chains to run in parallel,
482+
because the Laplace approximation is not an MCMC method. This argument exists to ensure that outputs are
483+
compatible with the ArviZ library.
482484
draws: int, default: 500
483-
The number of samples to draw from the approximated posterior.
485+
The number of samples to draw from the approximated posterior. Totals samples will be chains * draws.
484486
on_bad_cov : str, one of 'ignore', 'warn', or 'error', default: 'ignore'
485487
What to do when ``H_inv`` (inverse Hessian) is not positive semi-definite.
486488
If 'ignore' or 'warn', the closest positive-semi-definite matrix to ``H_inv`` (in L1 norm) will be returned.
@@ -547,7 +549,7 @@ def fit_laplace(
547549
**optimizer_kwargs,
548550
)
549551

550-
mu, H_inv = fit_mvn_to_MAP(
552+
mu, H_inv = fit_mvn_at_MAP(
551553
optimized_point=optimized_point,
552554
model=model,
553555
on_bad_cov=on_bad_cov,

tests/test_find_map.py

+19-14
Original file line numberDiff line numberDiff line change
@@ -54,24 +54,28 @@ def compute_z(x):
5454

5555

5656
@pytest.mark.parametrize(
57-
"method, use_grad, use_hess",
57+
"method, use_grad, use_hess, use_hessp",
5858
[
59-
("nelder-mead", False, False),
60-
("powell", False, False),
61-
("CG", True, False),
62-
("BFGS", True, False),
63-
("L-BFGS-B", True, False),
64-
("TNC", True, False),
65-
("SLSQP", True, False),
66-
("dogleg", True, True),
67-
("trust-ncg", True, True),
68-
("trust-exact", True, True),
69-
("trust-krylov", True, True),
70-
("trust-constr", True, True),
59+
("nelder-mead", False, False, False),
60+
("powell", False, False, False),
61+
("CG", True, False, False),
62+
("BFGS", True, False, False),
63+
("L-BFGS-B", True, False, False),
64+
("TNC", True, False, False),
65+
("SLSQP", True, False, False),
66+
("dogleg", True, True, False),
67+
("Newton-CG", True, True, False),
68+
("Newton-CG", True, False, True),
69+
("trust-ncg", True, True, False),
70+
("trust-ncg", True, False, True),
71+
("trust-exact", True, True, False),
72+
("trust-krylov", True, True, False),
73+
("trust-krylov", True, False, True),
74+
("trust-constr", True, True, False),
7175
],
7276
)
7377
@pytest.mark.parametrize("gradient_backend", ["jax", "pytensor"], ids=str)
74-
def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend, rng):
78+
def test_JAX_map(method, use_grad, use_hess, use_hessp, gradient_backend: GradientBackend, rng):
7579
extra_kwargs = {}
7680
if method == "dogleg":
7781
# HACK -- dogleg requires that the hessian of the objective function is PSD, so we have to pick a point
@@ -88,6 +92,7 @@ def test_JAX_map(method, use_grad, use_hess, gradient_backend: GradientBackend,
8892
**extra_kwargs,
8993
use_grad=use_grad,
9094
use_hess=use_hess,
95+
use_hessp=use_hessp,
9196
progressbar=False,
9297
gradient_backend=gradient_backend,
9398
compile_kwargs={"mode": "JAX"},

tests/test_laplace.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from pymc_extras.inference.find_map import find_MAP
2323
from pymc_extras.inference.laplace import (
2424
fit_laplace,
25-
fit_mvn_to_MAP,
25+
fit_mvn_at_MAP,
2626
sample_laplace_posterior,
2727
)
2828

@@ -137,7 +137,7 @@ def test_fit_laplace_coords(rng, transform_samples, mode):
137137
for value in optimized_point.values():
138138
assert value.shape == (3,)
139139

140-
mu, H_inv = fit_mvn_to_MAP(
140+
mu, H_inv = fit_mvn_at_MAP(
141141
optimized_point=optimized_point,
142142
model=model,
143143
transform_samples=transform_samples,

0 commit comments

Comments
 (0)