Skip to content

Commit d80c0bf

Browse files
committed
Fix Choice and Permutation not respecting the RandomVariable contract
These two RVs don't fall into the traditional RandomVariable contract because they don't have a concept of `batch_ndim`s. The hard-coded ndim params and ndim support were wrong and need to be defined for every node. * ChoiceRV was removed in favor of ChoiceWithoutReplacementRV which handles the cases without replacement. Those with replacement can be trivially be implemented with other basic RVs. * Both Permutation and ChoiceWithoutReplacement now support batch ndims * Avoid materializing the implicit arange
1 parent e2e8757 commit d80c0bf

File tree

8 files changed

+680
-126
lines changed

8 files changed

+680
-126
lines changed

pytensor/link/jax/dispatch/random.py

+104-23
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
9393
out_dtype = rv.type.dtype
9494
out_size = rv.type.shape
9595

96-
if op.ndim_supp > 0:
97-
out_size = node.outputs[1].type.shape[: -op.ndim_supp]
96+
batch_ndim = op.batch_ndim(node)
97+
out_size = node.default_output().type.shape[:batch_ndim]
9898

9999
# If one dimension has unknown size, either the size is determined
100100
# by a `Shape` operator in which case JAX will compile, or it is
@@ -106,18 +106,18 @@ def sample_fn(rng, size, dtype, *parameters):
106106
# PyTensor uses empty size to represent size = None
107107
if jax.numpy.asarray(size).shape == (0,):
108108
size = None
109-
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
109+
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
110110

111111
else:
112112

113113
def sample_fn(rng, size, dtype, *parameters):
114-
return jax_sample_fn(op)(rng, out_size, out_dtype, *parameters)
114+
return jax_sample_fn(op, node=node)(rng, out_size, out_dtype, *parameters)
115115

116116
return sample_fn
117117

118118

