3
3
import re
4
4
import warnings
5
5
6
- from collections import defaultdict
7
-
8
6
xla_flags = os .getenv ("XLA_FLAGS" , "" ).lstrip ("--" )
9
7
xla_flags = re .sub (r"xla_force_host_platform_device_count=.+\s" , "" , xla_flags ).split ()
10
8
os .environ ["XLA_FLAGS" ] = " " .join (["--xla_force_host_platform_device_count={}" .format (100 )])
11
9
12
10
import aesara .graph .fg
11
+ import aesara .tensor as at
13
12
import arviz as az
14
13
import jax
15
14
import numpy as np
16
15
import pandas as pd
17
16
18
- from aesara .link .jax .jax_dispatch import jax_funcify
19
-
20
- import pymc3 as pm
17
+ from aesara .compile import SharedVariable
18
+ from aesara .graph .basic import Apply , Constant , clone , graph_inputs
19
+ from aesara .graph .fg import FunctionGraph
20
+ from aesara .graph .op import Op
21
+ from aesara .graph .opt import MergeOptimizer
22
+ from aesara .link .jax .dispatch import jax_funcify
23
+ from aesara .tensor .type import TensorType
21
24
22
25
from pymc3 import modelcontext
23
26
24
27
warnings .warn ("This module is experimental." )
25
28
26
- # Disable C compilation by default
27
- # aesara.config.cxx = ""
28
- # This will make the JAX Linker the default
29
- # aesara.config.mode = "JAX"
30
29
30
+ class NumPyroNUTS (Op ):
31
+ def __init__ (
32
+ self ,
33
+ inputs ,
34
+ outputs ,
35
+ target_accept = 0.8 ,
36
+ draws = 1000 ,
37
+ tune = 1000 ,
38
+ chains = 4 ,
39
+ seed = None ,
40
+ progress_bar = True ,
41
+ ):
42
+ self .draws = draws
43
+ self .tune = tune
44
+ self .chains = chains
45
+ self .target_accept = target_accept
46
+ self .progress_bar = progress_bar
47
+ self .seed = seed
31
48
32
- def sample_tfp_nuts (
33
- draws = 1000 ,
34
- tune = 1000 ,
35
- chains = 4 ,
36
- target_accept = 0.8 ,
37
- random_seed = 10 ,
38
- model = None ,
39
- num_tuning_epoch = 2 ,
40
- num_compute_step_size = 500 ,
41
- ):
42
- import jax
49
+ self .inputs , self .outputs = clone (inputs , outputs , copy_inputs = False )
50
+ self .inputs_type = tuple ([input .type for input in inputs ])
51
+ self .outputs_type = tuple ([output .type for output in outputs ])
52
+ self .nin = len (inputs )
53
+ self .nout = len (outputs )
54
+ self .nshared = len ([v for v in inputs if isinstance (v , SharedVariable )])
55
+ self .samples_bcast = [self .chains == 1 , self .draws == 1 ]
43
56
44
- from tensorflow_probability .substrates import jax as tfp
57
+ self .fgraph = FunctionGraph (self .inputs , self .outputs , clone = False )
58
+ MergeOptimizer ().optimize (self .fgraph )
45
59
46
- model = modelcontext ( model )
60
+ super (). __init__ ( )
47
61
48
- seed = jax . random . PRNGKey ( random_seed )
62
+ def make_node ( self , * inputs ):
49
63
50
- fgraph = model .logp .f .maker .fgraph
51
- fns = jax_funcify (fgraph )
52
- logp_fn_jax = fns [0 ]
64
+ # The samples for each variable
65
+ outputs = [
66
+ TensorType (v .dtype , self .samples_bcast + list (v .broadcastable ))() for v in inputs
67
+ ]
53
68
54
- rv_names = [rv .name for rv in model .free_RVs ]
55
- init_state = [model .initial_point [rv_name ] for rv_name in rv_names ]
56
- init_state_batched = jax .tree_map (lambda x : np .repeat (x [None , ...], chains , axis = 0 ), init_state )
69
+ # The leapfrog statistics
70
+ outputs += [TensorType ("int64" , self .samples_bcast )()]
57
71
58
- @jax .pmap
59
- def _sample (init_state , seed ):
60
- def gen_kernel (step_size ):
61
- hmc = tfp .mcmc .NoUTurnSampler (target_log_prob_fn = logp_fn_jax , step_size = step_size )
62
- return tfp .mcmc .DualAveragingStepSizeAdaptation (
63
- hmc , tune // num_tuning_epoch , target_accept_prob = target_accept
64
- )
72
+ all_inputs = list (inputs )
73
+ if self .nshared > 0 :
74
+ all_inputs += self .inputs [- self .nshared :]
65
75
66
- def trace_fn (_ , pkr ):
67
- return pkr .new_step_size
68
-
69
- def get_tuned_stepsize (samples , step_size ):
70
- return step_size [- 1 ] * jax .numpy .std (samples [- num_compute_step_size :])
71
-
72
- step_size = jax .tree_map (jax .numpy .ones_like , init_state )
73
- for i in range (num_tuning_epoch - 1 ):
74
- tuning_hmc = gen_kernel (step_size )
75
- init_samples , tuning_result , kernel_results = tfp .mcmc .sample_chain (
76
- num_results = tune // num_tuning_epoch ,
77
- current_state = init_state ,
78
- kernel = tuning_hmc ,
79
- trace_fn = trace_fn ,
80
- return_final_kernel_results = True ,
81
- seed = seed ,
82
- )
76
+ return Apply (self , all_inputs , outputs )
83
77
84
- step_size = jax .tree_multimap (get_tuned_stepsize , list (init_samples ), tuning_result )
85
- init_state = [x [- 1 ] for x in init_samples ]
86
-
87
- # Run inference
88
- sample_kernel = gen_kernel (step_size )
89
- mcmc_samples , leapfrog_num = tfp .mcmc .sample_chain (
90
- num_results = draws ,
91
- num_burnin_steps = tune // num_tuning_epoch ,
92
- current_state = init_state ,
93
- kernel = sample_kernel ,
94
- trace_fn = lambda _ , pkr : pkr .inner_results .leapfrogs_taken ,
95
- seed = seed ,
96
- )
78
+ def do_constant_folding (self , * args ):
79
+ return False
97
80
98
- return mcmc_samples , leapfrog_num
81
+ def perform (self , node , inputs , outputs ):
82
+ raise NotImplementedError ()
99
83
100
- print ("Compiling..." )
101
- tic2 = pd .Timestamp .now ()
102
- map_seed = jax .random .split (seed , chains )
103
- mcmc_samples , leapfrog_num = _sample (init_state_batched , map_seed )
104
-
105
- # map_seed = jax.random.split(seed, chains)
106
- # mcmc_samples = _sample(init_state_batched, map_seed)
107
- # tic4 = pd.Timestamp.now()
108
- # print("Sampling time = ", tic4 - tic3)
109
-
110
- posterior = {k : v for k , v in zip (rv_names , mcmc_samples )}
111
84
112
- az_trace = az .from_dict (posterior = posterior )
113
- tic3 = pd .Timestamp .now ()
114
- print ("Compilation + sampling time = " , tic3 - tic2 )
115
- return az_trace # , leapfrog_num, tic3 - tic2
116
-
117
-
118
- def sample_numpyro_nuts (
119
- draws = 1000 ,
120
- tune = 1000 ,
121
- chains = 4 ,
122
- target_accept = 0.8 ,
123
- random_seed = 10 ,
124
- model = None ,
125
- progress_bar = True ,
126
- keep_untransformed = False ,
127
- ):
85
+ @jax_funcify .register (NumPyroNUTS )
86
+ def jax_funcify_NumPyroNUTS (op , node , ** kwargs ):
128
87
from numpyro .infer import MCMC , NUTS
129
88
130
- from pymc3 import modelcontext
89
+ draws = op .draws
90
+ tune = op .tune
91
+ chains = op .chains
92
+ target_accept = op .target_accept
93
+ progress_bar = op .progress_bar
94
+ seed = op .seed
95
+
96
+ # Compile the "inner" log-likelihood function. This will have extra shared
97
+ # variable inputs as the last arguments
98
+ logp_fn = jax_funcify (op .fgraph , ** kwargs )
99
+
100
+ if isinstance (logp_fn , (list , tuple )):
101
+ # This handles the new JAX backend, which always returns a tuple
102
+ logp_fn = logp_fn [0 ]
103
+
104
+ def _sample (* inputs ):
105
+
106
+ if op .nshared > 0 :
107
+ current_state = inputs [: - op .nshared ]
108
+ shared_inputs = tuple (op .fgraph .inputs [- op .nshared :])
109
+ else :
110
+ current_state = inputs
111
+ shared_inputs = ()
112
+
113
+ def log_fn_wrap (x ):
114
+ res = logp_fn (
115
+ * (
116
+ x
117
+ # We manually obtain the shared values and added them
118
+ # as arguments to our compiled "inner" function
119
+ + tuple (
120
+ v .get_value (borrow = True , return_internal_type = True ) for v in shared_inputs
121
+ )
122
+ )
123
+ )
131
124
132
- model = modelcontext (model )
125
+ if isinstance (res , (list , tuple )):
126
+ # This handles the new JAX backend, which always returns a tuple
127
+ res = res [0 ]
133
128
134
- seed = jax . random . PRNGKey ( random_seed )
129
+ return - res
135
130
136
- fgraph = aesara .graph .fg .FunctionGraph (model .free_RVs , [model .logpt ])
137
- fns = jax_funcify (fgraph )
138
- logp_fn_jax = fns [0 ]
139
-
140
- rv_names = [rv .name for rv in model .free_RVs ]
141
- init_state = [model .initial_point [rv_name ] for rv_name in rv_names ]
142
- init_state_batched = jax .tree_map (lambda x : np .repeat (x [None , ...], chains , axis = 0 ), init_state )
143
-
144
- @jax .jit
145
- def _sample (current_state , seed ):
146
- step_size = jax .tree_map (jax .numpy .ones_like , init_state )
147
131
nuts_kernel = NUTS (
148
- potential_fn = lambda x : - logp_fn_jax (* x ),
149
- # model=model,
132
+ potential_fn = log_fn_wrap ,
150
133
target_accept_prob = target_accept ,
151
134
adapt_step_size = True ,
152
135
adapt_mass_matrix = True ,
@@ -166,60 +149,87 @@ def _sample(current_state, seed):
166
149
pmap_numpyro .run (seed , init_params = current_state , extra_fields = ("num_steps" ,))
167
150
samples = pmap_numpyro .get_samples (group_by_chain = True )
168
151
leapfrogs_taken = pmap_numpyro .get_extra_fields (group_by_chain = True )["num_steps" ]
169
- return samples , leapfrogs_taken
170
-
171
- print ("Compiling..." )
172
- tic2 = pd .Timestamp .now ()
173
- map_seed = jax .random .split (seed , chains )
174
- mcmc_samples , leapfrogs_taken = _sample (init_state_batched , map_seed )
175
- # map_seed = jax.random.split(seed, chains)
176
- # mcmc_samples = _sample(init_state_batched, map_seed)
177
- # tic4 = pd.Timestamp.now()
178
- # print("Sampling time = ", tic4 - tic3)
152
+ return tuple (samples ) + (leapfrogs_taken ,)
179
153
180
- posterior = {k : v for k , v in zip (rv_names , mcmc_samples )}
181
- tic3 = pd .Timestamp .now ()
182
- posterior = _transform_samples (posterior , model , keep_untransformed = keep_untransformed )
183
- tic4 = pd .Timestamp .now ()
154
+ return _sample
184
155
185
- az_trace = az .from_dict (posterior = posterior )
186
- print ("Compilation + sampling time = " , tic3 - tic2 )
187
- print ("Transformation time = " , tic4 - tic3 )
188
156
189
- return az_trace # , leapfrogs_taken, tic3 - tic2
157
+ def sample_numpyro_nuts (
158
+ draws = 1000 ,
159
+ tune = 1000 ,
160
+ chains = 4 ,
161
+ target_accept = 0.8 ,
162
+ random_seed = 10 ,
163
+ model = None ,
164
+ progress_bar = True ,
165
+ keep_untransformed = False ,
166
+ ):
167
+ model = modelcontext (model )
190
168
169
+ seed = jax .random .PRNGKey (random_seed )
191
170
192
- def _transform_samples (samples , model , keep_untransformed = False ):
171
+ rv_names = [rv .name for rv in model .value_vars ]
172
+ init_state = [model .initial_point [rv_name ] for rv_name in rv_names ]
173
+ init_state_batched = jax .tree_map (lambda x : np .repeat (x [None , ...], chains , axis = 0 ), init_state )
174
+ init_state_batched_at = [at .as_tensor (v ) for v in init_state_batched ]
193
175
194
- # Find out which RVs we need to compute:
195
- free_rv_names = {x .name for x in model .free_RVs }
196
- unobserved_names = {x .name for x in model .unobserved_RVs }
176
+ nuts_inputs = sorted (
177
+ [v for v in graph_inputs ([model .logpt ]) if not isinstance (v , Constant )],
178
+ key = lambda x : isinstance (x , SharedVariable ),
179
+ )
180
+ map_seed = jax .random .split (seed , chains )
181
+ numpyro_samples = NumPyroNUTS (
182
+ nuts_inputs ,
183
+ [model .logpt ],
184
+ target_accept = target_accept ,
185
+ draws = draws ,
186
+ tune = tune ,
187
+ chains = chains ,
188
+ seed = map_seed ,
189
+ progress_bar = progress_bar ,
190
+ )(* init_state_batched_at )
191
+
192
+ # Un-transform the transformed variables in JAX
193
+ sample_outputs = []
194
+ for i , (value_var , rv_samples ) in enumerate (zip (model .value_vars , numpyro_samples [:- 1 ])):
195
+ rv = model .values_to_rvs [value_var ]
196
+ transform = getattr (value_var .tag , "transform" , None )
197
+ if transform is not None :
198
+ untrans_value_var = transform .backward (rv , rv_samples )
199
+ untrans_value_var .name = rv .name
200
+ sample_outputs .append (untrans_value_var )
201
+
202
+ if keep_untransformed :
203
+ rv_samples .name = value_var .name
204
+ sample_outputs .append (rv_samples )
205
+ else :
206
+ rv_samples .name = rv .name
207
+ sample_outputs .append (rv_samples )
197
208
198
- names_to_compute = unobserved_names - free_rv_names
199
- ops_to_compute = [x for x in model .unobserved_RVs if x .name in names_to_compute ]
209
+ print ("Compiling..." )
200
210
201
- # Create function graph for these:
202
- fgraph = aesara .graph .fg .FunctionGraph (model .free_RVs , ops_to_compute )
211
+ tic1 = pd .Timestamp .now ()
212
+ _sample = aesara .function (
213
+ [],
214
+ sample_outputs + [numpyro_samples [- 1 ]],
215
+ allow_input_downcast = True ,
216
+ on_unused_input = "ignore" ,
217
+ accept_inplace = True ,
218
+ mode = "JAX" ,
219
+ )
220
+ tic2 = pd .Timestamp .now ()
203
221
204
- # Jaxify, which returns a list of functions, one for each op
205
- jax_fns = jax_funcify (fgraph )
222
+ print ("Compilation time = " , tic2 - tic1 )
206
223
207
- # Put together the inputs
208
- inputs = [samples [x .name ] for x in model .free_RVs ]
224
+ print ("Sampling..." )
209
225
210
- for cur_op , cur_jax_fn in zip (ops_to_compute , jax_fns ):
226
+ * mcmc_samples , leapfrogs_taken = _sample ()
227
+ tic3 = pd .Timestamp .now ()
211
228
212
- # We need a function taking a single argument to run vmap, while the
213
- # jax_fn takes a list, so:
214
- result = jax .vmap (jax .vmap (cur_jax_fn ))(* inputs )
229
+ print ("Sampling time = " , tic3 - tic2 )
215
230
216
- # Add to sample dict
217
- samples [cur_op .name ] = result
231
+ posterior = {k .name : v for k , v in zip (sample_outputs , mcmc_samples )}
218
232
219
- # Discard unwanted transformed variables, if desired:
220
- vars_to_keep = set (
221
- pm .util .get_default_varnames (list (samples .keys ()), include_transformed = keep_untransformed )
222
- )
223
- samples = {x : y for x , y in samples .items () if x in vars_to_keep }
233
+ az_trace = az .from_dict (posterior = posterior )
224
234
225
- return samples
235
+ return az_trace
0 commit comments