Skip to content

Commit 8a7356c

Browse files
authored
Implement faster Multinomial JAX dispatch (#1316)
1 parent 2e9d502 commit 8a7356c

File tree

2 files changed

+48
-13
lines changed

2 files changed

+48
-13
lines changed

pytensor/link/jax/dispatch/random.py

+29-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from functools import singledispatch
22

33
import jax
4+
import jax.numpy as jnp
45
import numpy as np
56
from numpy.random import Generator
67
from numpy.random.bit_generator import ( # type: ignore[attr-defined]
@@ -394,16 +395,35 @@ def sample_fn(rng_key, size, dtype, n, p):
394395

395396
@jax_sample_fn.register(ptr.MultinomialRV)
396397
def jax_sample_fn_multinomial(op, node):
397-
if not numpyro_available:
398-
raise NotImplementedError(
399-
f"No JAX implementation for the given distribution: {op.name}. "
400-
"Implementation is available if NumPyro is installed."
401-
)
402-
403-
from numpyro.distributions.util import multinomial
404-
405398
def sample_fn(rng_key, size, dtype, n, p):
406-
sample = multinomial(key=rng_key, n=n, p=p, shape=size)
399+
if size is not None:
400+
n = jnp.broadcast_to(n, size)
401+
p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:])
402+
403+
else:
404+
broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1])
405+
n = jnp.broadcast_to(n, broadcast_shape)
406+
p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:])
407+
408+
binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...]
409+
sampling_rng = jax.random.split(rng_key, binom_p.shape[0])
410+
411+
def _binomial_sample_fn(carry, p_rng):
412+
s, rho = carry
413+
p, rng = p_rng
414+
samples = jax.random.binomial(rng, s, p / rho)
415+
s = s - samples
416+
rho = rho - p
417+
return ((s, rho), samples)
418+
419+
(remain, _), samples = jax.lax.scan(
420+
_binomial_sample_fn,
421+
(n.astype(np.float64), jnp.ones(binom_p.shape[1:])),
422+
(binom_p, sampling_rng),
423+
)
424+
sample = jnp.concatenate(
425+
[jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1
426+
)
407427
return sample
408428

409429
return sample_fn

tests/link/jax/test_random.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -703,21 +703,36 @@ def test_beta_binomial():
703703
)
704704

705705

706-
@pytest.mark.skipif(
707-
not numpyro_available, reason="Multinomial dispatch requires numpyro"
708-
)
709706
def test_multinomial():
710707
rng = shared(np.random.default_rng(123))
708+
709+
# test with 'size' argument and n.shape == p.shape[:-1]
711710
n = np.array([10, 40])
712711
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
713-
g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng)
712+
size = (10_000, 2)
713+
714+
g = pt.random.multinomial(n, p, size=size, rng=rng)
714715
g_fn = compile_random_function([], g, mode="JAX")
715716
samples = g_fn()
716717
np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1)
717718
np.testing.assert_allclose(
718719
samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1
719720
)
720721

722+
# test with no 'size' argument and no static shape
723+
n = np.broadcast_to(np.array([10, 40]), size)
724+
p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
725+
pt_n = pt.matrix("n")
726+
pt_p = pt.matrix("p")
727+
728+
g = pt.random.multinomial(pt_n, pt_p, rng=rng, size=None)
729+
g_fn = compile_random_function([pt_n, pt_p], g, mode="JAX")
730+
samples = g_fn(n, p)
731+
np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1)
732+
np.testing.assert_allclose(
733+
samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1
734+
)
735+
721736

722737
@pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro")
723738
def test_vonmises_mu_outside_circle():

0 commit comments

Comments
 (0)