Skip to content

Commit 9ab8df5

Browse files
Etienne DuchesnericardoV94
Etienne Duchesne
authored andcommitted
Simplify dispatch of JAX random variables
1 parent 95ce102 commit 9ab8df5

File tree

2 files changed

+58
-104
lines changed

2 files changed

+58
-104
lines changed

pytensor/link/jax/dispatch/random.py

+57-103
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
105105
assert_size_argument_jax_compatible(node)
106106

107107
def sample_fn(rng, size, *parameters):
108-
return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters)
108+
rng_key = rng["jax_state"]
109+
rng_key, sampling_key = jax.random.split(rng_key, 2)
110+
rng["jax_state"] = rng_key
111+
sample = jax_sample_fn(op, node=node)(
112+
sampling_key, size, out_dtype, *parameters
113+
)
114+
return (rng, sample)
109115

110116
else:
111117

112118
def sample_fn(rng, size, *parameters):
113-
return jax_sample_fn(op, node=node)(
114-
rng, static_size, out_dtype, *parameters
119+
rng_key = rng["jax_state"]
120+
rng_key, sampling_key = jax.random.split(rng_key, 2)
121+
rng["jax_state"] = rng_key
122+
sample = jax_sample_fn(op, node=node)(
123+
sampling_key, static_size, out_dtype, *parameters
115124
)
125+
return (rng, sample)
116126

117127
return sample_fn
118128

@@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node):
133143
name = op.name
134144
jax_op = getattr(jax.random, name)
135145

136-
def sample_fn(rng, size, dtype, *parameters):
137-
rng_key = rng["jax_state"]
138-
rng_key, sampling_key = jax.random.split(rng_key, 2)
139-
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
140-
rng["jax_state"] = rng_key
141-
return (rng, sample)
146+
def sample_fn(rng_key, size, dtype, *parameters):
147+
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
148+
return sample
142149

143150
return sample_fn
144151

@@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node):
159166
name = op.name
160167
jax_op = getattr(jax.random, name)
161168

162-
def sample_fn(rng, size, dtype, *parameters):
163-
rng_key = rng["jax_state"]
164-
rng_key, sampling_key = jax.random.split(rng_key, 2)
169+
def sample_fn(rng_key, size, dtype, *parameters):
165170
loc, scale = parameters
166171
if size is None:
167172
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
168-
sample = loc + jax_op(sampling_key, size, dtype) * scale
169-
rng["jax_state"] = rng_key
170-
return (rng, sample)
173+
sample = loc + jax_op(rng_key, size, dtype) * scale
174+
return sample
171175

172176
return sample_fn
173177

174178

175179
@jax_sample_fn.register(ptr.MvNormalRV)
176180
def jax_sample_mvnormal(op, node):
177-
def sample_fn(rng, size, dtype, mean, cov):
178-
rng_key = rng["jax_state"]
179-
rng_key, sampling_key = jax.random.split(rng_key, 2)
181+
def sample_fn(rng_key, size, dtype, mean, cov):
180182
sample = jax.random.multivariate_normal(
181-
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
183+
rng_key, mean, cov, shape=size, dtype=dtype, method=op.method
182184
)
183-
rng["jax_state"] = rng_key
184-
return (rng, sample)
185+
return sample
185186

186187
return sample_fn
187188

@@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node):
191192
"""JAX implementation of `BernoulliRV`."""
192193

193194
# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
194-
def sample_fn(rng, size, dtype, p):
195-
rng_key = rng["jax_state"]
196-
rng_key, sampling_key = jax.random.split(rng_key, 2)
197-
sample = jax.random.bernoulli(sampling_key, p, shape=size)
198-
rng["jax_state"] = rng_key
199-
return (rng, sample)
195+
def sample_fn(rng_key, size, dtype, p):
196+
sample = jax.random.bernoulli(rng_key, p, shape=size)
197+
return sample
200198

201199
return sample_fn
202200

@@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node):
206204
"""JAX implementation of `CategoricalRV`."""
207205

