@@ -93,8 +93,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
93
93
out_dtype = rv .type .dtype
94
94
out_size = rv .type .shape
95
95
96
- if op .ndim_supp > 0 :
97
- out_size = node .outputs [ 1 ] .type .shape [: - op . ndim_supp ]
96
+ batch_ndim = op .batch_ndim ( node )
97
+ out_size = node .default_output () .type .shape [:batch_ndim ]
98
98
99
99
# If one dimension has unknown size, either the size is determined
100
100
# by a `Shape` operator in which case JAX will compile, or it is
@@ -106,18 +106,18 @@ def sample_fn(rng, size, dtype, *parameters):
106
106
# PyTensor uses empty size to represent size = None
107
107
if jax .numpy .asarray (size ).shape == (0 ,):
108
108
size = None
109
- return jax_sample_fn (op )(rng , size , out_dtype , * parameters )
109
+ return jax_sample_fn (op , node = node )(rng , size , out_dtype , * parameters )
110
110
111
111
else :
112
112
113
113
def sample_fn (rng , size , dtype , * parameters ):
114
- return jax_sample_fn (op )(rng , out_size , out_dtype , * parameters )
114
+ return jax_sample_fn (op , node = node )(rng , out_size , out_dtype , * parameters )
115
115
116
116
return sample_fn
117
117
118
118
119
119
@singledispatch
120
- def jax_sample_fn (op ):
120
+ def jax_sample_fn (op , node ):
121
121
name = op .name
122
122
raise NotImplementedError (
123
123
f"No JAX implementation for the given distribution: { name } "
@@ -128,7 +128,7 @@ def jax_sample_fn(op):
128
128
@jax_sample_fn .register (ptr .DirichletRV )
129
129
@jax_sample_fn .register (ptr .PoissonRV )
130
130
@jax_sample_fn .register (ptr .MvNormalRV )
131
- def jax_sample_fn_generic (op ):
131
+ def jax_sample_fn_generic (op , node ):
132
132
"""Generic JAX implementation of random variables."""
133
133
name = op .name
134
134
jax_op = getattr (jax .random , name )
@@ -149,7 +149,7 @@ def sample_fn(rng, size, dtype, *parameters):
149
149
@jax_sample_fn .register (ptr .LogisticRV )
150
150
@jax_sample_fn .register (ptr .NormalRV )
151
151
@jax_sample_fn .register (ptr .StandardNormalRV )
152
- def jax_sample_fn_loc_scale (op ):
152
+ def jax_sample_fn_loc_scale (op , node ):
153
153
"""JAX implementation of random variables in the loc-scale families.
154
154
155
155
JAX only implements the standard version of random variables in the
@@ -174,7 +174,7 @@ def sample_fn(rng, size, dtype, *parameters):
174
174
175
175
176
176
@jax_sample_fn .register (ptr .BernoulliRV )
177
- def jax_sample_fn_bernoulli (op ):
177
+ def jax_sample_fn_bernoulli (op , node ):
178
178
"""JAX implementation of `BernoulliRV`."""
179
179
180
180
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
@@ -189,7 +189,7 @@ def sample_fn(rng, size, dtype, p):
189
189
190
190
191
191
@jax_sample_fn .register (ptr .CategoricalRV )
192
- def jax_sample_fn_categorical (op ):
192
+ def jax_sample_fn_categorical (op , node ):
193
193
"""JAX implementation of `CategoricalRV`."""
194
194
195
195
# We need a separate dispatch because Categorical expects logits in JAX
@@ -208,7 +208,7 @@ def sample_fn(rng, size, dtype, p):
208
208
@jax_sample_fn .register (ptr .RandIntRV )
209
209
@jax_sample_fn .register (ptr .IntegersRV )
210
210
@jax_sample_fn .register (ptr .UniformRV )
211
- def jax_sample_fn_uniform (op ):
211
+ def jax_sample_fn_uniform (op , node ):
212
212
"""JAX implementation of random variables with uniform density.
213
213
214
214
We need to pass the arguments as keyword arguments since the order
@@ -236,7 +236,7 @@ def sample_fn(rng, size, dtype, *parameters):
236
236
237
237
@jax_sample_fn .register (ptr .ParetoRV )
238
238
@jax_sample_fn .register (ptr .GammaRV )
239
- def jax_sample_fn_shape_scale (op ):
239
+ def jax_sample_fn_shape_scale (op , node ):
240
240
"""JAX implementation of random variables in the shape-scale family.
241
241
242
242
JAX only implements the standard version of random variables in the
@@ -259,7 +259,7 @@ def sample_fn(rng, size, dtype, shape, scale):
259
259
260
260
261
261
@jax_sample_fn .register (ptr .ExponentialRV )
262
- def jax_sample_fn_exponential (op ):
262
+ def jax_sample_fn_exponential (op , node ):
263
263
"""JAX implementation of `ExponentialRV`."""
264
264
265
265
def sample_fn (rng , size , dtype , scale ):
@@ -275,7 +275,7 @@ def sample_fn(rng, size, dtype, scale):
275
275
276
276
277
277
@jax_sample_fn .register (ptr .StudentTRV )
278
- def jax_sample_fn_t (op ):
278
+ def jax_sample_fn_t (op , node ):
279
279
"""JAX implementation of `StudentTRV`."""
280
280
281
281
def sample_fn (rng , size , dtype , df , loc , scale ):
@@ -290,38 +290,119 @@ def sample_fn(rng, size, dtype, df, loc, scale):
290
290
return sample_fn
291
291
292
292
293
- @jax_sample_fn .register (ptr .ChoiceRV )
294
- def jax_funcify_choice (op ):
293
+ @jax_sample_fn .register (ptr .ChoiceWithoutReplacement )
294
+ def jax_funcify_choice (op , node ):
295
295
"""JAX implementation of `ChoiceRV`."""
296
296
297
+ batch_ndim = op .batch_ndim (node )
298
+ a , * p , core_shape = node .inputs [3 :]
299
+ a_core_ndim , * p_core_ndim , _ = op .ndims_params
300
+
301
+ if batch_ndim and a_core_ndim == 0 :
302
+ raise NotImplementedError (
303
+ "Batch dimensions are not supported for 0d arrays. "
304
+ "A default JAX rewrite should have materialized the implicit arange"
305
+ )
306
+
307
+ a_batch_ndim = a .type .ndim - a_core_ndim
308
+ if op .has_p_param :
309
+ [p ] = p
310
+ [p_core_ndim ] = p_core_ndim
311
+ p_batch_ndim = p .type .ndim - p_core_ndim
312
+
297
313
def sample_fn (rng , size , dtype , * parameters ):
298
314
rng_key = rng ["jax_state" ]
299
315
rng_key , sampling_key = jax .random .split (rng_key , 2 )
300
- (a , p , replace ) = parameters
301
- smpl_value = jax .random .choice (sampling_key , a , size , replace , p )
316
+
317
+ if op .has_p_param :
318
+ a , p , core_shape = parameters
319
+ else :
320
+ a , core_shape = parameters
321
+ p = None
322
+ core_shape = tuple (np .asarray (core_shape ))
323
+
324
+ if batch_ndim == 0 :
325
+ sample = jax .random .choice (
326
+ sampling_key , a , shape = core_shape , replace = False , p = p
327
+ )
328
+
329
+ else :
330
+ if size is None :
331
+ if p is None :
332
+ size = a .shape [:a_batch_ndim ]
333
+ else :
334
+ size = jax .numpy .broadcast_shapes (
335
+ a .shape [:a_batch_ndim ],
336
+ p .shape [:p_batch_ndim ],
337
+ )
338
+
339
+ a = jax .numpy .broadcast_to (a , size + a .shape [a_batch_ndim :])
340
+ if p is not None :
341
+ p = jax .numpy .broadcast_to (p , size + p .shape [p_batch_ndim :])
342
+
343
+ batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
344
+
345
+ # Ravel the batch dimensions because vmap only works along a single axis
346
+ raveled_batch_a = a .reshape ((- 1 ,) + a .shape [batch_ndim :])
347
+ if p is None :
348
+ raveled_sample = jax .vmap (
349
+ lambda key , a : jax .random .choice (
350
+ key , a , shape = core_shape , replace = False , p = None
351
+ )
352
+ )(batch_sampling_keys , raveled_batch_a )
353
+ else :
354
+ raveled_batch_p = p .reshape ((- 1 ,) + p .shape [batch_ndim :])
355
+ raveled_sample = jax .vmap (
356
+ lambda key , a , p : jax .random .choice (
357
+ key , a , shape = core_shape , replace = False , p = p
358
+ )
359
+ )(batch_sampling_keys , raveled_batch_a , raveled_batch_p )
360
+
361
+ # Reshape the batch dimensions
362
+ sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
363
+
302
364
rng ["jax_state" ] = rng_key
303
- return (rng , smpl_value )
365
+ return (rng , sample )
304
366
305
367
return sample_fn
306
368
307
369
308
370
@jax_sample_fn .register (ptr .PermutationRV )
309
- def jax_sample_fn_permutation (op ):
371
+ def jax_sample_fn_permutation (op , node ):
310
372
"""JAX implementation of `PermutationRV`."""
311
373
374
+ batch_ndim = op .batch_ndim (node )
375
+ x_batch_ndim = node .inputs [- 1 ].type .ndim - op .ndims_params [0 ]
376
+
312
377
def sample_fn (rng , size , dtype , * parameters ):
313
378
rng_key = rng ["jax_state" ]
314
379
rng_key , sampling_key = jax .random .split (rng_key , 2 )
315
380
(x ,) = parameters
316
- sample = jax .random .permutation (sampling_key , x )
381
+ if batch_ndim :
382
+ # jax.random.permutation has no concept of batch dims
383
+ x_core_shape = x .shape [x_batch_ndim :]
384
+ if size is None :
385
+ size = x .shape [:x_batch_ndim ]
386
+ else :
387
+ x = jax .numpy .broadcast_to (x , size + x_core_shape )
388
+
389
+ batch_sampling_keys = jax .random .split (sampling_key , np .prod (size ))
390
+ raveled_batch_x = x .reshape ((- 1 ,) + x .shape [batch_ndim :])
391
+ raveled_sample = jax .vmap (lambda key , x : jax .random .permutation (key , x ))(
392
+ batch_sampling_keys , raveled_batch_x
393
+ )
394
+ sample = raveled_sample .reshape (size + raveled_sample .shape [1 :])
395
+ else :
396
+ sample = jax .random .permutation (sampling_key , x )
397
+
317
398
rng ["jax_state" ] = rng_key
318
399
return (rng , sample )
319
400
320
401
return sample_fn
321
402
322
403
323
404
@jax_sample_fn .register (ptr .BinomialRV )
324
- def jax_sample_fn_binomial (op ):
405
+ def jax_sample_fn_binomial (op , node ):
325
406
if not numpyro_available :
326
407
raise NotImplementedError (
327
408
f"No JAX implementation for the given distribution: { op .name } . "
@@ -344,7 +425,7 @@ def sample_fn(rng, size, dtype, n, p):
344
425
345
426
346
427
@jax_sample_fn .register (ptr .MultinomialRV )
347
- def jax_sample_fn_multinomial (op ):
428
+ def jax_sample_fn_multinomial (op , node ):
348
429
if not numpyro_available :
349
430
raise NotImplementedError (
350
431
f"No JAX implementation for the given distribution: { op .name } . "
@@ -367,7 +448,7 @@ def sample_fn(rng, size, dtype, n, p):
367
448
368
449
369
450
@jax_sample_fn .register (ptr .VonMisesRV )
370
- def jax_sample_fn_vonmises (op ):
451
+ def jax_sample_fn_vonmises (op , node ):
371
452
if not numpyro_available :
372
453
raise NotImplementedError (
373
454
f"No JAX implementation for the given distribution: { op .name } . "
0 commit comments