Skip to content

Commit 3729614

Browse files
aseyboldtricardoV94
authored andcommitted
Add warning about future change in hessian sign
1 parent 82eae9a commit 3729614

File tree

6 files changed

+56
-17
lines changed

6 files changed

+56
-17
lines changed

pymc/model/core.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,7 @@ def compile_d2logp(
658658
self,
659659
vars: Variable | Sequence[Variable] | None = None,
660660
jacobian: bool = True,
661+
negate_output=True,
661662
**compile_kwargs,
662663
) -> PointFunc:
663664
"""Compiled log probability density hessian function.
@@ -670,7 +671,10 @@ def compile_d2logp(
670671
jacobian : bool
671672
Whether to include jacobian terms in logprob graph. Defaults to True.
672673
"""
673-
return self.compile_fn(self.d2logp(vars=vars, jacobian=jacobian), **compile_kwargs)
674+
return self.model.compile_fn(
675+
self.d2logp(vars=vars, jacobian=jacobian, negate_output=negate_output),
676+
**compile_kwargs,
677+
)
674678

675679
def logp(
676680
self,
@@ -794,6 +798,7 @@ def d2logp(
794798
self,
795799
vars: Variable | Sequence[Variable] | None = None,
796800
jacobian: bool = True,
801+
negate_output=True,
797802
) -> Variable:
798803
"""Hessian of the models log-probability w.r.t. ``vars``.
799804
@@ -827,7 +832,7 @@ def d2logp(
827832

828833
cost = self.logp(jacobian=jacobian)
829834
cost = rewrite_pregrad(cost)
830-
return hessian(cost, value_vars)
835+
return hessian(cost, value_vars, negate_output=negate_output)
831836

832837
@property
833838
def datalogp(self) -> Variable:

pymc/pytensorf.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -352,8 +352,17 @@ def grad_ii(i, f, x):
352352

353353

354354
@pytensor.config.change_flags(compute_test_value="ignore")
355-
def hessian(f, vars=None):
356-
return -jacobian(gradient(f, vars), vars)
355+
def hessian(f, vars=None, negate_output=True):
356+
res = jacobian(gradient(f, vars), vars)
357+
if negate_output:
358+
warnings.warn(
359+
"hessian will stop negating the output in a future version of PyMC.\n"
360+
"To suppress this warning set `negate_output=False`",
361+
FutureWarning,
362+
stacklevel=2,
363+
)
364+
res = -res
365+
return res
357366

358367

359368
@pytensor.config.change_flags(compute_test_value="ignore")
@@ -368,12 +377,21 @@ def hess_ii(i):
368377

369378

370379
@pytensor.config.change_flags(compute_test_value="ignore")
371-
def hessian_diag(f, vars=None):
380+
def hessian_diag(f, vars=None, negate_output=True):
372381
if vars is None:
373382
vars = cont_inputs(f)
374383

375384
if vars:
376-
return -pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
385+
res = pt.concatenate([hessian_diag1(f, v) for v in vars], axis=0)
386+
if negate_output:
387+
warnings.warn(
388+
"hessian_diag will stop negating the output in a future version of PyMC.\n"
389+
"To suppress this warning set `negate_output=False`",
390+
FutureWarning,
391+
stacklevel=2,
392+
)
393+
res = -res
394+
return res
377395
else:
378396
return empty_gradient
379397

pymc/sampling/mcmc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,7 @@ def init_nuts(
14611461
potential = quadpotential.QuadPotentialDiag(cov)
14621462
elif init == "map":
14631463
start = pm.find_MAP(include_transformed=True, seed=random_seed_list[0])
1464-
cov = pm.find_hessian(point=start)
1464+
cov = -pm.find_hessian(point=start, negate_output=False)
14651465
initial_points = [start] * chains
14661466
potential = quadpotential.QuadPotentialFull(cov)
14671467
elif init == "adapt_full":

pymc/tuning/scaling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def fixed_hessian(point, model=None):
4343
return rval
4444

4545

46-
def find_hessian(point, vars=None, model=None):
46+
def find_hessian(point, vars=None, model=None, negate_output=True):
4747
"""
4848
Returns Hessian of logp at the point passed.
4949
@@ -55,11 +55,11 @@ def find_hessian(point, vars=None, model=None):
5555
Variables for which Hessian is to be calculated.
5656
"""
5757
model = modelcontext(model)
58-
H = model.compile_d2logp(vars)
58+
H = model.compile_d2logp(vars, negate_output=negate_output)
5959
return H(Point(point, filter_model_vars=True, model=model))
6060

6161

62-
def find_hessian_diag(point, vars=None, model=None):
62+
def find_hessian_diag(point, vars=None, model=None, negate_output=True):
6363
"""
6464
Returns Hessian of logp at the point passed.
6565
@@ -71,14 +71,14 @@ def find_hessian_diag(point, vars=None, model=None):
7171
Variables for which Hessian is to be calculated.
7272
"""
7373
model = modelcontext(model)
74-
H = model.compile_fn(hessian_diag(model.logp(), vars))
74+
H = model.compile_fn(hessian_diag(model.logp(), vars, negate_output=negate_output))
7575
return H(Point(point, model=model))
7676

7777

7878
def guess_scaling(point, vars=None, model=None, scaling_bound=1e-8):
7979
model = modelcontext(model)
8080
try:
81-
h = find_hessian_diag(point, vars, model=model)
81+
h = -find_hessian_diag(point, vars, model=model, negate_output=False)
8282
except NotImplementedError:
8383
h = fixed_hessian(point, model=model)
8484
return adjust_scaling(h, scaling_bound)

tests/model/test_core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,16 +1012,16 @@ def test_model_d2logp(jacobian):
10121012
test_vals = np.array([0.0, -1.0])
10131013
state = {"x": test_vals, "y_log__": test_vals}
10141014

1015-
expected_x_d2logp = expected_y_d2logp = np.eye(2)
1015+
expected_x_d2logp = expected_y_d2logp = -np.eye(2)
10161016

1017-
dlogps = m.compile_d2logp(jacobian=jacobian)(state)
1017+
dlogps = m.compile_d2logp(jacobian=jacobian, negate_output=False)(state)
10181018
assert np.all(np.isclose(dlogps[:2, :2], expected_x_d2logp))
10191019
assert np.all(np.isclose(dlogps[2:, 2:], expected_y_d2logp))
10201020

1021-
x_dlogp2 = m.compile_d2logp(vars=[x], jacobian=jacobian)(state)
1021+
x_dlogp2 = m.compile_d2logp(vars=[x], jacobian=jacobian, negate_output=False)(state)
10221022
assert np.all(np.isclose(x_dlogp2, expected_x_d2logp))
10231023

1024-
y_dlogp2 = m.compile_d2logp(vars=[y], jacobian=jacobian)(state)
1024+
y_dlogp2 = m.compile_d2logp(vars=[y], jacobian=jacobian, negate_output=False)(state)
10251025
assert np.all(np.isclose(y_dlogp2, expected_y_d2logp))
10261026

10271027

tests/test_pytensorf.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from pytensor import scan, shared
2626
from pytensor.compile import UnusedInputError
2727
from pytensor.compile.builders import OpFromGraph
28-
from pytensor.graph.basic import Variable
28+
from pytensor.graph.basic import Variable, equal_computations
2929
from pytensor.tensor.random.basic import normal, uniform
3030
from pytensor.tensor.random.var import RandomStateSharedVariable
3131
from pytensor.tensor.subtensor import AdvancedIncSubtensor, AdvancedIncSubtensor1
@@ -43,6 +43,8 @@
4343
constant_fold,
4444
convert_observed_data,
4545
extract_obs_data,
46+
hessian,
47+
hessian_diag,
4648
replace_rng_nodes,
4749
replace_vars_in_graphs,
4850
reseed_rngs,
@@ -726,3 +728,17 @@ def test_replace_vars_in_graphs_nested_reference():
726728
assert np.abs(x.eval()) < 1
727729
# Confirm the original `y` variable is not changed in place
728730
assert np.abs(y.eval()) < 1
731+
732+
733+
@pytest.mark.filterwarnings("error")
734+
@pytest.mark.parametrize("func", (hessian, hessian_diag))
735+
def test_hessian_sign_change_warning(func):
736+
x = pt.vector("x")
737+
f = (x**2).sum()
738+
with pytest.warns(
739+
FutureWarning,
740+
match="will stop negating the output",
741+
):
742+
res_neg = func(f, vars=[x])
743+
res = func(f, vars=[x], negate_output=False)
744+
assert equal_computations([res_neg], [-res])

0 commit comments

Comments
 (0)