208206
# We need a separate dispatch because Categorical expects logits in JAX
209-
def sample_fn(rng, size, dtype, p):
210-
rng_key = rng["jax_state"]
211-
rng_key, sampling_key = jax.random.split(rng_key, 2)
212-
207+
def sample_fn(rng_key, size, dtype, p):
213208
logits = jax.scipy.special.logit(p)
214-
sample = jax.random.categorical(sampling_key, logits=logits, shape=size)
215-
rng["jax_state"] = rng_key
216-
return (rng, sample)
209+
sample = jax.random.categorical(rng_key, logits=logits, shape=size)
210+
return sample
217211

218212
return sample_fn
219213

@@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node):
233227
name = "randint"
234228
jax_op = getattr(jax.random, name)
235229

236-
def sample_fn(rng, size, dtype, *parameters):
237-
rng_key = rng["jax_state"]
238-
rng_key, sampling_key = jax.random.split(rng_key, 2)
230+
def sample_fn(rng_key, size, dtype, *parameters):
239231
minval, maxval = parameters
240-
sample = jax_op(
241-
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
242-
)
243-
rng["jax_state"] = rng_key
244-
return (rng, sample)
232+
sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval)
233+
return sample
245234

246235
return sample_fn
247236

@@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node):
258247
name = op.name
259248
jax_op = getattr(jax.random, name)
260249

261-
def sample_fn(rng, size, dtype, shape, scale):
262-
rng_key = rng["jax_state"]
263-
rng_key, sampling_key = jax.random.split(rng_key, 2)
250+
def sample_fn(rng_key, size, dtype, shape, scale):
264251
if size is None:
265252
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
266-
sample = jax_op(sampling_key, shape, size, dtype) * scale
267-
rng["jax_state"] = rng_key
268-
return (rng, sample)
253+
sample = jax_op(rng_key, shape, size, dtype) * scale
254+
return sample
269255

270256
return sample_fn
271257

@@ -274,14 +260,11 @@ def sample_fn(rng, size, dtype, shape, scale):
274260
def jax_sample_fn_exponential(op, node):
275261
"""JAX implementation of `ExponentialRV`."""
276262

277-
def sample_fn(rng, size, dtype, scale):
278-
rng_key = rng["jax_state"]
279-
rng_key, sampling_key = jax.random.split(rng_key, 2)
263+
def sample_fn(rng_key, size, dtype, scale):
280264
if size is None:
281265
size = jax.numpy.asarray(scale).shape
282-
sample = jax.random.exponential(sampling_key, size, dtype) * scale
283-
rng["jax_state"] = rng_key
284-
return (rng, sample)
266+
sample = jax.random.exponential(rng_key, size, dtype) * scale
267+
return sample
285268

286269
return sample_fn
287270

@@ -290,14 +273,11 @@ def sample_fn(rng, size, dtype, scale):
290273
def jax_sample_fn_t(op, node):
291274
"""JAX implementation of `StudentTRV`."""
292275

293-
def sample_fn(rng, size, dtype, df, loc, scale):
294-
rng_key = rng["jax_state"]
295-
rng_key, sampling_key = jax.random.split(rng_key, 2)
276+
def sample_fn(rng_key, size, dtype, df, loc, scale):
296277
if size is None:
297278
size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
298-
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
299-
rng["jax_state"] = rng_key
300-
return (rng, sample)
279+
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
280+
return sample
301281

302282
return sample_fn
303283

@@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
315295
"A default JAX rewrite should have materialized the implicit arange"
316296
)
317297

318-
def sample_fn(rng, size, dtype, *parameters):
319-
rng_key = rng["jax_state"]
320-
rng_key, sampling_key = jax.random.split(rng_key, 2)
321-
298+
def sample_fn(rng_key, size, dtype, *parameters):
322299
if op.has_p_param:
323300
a, p, core_shape = parameters
324301
else:
@@ -327,9 +304,7 @@ def sample_fn(rng, size, dtype, *parameters):
327304
core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim])
328305

329306
if batch_ndim == 0:
330-
sample = jax.random.choice(
331-
sampling_key, a, shape=core_shape, replace=False, p=p
332-
)
307+
sample = jax.random.choice(rng_key, a, shape=core_shape, replace=False, p=p)
333308

