Skip to content

Commit 14f89f8

Browse files
authored
Merge branch 'pymc-devs:main' into svd_graph_rewrite
2 parents b999d68 + 8c157a2 commit 14f89f8

File tree

10 files changed

+729
-160
lines changed

10 files changed

+729
-160
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 104 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,8 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
9393
out_dtype = rv.type.dtype
9494
out_size = rv.type.shape
9595

96-
if op.ndim_supp > 0:
97-
out_size = node.outputs[1].type.shape[: -op.ndim_supp]
96+
batch_ndim = op.batch_ndim(node)
97+
out_size = node.default_output().type.shape[:batch_ndim]
9898

9999
# If one dimension has unknown size, either the size is determined
100100
# by a `Shape` operator in which case JAX will compile, or it is
@@ -106,18 +106,18 @@ def sample_fn(rng, size, dtype, *parameters):
106106
# PyTensor uses empty size to represent size = None
107107
if jax.numpy.asarray(size).shape == (0,):
108108
size = None
109-
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
109+
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
110110

111111
else:
112112

113113
def sample_fn(rng, size, dtype, *parameters):
114-
return jax_sample_fn(op)(rng, out_size, out_dtype, *parameters)
114+
return jax_sample_fn(op, node=node)(rng, out_size, out_dtype, *parameters)
115115

116116
return sample_fn
117117

118118

