Skip to content

Commit 541dec4

Browse files
ricardoV94twiecki
authored andcommitted
Optimize sample_numpyro_nuts potential fn
Fixes regression caused by #5092
1 parent 0a172c8 commit 541dec4

File tree

2 files changed

+54
-24
lines changed

2 files changed

+54
-24
lines changed

pymc/sampling_jax.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import sys
55
import warnings
66

7+
from typing import Callable, List
8+
9+
from aesara.graph import optimize_graph
10+
from aesara.tensor import TensorVariable
11+
712
xla_flags = os.getenv("XLA_FLAGS", "").lstrip("--")
813
xla_flags = re.sub(r"xla_force_host_platform_device_count=.+\s", "", xla_flags).split()
914
os.environ["XLA_FLAGS"] = " ".join([f"--xla_force_host_platform_device_count={100}"])
@@ -18,10 +23,9 @@
1823
from aesara.compile import SharedVariable
1924
from aesara.graph.basic import clone_replace, graph_inputs
2025
from aesara.graph.fg import FunctionGraph
21-
from aesara.graph.opt import MergeOptimizer
2226
from aesara.link.jax.dispatch import jax_funcify
2327

24-
from pymc import modelcontext
28+
from pymc import Model, modelcontext
2529
from pymc.aesaraf import compile_rv_inplace
2630

2731
warnings.warn("This module is experimental.")
@@ -39,7 +43,7 @@ def assert_fn(value, *inps):
3943
return assert_fn
4044

4145

42-
def replace_shared_variables(graph):
46+
def replace_shared_variables(graph: List[TensorVariable]) -> List[TensorVariable]:
4347
"""Replace shared variables in graph by their constant values
4448
4549
Raises
@@ -62,6 +66,34 @@ def replace_shared_variables(graph):
6266
return new_graph
6367

6468

69+
def get_jaxified_logp(model: Model) -> Callable:
70+
"""Compile model.logpt into an optimized jax function"""
71+
72+
logpt = replace_shared_variables([model.logpt])[0]
73+
74+
logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False)
75+
optimize_graph(logpt_fgraph, include=["fast_run"], exclude=["cxx_only", "BlasOpt"])
76+
77+
# We now jaxify the optimized fgraph
78+
logp_fn = jax_funcify(logpt_fgraph)
79+
80+
if isinstance(logp_fn, (list, tuple)):
81+
# This handles the new JAX backend, which always returns a tuple
82+
logp_fn = logp_fn[0]
83+
84+
def logp_fn_wrap(x):
85+
res = logp_fn(*x)
86+
87+
if isinstance(res, (list, tuple)):
88+
# This handles the new JAX backend, which always returns a tuple
89+
res = res[0]
90+
91+
# Jax expects a potential with the opposite sign of model.logpt
92+
return -res
93+
94+
return logp_fn_wrap
95+
96+
6597
def sample_numpyro_nuts(
6698
draws=1000,
6799
tune=1000,
@@ -83,27 +115,10 @@ def sample_numpyro_nuts(
83115
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
84116
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
85117

86-
logpt = replace_shared_variables([model.logpt])[0]
87-
logpt_fgraph = FunctionGraph(outputs=[logpt], clone=False)
88-
MergeOptimizer().optimize(logpt_fgraph)
89-
logp_fn = jax_funcify(logpt_fgraph)
90-
91-
if isinstance(logp_fn, (list, tuple)):
92-
# This handles the new JAX backend, which always returns a tuple
93-
logp_fn = logp_fn[0]
94-
95-
def logp_fn_wrap(x):
96-
res = logp_fn(*x)
97-
98-
if isinstance(res, (list, tuple)):
99-
# This handles the new JAX backend, which always returns a tuple
100-
res = res[0]
101-
102-
# Jax expects a potential with the opposite sign of model.logpt
103-
return -res
118+
logp_fn = get_jaxified_logp(model)
104119

105120
nuts_kernel = NUTS(
106-
potential_fn=logp_fn_wrap,
121+
potential_fn=logp_fn,
107122
target_accept_prob=target_accept,
108123
adapt_step_size=True,
109124
adapt_mass_matrix=True,

pymc/tests/test_sampling_jax.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import aesara
2+
import aesara.tensor as at
23
import numpy as np
34
import pytest
45

@@ -7,7 +8,11 @@
78

89
import pymc as pm
910

10-
from pymc.sampling_jax import replace_shared_variables, sample_numpyro_nuts
11+
from pymc.sampling_jax import (
12+
get_jaxified_logp,
13+
replace_shared_variables,
14+
sample_numpyro_nuts,
15+
)
1116

1217

1318
def test_transform_samples():
@@ -40,7 +45,6 @@ def test_transform_samples():
4045

4146

4247
def test_replace_shared_variables():
43-
4448
x = aesara.shared(5, name="shared_x")
4549

4650
new_x = replace_shared_variables([x])
@@ -50,3 +54,14 @@ def test_replace_shared_variables():
5054
x.default_update = x + 1
5155
with pytest.raises(ValueError, match="shared variables with default_update"):
5256
replace_shared_variables([x])
57+
58+
59+
def test_get_jaxified_logp():
60+
with pm.Model() as m:
61+
x = pm.Flat("x")
62+
y = pm.Flat("y")
63+
pm.Potential("pot", at.log(at.exp(x) + at.exp(y)))
64+
65+
jax_fn = get_jaxified_logp(m)
66+
# This would underflow if not optimized
67+
assert not np.isinf(jax_fn((np.array(5000.0), np.array(5000.0))))

0 commit comments

Comments
 (0)