119119
@singledispatch
120-
def jax_sample_fn(op):
120+
def jax_sample_fn(op, node):
121121
name = op.name
122122
raise NotImplementedError(
123123
f"No JAX implementation for the given distribution: {name}"
@@ -128,7 +128,7 @@ def jax_sample_fn(op):
128128
@jax_sample_fn.register(ptr.DirichletRV)
129129
@jax_sample_fn.register(ptr.PoissonRV)
130130
@jax_sample_fn.register(ptr.MvNormalRV)
131-
def jax_sample_fn_generic(op):
131+
def jax_sample_fn_generic(op, node):
132132
"""Generic JAX implementation of random variables."""
133133
name = op.name
134134
jax_op = getattr(jax.random, name)
@@ -149,7 +149,7 @@ def sample_fn(rng, size, dtype, *parameters):
149149
@jax_sample_fn.register(ptr.LogisticRV)
150150
@jax_sample_fn.register(ptr.NormalRV)
151151
@jax_sample_fn.register(ptr.StandardNormalRV)
152-
def jax_sample_fn_loc_scale(op):
152+
def jax_sample_fn_loc_scale(op, node):
153153
"""JAX implementation of random variables in the loc-scale families.
154154
155155
JAX only implements the standard version of random variables in the
@@ -174,7 +174,7 @@ def sample_fn(rng, size, dtype, *parameters):
174174

175175

176176
@jax_sample_fn.register(ptr.BernoulliRV)
177-
def jax_sample_fn_bernoulli(op):
177+
def jax_sample_fn_bernoulli(op, node):
178178
"""JAX implementation of `BernoulliRV`."""
179179

180180
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
@@ -189,7 +189,7 @@ def sample_fn(rng, size, dtype, p):
189189

190190

191191
@jax_sample_fn.register(ptr.CategoricalRV)
192-
def jax_sample_fn_categorical(op):
192+
def jax_sample_fn_categorical(op, node):
193193
"""JAX implementation of `CategoricalRV`."""
194194

195195
# We need a separate dispatch because Categorical expects logits in JAX
@@ -208,7 +208,7 @@ def sample_fn(rng, size, dtype, p):
208208
@jax_sample_fn.register(ptr.RandIntRV)
209209
@jax_sample_fn.register(ptr.IntegersRV)
210210
@jax_sample_fn.register(ptr.UniformRV)
211-
def jax_sample_fn_uniform(op):
211+
def jax_sample_fn_uniform(op, node):
212212
"""JAX implementation of random variables with uniform density.
213213
214214
We need to pass the arguments as keyword arguments since the order
@@ -236,7 +236,7 @@ def sample_fn(rng, size, dtype, *parameters):
236236

237237
@jax_sample_fn.register(ptr.ParetoRV)
238238
@jax_sample_fn.register(ptr.GammaRV)
239-
def jax_sample_fn_shape_scale(op):
239+
def jax_sample_fn_shape_scale(op, node):
240240
"""JAX implementation of random variables in the shape-scale family.
241241
242242
JAX only implements the standard version of random variables in the
@@ -259,7 +259,7 @@ def sample_fn(rng, size, dtype, shape, scale):
259259

260260

261261
@jax_sample_fn.register(ptr.ExponentialRV)
262-
def jax_sample_fn_exponential(op):
262+
def jax_sample_fn_exponential(op, node):
263263
"""JAX implementation of `ExponentialRV`."""
264264

265265
def sample_fn(rng, size, dtype, scale):
@@ -275,7 +275,7 @@ def sample_fn(rng, size, dtype, scale):
275275

276276

277277
@jax_sample_fn.register(ptr.StudentTRV)
278-
def jax_sample_fn_t(op):
278+
def jax_sample_fn_t(op, node):
279279
"""JAX implementation of `StudentTRV`."""
280280

281281
def sample_fn(rng, size, dtype, df, loc, scale):
@@ -290,38 +290,119 @@ def sample_fn(rng, size, dtype, df, loc, scale):
290290
return sample_fn
291291

292292

293-
@jax_sample_fn.register(ptr.ChoiceRV)
294-
def jax_funcify_choice(op):
293+
@jax_sample_fn.register(ptr.ChoiceWithoutReplacement)
294+
def jax_funcify_choice(op, node):
295295
"""JAX implementation of `ChoiceRV`."""
296296

297+
batch_ndim = op.batch_ndim(node)
298+
a, *p, core_shape = node.inputs[3:]
299+
a_core_ndim, *p_core_ndim, _ = op.ndims_params
300+
301+
if batch_ndim and a_core_ndim == 0:
302+
raise NotImplementedError(
303+
"Batch dimensions are not supported for 0d arrays. "
304+
"A default JAX rewrite should have materialized the implicit arange"
305+
)
306+
307+
a_batch_ndim = a.type.ndim - a_core_ndim
308+
if op.has_p_param:
309+
[p] = p
310+
[p_core_ndim] = p_core_ndim
311+
p_batch_ndim = p.type.ndim - p_core_ndim
312+
297313
def sample_fn(rng, size, dtype, *parameters):
298314
rng_key = rng["jax_state"]
299315
rng_key, sampling_key = jax.random.split(rng_key, 2)
300-
(a, p, replace) = parameters
301-
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
316+
317+
if op.has_p_param:
318+
a, p, core_shape = parameters
319+
else:
320+
a, core_shape = parameters
321+
p = None
322+
core_shape = tuple(np.asarray(core_shape))
323+
324+
if batch_ndim == 0:
325+
sample = jax.random.choice(
326+
sampling_key, a, shape=core_shape, replace=False, p=p
327+
)
328+
329+
else:
330+
if size is None:
331+
if p is None:
332+
size = a.shape[:a_batch_ndim]
333+
else:
334+
size = jax.numpy.broadcast_shapes(
335+
a.shape[:a_batch_ndim],
336+
p.shape[:p_batch_ndim],
337+
)
338+
339+
a = jax.numpy.broadcast_to(a, size + a.shape[a_batch_ndim:])
340+
if p is not None:
341+
p = jax.numpy.broadcast_to(p, size + p.shape[p_batch_ndim:])
342+
343+
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
344+
345+
# Ravel the batch dimensions because vmap only works along a single axis
346+
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
347+
if p is None:
348+
raveled_sample = jax.vmap(
349+
lambda key, a: jax.random.choice(
350+
key, a, shape=core_shape, replace=False, p=None
351+
)
352+
)(batch_sampling_keys, raveled_batch_a)
353+
else:
354+
raveled_batch_p = p.reshape((-1,) + p.shape[batch_ndim:])
355+
raveled_sample = jax.vmap(
356+
lambda key, a, p: jax.random.choice(
357+
key, a, shape=core_shape, replace=False, p=p
358+
)
359+
)(batch_sampling_keys, raveled_batch_a, raveled_batch_p)
360+
361+
# Reshape the batch dimensions
362+
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
363+
302364
rng["jax_state"] = rng_key
303-
return (rng, smpl_value)
365+
return (rng, sample)
304366

305367
return sample_fn
306368

307369

308370
@jax_sample_fn.register(ptr.PermutationRV)
309-
def jax_sample_fn_permutation(op):
371+
def jax_sample_fn_permutation(op, node):
310372
"""JAX implementation of `PermutationRV`."""
311373

374+
batch_ndim = op.batch_ndim(node)
375+
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
376+
312377
def sample_fn(rng, size, dtype, *parameters):
313378
rng_key = rng["jax_state"]
314379
rng_key, sampling_key = jax.random.split(rng_key, 2)
315380
(x,) = parameters
316-
sample = jax.random.permutation(sampling_key, x)
381+
if batch_ndim:
382+
# jax.random.permutation has no concept of batch dims
383+
x_core_shape = x.shape[x_batch_ndim:]
384+
if size is None:
385+
size = x.shape[:x_batch_ndim]
386+
else:
387+
x = jax.numpy.broadcast_to(x, size + x_core_shape)
388+
389+
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
390+
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
391+
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
392+
batch_sampling_keys, raveled_batch_x
393+
)
394+
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
395+
else:
396+
sample = jax.random.permutation(sampling_key, x)
397+
317398
rng["jax_state"] = rng_key
318399
return (rng, sample)
319400

320401
return sample_fn
321402

322403

323404
@jax_sample_fn.register(ptr.BinomialRV)
324-
def jax_sample_fn_binomial(op):
405+
def jax_sample_fn_binomial(op, node):
325406
if not numpyro_available:
326407
raise NotImplementedError(
327408
f"No JAX implementation for the given distribution: {op.name}. "
@@ -344,7 +425,7 @@ def sample_fn(rng, size, dtype, n, p):
344425

345426

346427
@jax_sample_fn.register(ptr.MultinomialRV)
347-
def jax_sample_fn_multinomial(op):
428+
def jax_sample_fn_multinomial(op, node):
348429
if not numpyro_available:
349430
raise NotImplementedError(
350431
f"No JAX implementation for the given distribution: {op.name}. "
@@ -367,7 +448,7 @@ def sample_fn(rng, size, dtype, n, p):
367448

368449

369450
@jax_sample_fn.register(ptr.VonMisesRV)
370-
def jax_sample_fn_vonmises(op):
451+
def jax_sample_fn_vonmises(op, node):
371452
if not numpyro_available:
372453
raise NotImplementedError(
373454
f"No JAX implementation for the given distribution: {op.name}. "

pytensor/link/numba/dispatch/random.py

+60-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ def {sized_fn_name}({random_fn_input_names}):
210210
@numba_funcify.register(ptr.BinomialRV)
211211
@numba_funcify.register(ptr.MultinomialRV)
212212
@numba_funcify.register(ptr.RandIntRV) # only the first two arguments are supported
213-
@numba_funcify.register(ptr.ChoiceRV) # the `p` argument is not supported
214213
@numba_funcify.register(ptr.PermutationRV)
215214
def numba_funcify_RandomVariable(op, node, **kwargs):
216215
name = op.name
@@ -367,3 +366,63 @@ def dirichlet_rv(rng, size, dtype, alphas):
367366
return (rng, np.random.dirichlet(alphas, size))
368367

369368
return dirichlet_rv
369+
370+
371+
@numba_funcify.register(ptr.ChoiceWithoutReplacement)
372+
def numba_funcify_choice_without_replacement(op, node, **kwargs):
373+
batch_ndim = op.batch_ndim(node)
374+
if batch_ndim:
375+
# The code isn't too hard to write, but Numba doesn't support a with ndim > 1,
376+
# and I don't want to change the batched tests for this
377+
# We'll just raise an error for now
378+
raise NotImplementedError(
379+
"ChoiceWithoutReplacement with batch_ndim not supported in Numba backend"
380+
)
381+
382+
[core_shape_len] = node.inputs[-1].type.shape
383+
384+
if op.has_p_param:
385+
386+
@numba_basic.numba_njit
387+
def choice_without_replacement_rv(rng, size, dtype, a, p, core_shape):
388+
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
389+
samples = np.random.choice(a, size=core_shape, replace=False, p=p)
390+
return (rng, samples)
391+
else:
392+
393+
@numba_basic.numba_njit
394+
def choice_without_replacement_rv(rng, size, dtype, a, core_shape):
395+
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
396+
samples = np.random.choice(a, size=core_shape, replace=False)
397+
return (rng, samples)
398+
399+
return choice_without_replacement_rv
400+
401+
402+
@numba_funcify.register(ptr.PermutationRV)
403+
def numba_funcify_permutation(op, node, **kwargs):
404+
# PyTensor uses size=() to represent size=None
405+
size_is_none = node.inputs[1].type.shape == (0,)
406+
batch_ndim = op.batch_ndim(node)
407+
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
408+
409+
@numba_basic.numba_njit
410+
def permutation_rv(rng, size, dtype, x):
411+
if batch_ndim:
412+
x_core_shape = x.shape[x_batch_ndim:]
413+
if size_is_none:
414+
size = x.shape[:batch_ndim]
415+
else:
416+
size = numba_ndarray.to_fixed_tuple(size, batch_ndim)
417+
x = np.broadcast_to(x, size + x_core_shape)
418+
419+
samples = np.empty(size + x_core_shape, dtype=x.dtype)
420+
for index in np.ndindex(size):
421+
samples[index] = np.random.permutation(x[index])
422+
423+
else:
424+
samples = np.random.permutation(x)
425+
426+
return (rng, samples)
427+
428+
return permutation_rv

0 commit comments

Comments
 (0)