119119
@singledispatch
120-
def jax_sample_fn(op):
120+
def jax_sample_fn(op, node):
121121
name = op.name
122122
raise NotImplementedError(
123123
f"No JAX implementation for the given distribution: {name}"
@@ -128,7 +128,7 @@ def jax_sample_fn(op):
128128
@jax_sample_fn.register(ptr.DirichletRV)
129129
@jax_sample_fn.register(ptr.PoissonRV)
130130
@jax_sample_fn.register(ptr.MvNormalRV)
131-
def jax_sample_fn_generic(op):
131+
def jax_sample_fn_generic(op, node):
132132
"""Generic JAX implementation of random variables."""
133133
name = op.name
134134
jax_op = getattr(jax.random, name)
@@ -149,7 +149,7 @@ def sample_fn(rng, size, dtype, *parameters):
149149
@jax_sample_fn.register(ptr.LogisticRV)
150150
@jax_sample_fn.register(ptr.NormalRV)
151151
@jax_sample_fn.register(ptr.StandardNormalRV)
152-
def jax_sample_fn_loc_scale(op):
152+
def jax_sample_fn_loc_scale(op, node):
153153
"""JAX implementation of random variables in the loc-scale families.
154154
155155
JAX only implements the standard version of random variables in the
@@ -174,7 +174,7 @@ def sample_fn(rng, size, dtype, *parameters):
174174

175175

176176
@jax_sample_fn.register(ptr.BernoulliRV)
177-
def jax_sample_fn_bernoulli(op):
177+
def jax_sample_fn_bernoulli(op, node):
178178
"""JAX implementation of `BernoulliRV`."""
179179

180180
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
@@ -189,7 +189,7 @@ def sample_fn(rng, size, dtype, p):
189189

190190

191191
@jax_sample_fn.register(ptr.CategoricalRV)
192-
def jax_sample_fn_categorical(op):
192+
def jax_sample_fn_categorical(op, node):
193193
"""JAX implementation of `CategoricalRV`."""
194194

195195
# We need a separate dispatch because Categorical expects logits in JAX
@@ -208,7 +208,7 @@ def sample_fn(rng, size, dtype, p):
208208
@jax_sample_fn.register(ptr.RandIntRV)
209209
@jax_sample_fn.register(ptr.IntegersRV)
210210
@jax_sample_fn.register(ptr.UniformRV)
211-
def jax_sample_fn_uniform(op):
211+
def jax_sample_fn_uniform(op, node):
212212
"""JAX implementation of random variables with uniform density.
213213
214214
We need to pass the arguments as keyword arguments since the order
@@ -236,7 +236,7 @@ def sample_fn(rng, size, dtype, *parameters):
236236

237237
@jax_sample_fn.register(ptr.ParetoRV)
238238
@jax_sample_fn.register(ptr.GammaRV)
239-
def jax_sample_fn_shape_scale(op):
239+
def jax_sample_fn_shape_scale(op, node):
240240
"""JAX implementation of random variables in the shape-scale family.
241241
242242
JAX only implements the standard version of random variables in the
@@ -259,7 +259,7 @@ def sample_fn(rng, size, dtype, shape, scale):
259259

260260

261261
@jax_sample_fn.register(ptr.ExponentialRV)
262-
def jax_sample_fn_exponential(op):
262+
def jax_sample_fn_exponential(op, node):
263263
"""JAX implementation of `ExponentialRV`."""
264264

265265
def sample_fn(rng, size, dtype, scale):
@@ -275,7 +275,7 @@ def sample_fn(rng, size, dtype, scale):
275275

276276

277277
@jax_sample_fn.register(ptr.StudentTRV)
278-
def jax_sample_fn_t(op):
278+
def jax_sample_fn_t(op, node):
279279
"""JAX implementation of `StudentTRV`."""
280280

281281
def sample_fn(rng, size, dtype, df, loc, scale):
@@ -290,38 +290,119 @@ def sample_fn(rng, size, dtype, df, loc, scale):
290290
return sample_fn
291291

292292

293-
@jax_sample_fn.register(ptr.ChoiceRV)
294-
def jax_funcify_choice(op):
293+
@jax_sample_fn.register(ptr.ChoiceWithoutReplacement)
294+
def jax_funcify_choice(op, node):
295295
"""JAX implementation of `ChoiceRV`."""
296296

297+
batch_ndim = op.batch_ndim(node)
298+
a, *p, core_shape = node.inputs[3:]
299+
a_core_ndim, *p_core_ndim, _ = op.ndims_params
300+
301+
if batch_ndim and a_core_ndim == 0:
302+
raise NotImplementedError(
303+
"Batch dimensions are not supported for 0d arrays. "
304+
"A default JAX rewrite should have materialized the implicit arange"
305+
)
306+
307+
a_batch_ndim = a.type.ndim - a_core_ndim
308+
if op.has_p_param:
309+
[p] = p
310+
[p_core_ndim] = p_core_ndim
311+
p_batch_ndim = p.type.ndim - p_core_ndim
312+
297313
def sample_fn(rng, size, dtype, *parameters):
298314
rng_key = rng["jax_state"]
299315
rng_key, sampling_key = jax.random.split(rng_key, 2)
300-
(a, p, replace) = parameters
301-
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
316+
317+
if op.has_p_param:
318+
a, p, core_shape = parameters
319+
else:
320+
a, core_shape = parameters
321+
p = None
322+
core_shape = tuple(np.asarray(core_shape))
323+
324+
if batch_ndim == 0:
325+
sample = jax.random.choice(
326+
sampling_key, a, shape=core_shape, replace=False, p=p
327+
)
328+
329+
else:
330+
if size is None:
331+
if p is None:
332+
size = a.shape[:a_batch_ndim]
333+
else:
334+
size = jax.numpy.broadcast_shapes(
335+
a.shape[:a_batch_ndim],
336+
p.shape[:p_batch_ndim],
337+
)
338+
339+
a = jax.numpy.broadcast_to(a, size + a.shape[a_batch_ndim:])
340+
if p is not None:
341+
p = jax.numpy.broadcast_to(p, size + p.shape[p_batch_ndim:])
342+
343+
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
344+
345+
# Ravel the batch dimensions because vmap only works along a single axis
346+
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
347+
if p is None:
348+
raveled_sample = jax.vmap(
349+
lambda key, a: jax.random.choice(
350+
key, a, shape=core_shape, replace=False, p=None
351+
)
352+
)(batch_sampling_keys, raveled_batch_a)
353+
else:
354+
raveled_batch_p = p.reshape((-1,) + p.shape[batch_ndim:])
355+
raveled_sample = jax.vmap(
356+
lambda key, a, p: jax.random.choice(
357+
key, a, shape=core_shape, replace=False, p=p
358+
)
359+
)(batch_sampling_keys, raveled_batch_a, raveled_batch_p)
360+
361+
# Reshape the batch dimensions
362+
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
363+
302364
rng["jax_state"] = rng_key
303-
return (rng, smpl_value)
365+
return (rng, sample)
304366

305367
return sample_fn
306368

307369

308370
@jax_sample_fn.register(ptr.PermutationRV)
309-
def jax_sample_fn_permutation(op):
371+
def jax_sample_fn_permutation(op, node):
310372
"""JAX implementation of `PermutationRV`."""
311373

374+
batch_ndim = op.batch_ndim(node)
375+
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
376+
312377
def sample_fn(rng, size, dtype, *parameters):
313378
rng_key = rng["jax_state"]
314379
rng_key, sampling_key = jax.random.split(rng_key, 2)
315380
(x,) = parameters
316-
sample = jax.random.permutation(sampling_key, x)
381+
if batch_ndim:
382+
# jax.random.permutation has no concept of batch dims
383+
x_core_shape = x.shape[x_batch_ndim:]
384+
if size is None:
385+
size = x.shape[:x_batch_ndim]
386+
else:
387+
x = jax.numpy.broadcast_to(x, size + x_core_shape)
388+
389+
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
390+
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
391+
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
392+
batch_sampling_keys, raveled_batch_x
393+
)
394+
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
395+
else:
396+
sample = jax.random.permutation(sampling_key, x)
397+
317398
rng["jax_state"] = rng_key
318399
return (rng, sample)
319400

320401
return sample_fn
321402

322403

323404
@jax_sample_fn.register(ptr.BinomialRV)
324-
def jax_sample_fn_binomial(op):
405+
def jax_sample_fn_binomial(op, node):
325406
if not numpyro_available:
326407
raise NotImplementedError(
327408
f"No JAX implementation for the given distribution: {op.name}. "
@@ -344,7 +425,7 @@ def sample_fn(rng, size, dtype, n, p):
344425

345426

346427
@jax_sample_fn.register(ptr.MultinomialRV)
347-
def jax_sample_fn_multinomial(op):
428+
def jax_sample_fn_multinomial(op, node):
348429
if not numpyro_available:
349430
raise NotImplementedError(
350431
f"No JAX implementation for the given distribution: {op.name}. "
@@ -367,7 +448,7 @@ def sample_fn(rng, size, dtype, n, p):
367448

368449

369450
@jax_sample_fn.register(ptr.VonMisesRV)
370-
def jax_sample_fn_vonmises(op):
451+
def jax_sample_fn_vonmises(op, node):
371452
if not numpyro_available:
372453
raise NotImplementedError(
373454
f"No JAX implementation for the given distribution: {op.name}. "

pytensor/link/numba/dispatch/random.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ def {sized_fn_name}({random_fn_input_names}):
210210
@numba_funcify.register(ptr.BinomialRV)
211211
@numba_funcify.register(ptr.MultinomialRV)
212212
@numba_funcify.register(ptr.RandIntRV) # only the first two arguments are supported
213-
@numba_funcify.register(ptr.ChoiceRV) # the `p` argument is not supported
214213
@numba_funcify.register(ptr.PermutationRV)
215214
def numba_funcify_RandomVariable(op, node, **kwargs):
216215
name = op.name
@@ -367,3 +366,63 @@ def dirichlet_rv(rng, size, dtype, alphas):
367366
return (rng, np.random.dirichlet(alphas, size))
368367

369368
return dirichlet_rv
369+
370+
371+
@numba_funcify.register(ptr.ChoiceWithoutReplacement)
372+
def numba_funcify_choice_without_replacement(op, node, **kwargs):
373+
batch_ndim = op.batch_ndim(node)
374+
if batch_ndim:
375+
# The code isn't too hard to write, but Numba doesn't support a with ndim > 1,
376+
# and I don't want to change the batched tests for this
377+
# We'll just raise an error for now
378+
raise NotImplementedError(
379+
"ChoiceWithoutReplacement with batch_ndim not supported in Numba backend"
380+
)
381+
382+
[core_shape_len] = node.inputs[-1].type.shape
383+
384+
if op.has_p_param:
385+
386+
@numba_basic.numba_njit
387+
def choice_without_replacement_rv(rng, size, dtype, a, p, core_shape):
388+
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
389+
samples = np.random.choice(a, size=core_shape, replace=False, p=p)
390+
return (rng, samples)
391+
else:
392+
393+
@numba_basic.numba_njit
394+
def choice_without_replacement_rv(rng, size, dtype, a, core_shape):
395+
core_shape = numba_ndarray.to_fixed_tuple(core_shape, core_shape_len)
396+
samples = np.random.choice(a, size=core_shape, replace=False)
397+
return (rng, samples)
398+
399+
return choice_without_replacement_rv
400+
401+
402+
@numba_funcify.register(ptr.PermutationRV)
403+
def numba_funcify_permutation(op, node, **kwargs):
404+
# PyTensor uses size=() to represent size=None
405+
size_is_none = node.inputs[1].type.shape == (0,)
406+
batch_ndim = op.batch_ndim(node)
407+
x_batch_ndim = node.inputs[-1].type.ndim - op.ndims_params[0]
408+
409+
@numba_basic.numba_njit
410+
def permutation_rv(rng, size, dtype, x):
411+
if batch_ndim:
412+
x_core_shape = x.shape[x_batch_ndim:]
413+
if size_is_none:
414+
size = x.shape[:batch_ndim]
415+
else:
416+
size = numba_ndarray.to_fixed_tuple(size, batch_ndim)
417+
x = np.broadcast_to(x, size + x_core_shape)
418+
419+
samples = np.empty(size + x_core_shape, dtype=x.dtype)
420+
for index in np.ndindex(size):
421+
samples[index] = np.random.permutation(x[index])
422+
423+
else:
424+
samples = np.random.permutation(x)
425+
426+
return (rng, samples)
427+
428+
return permutation_rv

0 commit comments

Comments
 (0)