334309
else:
335310
if size is None:
@@ -345,7 +320,7 @@ def sample_fn(rng, size, dtype, *parameters):
345320
if p is not None:
346321
p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:])
347322

348-
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
323+
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
349324

350325
# Ravel the batch dimensions because vmap only works along a single axis
351326
raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:])
@@ -366,8 +341,7 @@ def sample_fn(rng, size, dtype, *parameters):
366341
# Reshape the batch dimensions
367342
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
368343

369-
rng["jax_state"] = rng_key
370-
return (rng, sample)
344+
return sample
371345

372346
return sample_fn
373347

@@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node):
378352

379353
batch_ndim = op.batch_ndim(node)
380354

381-
def sample_fn(rng, size, dtype, *parameters):
382-
rng_key = rng["jax_state"]
383-
rng_key, sampling_key = jax.random.split(rng_key, 2)
355+
def sample_fn(rng_key, size, dtype, *parameters):
384356
(x,) = parameters
385357
if batch_ndim:
386358
# jax.random.permutation has no concept of batch dims
@@ -389,17 +361,16 @@ def sample_fn(rng, size, dtype, *parameters):
389361
else:
390362
x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:])
391363

392-
batch_sampling_keys = jax.random.split(sampling_key, np.prod(size))
364+
batch_sampling_keys = jax.random.split(rng_key, np.prod(size))
393365
raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:])
394366
raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))(
395367
batch_sampling_keys, raveled_batch_x
396368
)
397369
sample = raveled_sample.reshape(size + raveled_sample.shape[1:])
398370
else:
399-
sample = jax.random.permutation(sampling_key, x)
371+
sample = jax.random.permutation(rng_key, x)
400372

401-
rng["jax_state"] = rng_key
402-
return (rng, sample)
373+
return sample
403374

404375
return sample_fn
405376

@@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node):
414385

415386
from numpyro.distributions.util import binomial
416387

417-
def sample_fn(rng, size, dtype, n, p):
418-
rng_key = rng["jax_state"]
419-
rng_key, sampling_key = jax.random.split(rng_key, 2)
420-
421-
sample = binomial(key=sampling_key, n=n, p=p, shape=size)
422-
423-
rng["jax_state"] = rng_key
424-
425-
return (rng, sample)
388+
def sample_fn(rng_key, size, dtype, n, p):
389+
sample = binomial(key=rng_key, n=n, p=p, shape=size)
390+
return sample
426391

427392
return sample_fn
428393

@@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node):
437402

438403
from numpyro.distributions.util import multinomial
439404

440-
def sample_fn(rng, size, dtype, n, p):
441-
rng_key = rng["jax_state"]
442-
rng_key, sampling_key = jax.random.split(rng_key, 2)
443-
444-
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
445-
446-
rng["jax_state"] = rng_key
447-
448-
return (rng, sample)
405+
def sample_fn(rng_key, size, dtype, n, p):
406+
sample = multinomial(key=rng_key, n=n, p=p, shape=size)
407+
return sample
449408

450409
return sample_fn
451410

@@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node):
460419

461420
from numpyro.distributions.util import von_mises_centered
462421

463-
def sample_fn(rng, size, dtype, mu, kappa):
464-
rng_key = rng["jax_state"]
465-
rng_key, sampling_key = jax.random.split(rng_key, 2)
466-
422+
def sample_fn(rng_key, size, dtype, mu, kappa):
467423
sample = von_mises_centered(
468-
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
424+
key=rng_key, concentration=kappa, shape=size, dtype=dtype
469425
)
470426
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
471427

472-
rng["jax_state"] = rng_key
473-
474-
return (rng, sample)
428+
return sample
475429

476430
return sample_fn

tests/link/jax/test_random.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -796,7 +796,7 @@ def rng_fn(cls, rng, size):
796796
@jax_sample_fn.register(CustomRV)
797797
def jax_sample_fn_custom(op, node):
798798
def sample_fn(rng, size, dtype, *parameters):
799-
return (rng, 0)
799+
return 0
800800

801801
return sample_fn
802802

0 commit comments

Comments
 (0)