4
4
import sys
5
5
import warnings
6
6
7
+ from typing import Callable , List
8
+
9
+ from aesara .graph import optimize_graph
10
+ from aesara .tensor import TensorVariable
11
+
7
12
xla_flags = os .getenv ("XLA_FLAGS" , "" ).lstrip ("--" )
8
13
xla_flags = re .sub (r"xla_force_host_platform_device_count=.+\s" , "" , xla_flags ).split ()
9
14
os .environ ["XLA_FLAGS" ] = " " .join ([f"--xla_force_host_platform_device_count={ 100 } " ])
18
23
from aesara .compile import SharedVariable
19
24
from aesara .graph .basic import clone_replace , graph_inputs
20
25
from aesara .graph .fg import FunctionGraph
21
- from aesara .graph .opt import MergeOptimizer
22
26
from aesara .link .jax .dispatch import jax_funcify
23
27
24
- from pymc import modelcontext
28
+ from pymc import Model , modelcontext
25
29
from pymc .aesaraf import compile_rv_inplace
26
30
27
31
warnings .warn ("This module is experimental." )
@@ -39,7 +43,7 @@ def assert_fn(value, *inps):
39
43
return assert_fn
40
44
41
45
42
- def replace_shared_variables (graph ) :
46
+ def replace_shared_variables (graph : List [ TensorVariable ]) -> List [ TensorVariable ] :
43
47
"""Replace shared variables in graph by their constant values
44
48
45
49
Raises
@@ -62,6 +66,34 @@ def replace_shared_variables(graph):
62
66
return new_graph
63
67
64
68
69
+ def get_jaxified_logp (model : Model ) -> Callable :
70
+ """Compile model.logpt into an optimized jax function"""
71
+
72
+ logpt = replace_shared_variables ([model .logpt ])[0 ]
73
+
74
+ logpt_fgraph = FunctionGraph (outputs = [logpt ], clone = False )
75
+ optimize_graph (logpt_fgraph , include = ["fast_run" ], exclude = ["cxx_only" , "BlasOpt" ])
76
+
77
+ # We now jaxify the optimized fgraph
78
+ logp_fn = jax_funcify (logpt_fgraph )
79
+
80
+ if isinstance (logp_fn , (list , tuple )):
81
+ # This handles the new JAX backend, which always returns a tuple
82
+ logp_fn = logp_fn [0 ]
83
+
84
+ def logp_fn_wrap (x ):
85
+ res = logp_fn (* x )
86
+
87
+ if isinstance (res , (list , tuple )):
88
+ # This handles the new JAX backend, which always returns a tuple
89
+ res = res [0 ]
90
+
91
+ # Jax expects a potential with the opposite sign of model.logpt
92
+ return - res
93
+
94
+ return logp_fn_wrap
95
+
96
+
65
97
def sample_numpyro_nuts (
66
98
draws = 1000 ,
67
99
tune = 1000 ,
@@ -83,27 +115,10 @@ def sample_numpyro_nuts(
83
115
init_state = [model .initial_point [rv_name ] for rv_name in rv_names ]
84
116
init_state_batched = jax .tree_map (lambda x : np .repeat (x [None , ...], chains , axis = 0 ), init_state )
85
117
86
- logpt = replace_shared_variables ([model .logpt ])[0 ]
87
- logpt_fgraph = FunctionGraph (outputs = [logpt ], clone = False )
88
- MergeOptimizer ().optimize (logpt_fgraph )
89
- logp_fn = jax_funcify (logpt_fgraph )
90
-
91
- if isinstance (logp_fn , (list , tuple )):
92
- # This handles the new JAX backend, which always returns a tuple
93
- logp_fn = logp_fn [0 ]
94
-
95
- def logp_fn_wrap (x ):
96
- res = logp_fn (* x )
97
-
98
- if isinstance (res , (list , tuple )):
99
- # This handles the new JAX backend, which always returns a tuple
100
- res = res [0 ]
101
-
102
- # Jax expects a potential with the opposite sign of model.logpt
103
- return - res
118
+ logp_fn = get_jaxified_logp (model )
104
119
105
120
nuts_kernel = NUTS (
106
- potential_fn = logp_fn_wrap ,
121
+ potential_fn = logp_fn ,
107
122
target_accept_prob = target_accept ,
108
123
adapt_step_size = True ,
109
124
adapt_mass_matrix = True ,
0 commit comments