@@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
105
105
assert_size_argument_jax_compatible (node )
106
106
107
107
def sample_fn (rng , size , * parameters ):
108
- return jax_sample_fn (op , node = node )(rng , size , out_dtype , * parameters )
108
+ rng_key = rng ["jax_state" ]
109
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
110
+ rng ["jax_state" ] = rng_key
111
+ sample = jax_sample_fn (op , node = node )(
112
+ sampling_key , size , out_dtype , * parameters
113
+ )
114
+ return (rng , sample )
109
115
110
116
else :
111
117
112
118
def sample_fn (rng , size , * parameters ):
113
- return jax_sample_fn (op , node = node )(
114
- rng , static_size , out_dtype , * parameters
119
+ rng_key = rng ["jax_state" ]
120
+ rng_key , sampling_key = jax .random .split (rng_key , 2 )
121
+ rng ["jax_state" ] = rng_key
122
+ sample = jax_sample_fn (op , node = node )(
123
+ sampling_key , static_size , out_dtype , * parameters
115
124
)
125
+ return (rng , sample )
116
126
117
127
return sample_fn
118
128
@@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node):
133
143
name = op .name
134
144
jax_op = getattr (jax .random , name )
135
145
136
- def sample_fn (rng , size , dtype , * parameters ):
137
- rng_key = rng ["jax_state" ]
138
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
139
- sample = jax_op (sampling_key , * parameters , shape = size , dtype = dtype )
140
- rng ["jax_state" ] = rng_key
141
- return (rng , sample )
146
+ def sample_fn (rng_key , size , dtype , * parameters ):
147
+ sample = jax_op (rng_key , * parameters , shape = size , dtype = dtype )
148
+ return sample
142
149
143
150
return sample_fn
144
151
@@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node):
159
166
name = op .name
160
167
jax_op = getattr (jax .random , name )
161
168
162
- def sample_fn (rng , size , dtype , * parameters ):
163
- rng_key = rng ["jax_state" ]
164
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
169
+ def sample_fn (rng_key , size , dtype , * parameters ):
165
170
loc , scale = parameters
166
171
if size is None :
167
172
size = jax .numpy .broadcast_arrays (loc , scale )[0 ].shape
168
- sample = loc + jax_op (sampling_key , size , dtype ) * scale
169
- rng ["jax_state" ] = rng_key
170
- return (rng , sample )
173
+ sample = loc + jax_op (rng_key , size , dtype ) * scale
174
+ return sample
171
175
172
176
return sample_fn
173
177
174
178
175
179
@jax_sample_fn .register (ptr .MvNormalRV )
176
180
def jax_sample_mvnormal (op , node ):
177
- def sample_fn (rng , size , dtype , mean , cov ):
178
- rng_key = rng ["jax_state" ]
179
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
181
+ def sample_fn (rng_key , size , dtype , mean , cov ):
180
182
sample = jax .random .multivariate_normal (
181
- sampling_key , mean , cov , shape = size , dtype = dtype , method = op .method
183
+ rng_key , mean , cov , shape = size , dtype = dtype , method = op .method
182
184
)
183
- rng ["jax_state" ] = rng_key
184
- return (rng , sample )
185
+ return sample
185
186
186
187
return sample_fn
187
188
@@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node):
191
192
"""JAX implementation of `BernoulliRV`."""
192
193
193
194
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
194
- def sample_fn (rng , size , dtype , p ):
195
- rng_key = rng ["jax_state" ]
196
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
197
- sample = jax .random .bernoulli (sampling_key , p , shape = size )
198
- rng ["jax_state" ] = rng_key
199
- return (rng , sample )
195
+ def sample_fn (rng_key , size , dtype , p ):
196
+ sample = jax .random .bernoulli (rng_key , p , shape = size )
197
+ return sample
200
198
201
199
return sample_fn
202
200
@@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node):
206
204
"""JAX implementation of `CategoricalRV`."""
207
205
208
206
# We need a separate dispatch because Categorical expects logits in JAX
209
- def sample_fn (rng , size , dtype , p ):
210
- rng_key = rng ["jax_state" ]
211
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
212
-
207
+ def sample_fn (rng_key , size , dtype , p ):
213
208
logits = jax .scipy .special .logit (p )
214
- sample = jax .random .categorical (sampling_key , logits = logits , shape = size )
215
- rng ["jax_state" ] = rng_key
216
- return (rng , sample )
209
+ sample = jax .random .categorical (rng_key , logits = logits , shape = size )
210
+ return sample
217
211
218
212
return sample_fn
219
213
@@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node):
233
227
name = "randint"
234
228
jax_op = getattr (jax .random , name )
235
229
236
- def sample_fn (rng , size , dtype , * parameters ):
237
- rng_key = rng ["jax_state" ]
238
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
230
+ def sample_fn (rng_key , size , dtype , * parameters ):
239
231
minval , maxval = parameters
240
- sample = jax_op (
241
- sampling_key , shape = size , dtype = dtype , minval = minval , maxval = maxval
242
- )
243
- rng ["jax_state" ] = rng_key
244
- return (rng , sample )
232
+ sample = jax_op (rng_key , shape = size , dtype = dtype , minval = minval , maxval = maxval )
233
+ return sample
245
234
246
235
return sample_fn
247
236
@@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node):
258
247
name = op .name
259
248
jax_op = getattr (jax .random , name )
260
249
261
- def sample_fn (rng , size , dtype , shape , scale ):
262
- rng_key = rng ["jax_state" ]
263
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
250
+ def sample_fn (rng_key , size , dtype , shape , scale ):
264
251
if size is None :
265
252
size = jax .numpy .broadcast_arrays (shape , scale )[0 ].shape
266
- sample = jax_op (sampling_key , shape , size , dtype ) * scale
267
- rng ["jax_state" ] = rng_key
268
- return (rng , sample )
253
+ sample = jax_op (rng_key , shape , size , dtype ) * scale
254
+ return sample
269
255
270
256
return sample_fn
271
257
@@ -274,14 +260,11 @@ def sample_fn(rng, size, dtype, shape, scale):
274
260
def jax_sample_fn_exponential (op , node ):
275
261
"""JAX implementation of `ExponentialRV`."""
276
262
277
- def sample_fn (rng , size , dtype , scale ):
278
- rng_key = rng ["jax_state" ]
279
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
263
+ def sample_fn (rng_key , size , dtype , scale ):
280
264
if size is None :
281
265
size = jax .numpy .asarray (scale ).shape
282
- sample = jax .random .exponential (sampling_key , size , dtype ) * scale
283
- rng ["jax_state" ] = rng_key
284
- return (rng , sample )
266
+ sample = jax .random .exponential (rng_key , size , dtype ) * scale
267
+ return sample
285
268
286
269
return sample_fn
287
270
@@ -290,14 +273,11 @@ def sample_fn(rng, size, dtype, scale):
290
273
def jax_sample_fn_t (op , node ):
291
274
"""JAX implementation of `StudentTRV`."""
292
275
293
- def sample_fn (rng , size , dtype , df , loc , scale ):
294
- rng_key = rng ["jax_state" ]
295
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
276
+ def sample_fn (rng_key , size , dtype , df , loc , scale ):
296
277
if size is None :
297
278
size = jax .numpy .broadcast_arrays (df , loc , scale )[0 ].shape
298
- sample = loc + jax .random .t (sampling_key , df , size , dtype ) * scale
299
- rng ["jax_state" ] = rng_key
300
- return (rng , sample )
279
+ sample = loc + jax .random .t (rng_key , df , size , dtype ) * scale
280
+ return sample
301
281
302
282
return sample_fn
303
283
@@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
315
295
"A default JAX rewrite should have materialized the implicit arange"
316
296
)
317
297
318
- def sample_fn (rng , size , dtype , * parameters ):
319
- rng_key = rng ["jax_state" ]
320
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
321
-
298
+ def sample_fn (rng_key , size , dtype , * parameters ):
322
299
if op .has_p_param :
323
300
a , p , core_shape = parameters
324
301
else :
@@ -327,9 +304,7 @@ def sample_fn(rng, size, dtype, *parameters):
327
304
core_shape = tuple (np .asarray (core_shape )[(0 ,) * batch_ndim ])
328
305
329
306
if batch_ndim == 0 :
330
- sample = jax .random .choice (
331
- sampling_key , a , shape = core_shape , replace = False , p = p
332
- )
307
+ sample = jax .random .choice (rng_key , a , shape = core_shape , replace = False , p = p )
333
308
334
309
else :
335
310
if size is None :
@@ -345,7 +320,7 @@ def sample_fn(rng, size, dtype, *parameters):
345
320
if p is not None :
346
321
p = jax .numpy .broadcast_to (p , size + p .shape [batch_ndim :])
347
322
348
- batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
323
+ batch_sampling_keys = jax .random .split (rng_key , np .prod (size ))
349
324
350
325
# Ravel the batch dimensions because vmap only works along a single axis
351
326
raveled_batch_a = a .reshape ((- 1 ,) + a .shape [batch_ndim :])
@@ -366,8 +341,7 @@ def sample_fn(rng, size, dtype, *parameters):
366
341
# Reshape the batch dimensions
367
342
sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
368
343
369
- rng ["jax_state" ] = rng_key
370
- return (rng , sample )
344
+ return sample
371
345
372
346
return sample_fn
373
347
@@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node):
378
352
379
353
batch_ndim = op .batch_ndim (node )
380
354
381
- def sample_fn (rng , size , dtype , * parameters ):
382
- rng_key = rng ["jax_state" ]
383
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
355
+ def sample_fn (rng_key , size , dtype , * parameters ):
384
356
(x ,) = parameters
385
357
if batch_ndim :
386
358
# jax.random.permutation has no concept of batch dims
@@ -389,17 +361,16 @@ def sample_fn(rng, size, dtype, *parameters):
389
361
else :
390
362
x = jax .numpy .broadcast_to (x , size + x .shape [batch_ndim :])
391
363
392
- batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
364
+ batch_sampling_keys = jax .random .split (rng_key , np .prod (size ))
393
365
raveled_batch_x = x .reshape ((- 1 ,) + x .shape [batch_ndim :])
394
366
raveled_sample = jax .vmap (lambda key , x : jax .random .permutation (key , x ))(
395
367
batch_sampling_keys , raveled_batch_x
396
368
)
397
369
sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
398
370
else :
399
- sample = jax .random .permutation (sampling_key , x )
371
+ sample = jax .random .permutation (rng_key , x )
400
372
401
- rng ["jax_state" ] = rng_key
402
- return (rng , sample )
373
+ return sample
403
374
404
375
return sample_fn
405
376
@@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node):
414
385
415
386
from numpyro .distributions .util import binomial
416
387
417
- def sample_fn (rng , size , dtype , n , p ):
418
- rng_key = rng ["jax_state" ]
419
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
420
-
421
- sample = binomial (key = sampling_key , n = n , p = p , shape = size )
422
-
423
- rng ["jax_state" ] = rng_key
424
-
425
- return (rng , sample )
388
+ def sample_fn (rng_key , size , dtype , n , p ):
389
+ sample = binomial (key = rng_key , n = n , p = p , shape = size )
390
+ return sample
426
391
427
392
return sample_fn
428
393
@@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node):
437
402
438
403
from numpyro .distributions .util import multinomial
439
404
440
- def sample_fn (rng , size , dtype , n , p ):
441
- rng_key = rng ["jax_state" ]
442
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
443
-
444
- sample = multinomial (key = sampling_key , n = n , p = p , shape = size )
445
-
446
- rng ["jax_state" ] = rng_key
447
-
448
- return (rng , sample )
405
+ def sample_fn (rng_key , size , dtype , n , p ):
406
+ sample = multinomial (key = rng_key , n = n , p = p , shape = size )
407
+ return sample
449
408
450
409
return sample_fn
451
410
@@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node):
460
419
461
420
from numpyro .distributions .util import von_mises_centered
462
421
463
- def sample_fn (rng , size , dtype , mu , kappa ):
464
- rng_key = rng ["jax_state" ]
465
- rng_key , sampling_key = jax .random .split (rng_key , 2 )
466
-
422
+ def sample_fn (rng_key , size , dtype , mu , kappa ):
467
423
sample = von_mises_centered (
468
- key = sampling_key , concentration = kappa , shape = size , dtype = dtype
424
+ key = rng_key , concentration = kappa , shape = size , dtype = dtype
469
425
)
470
426
sample = (sample + mu + np .pi ) % (2.0 * np .pi ) - np .pi
471
427
472
- rng ["jax_state" ] = rng_key
473
-
474
- return (rng , sample )
428
+ return sample
475
429
476
430
return sample_fn
0 commit comments