diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 9af17b0a68..678b2dc486 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -1,6 +1,7 @@ from functools import singledispatch import jax +import jax.numpy as jnp import numpy as np from numpy.random import Generator from numpy.random.bit_generator import ( # type: ignore[attr-defined] @@ -394,16 +395,35 @@ def sample_fn(rng_key, size, dtype, n, p): @jax_sample_fn.register(ptr.MultinomialRV) def jax_sample_fn_multinomial(op, node): - if not numpyro_available: - raise NotImplementedError( - f"No JAX implementation for the given distribution: {op.name}. " - "Implementation is available if NumPyro is installed." - ) - - from numpyro.distributions.util import multinomial - def sample_fn(rng_key, size, dtype, n, p): - sample = multinomial(key=rng_key, n=n, p=p, shape=size) + if size is not None: + n = jnp.broadcast_to(n, size) + p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:]) + + else: + broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + n = jnp.broadcast_to(n, broadcast_shape) + p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) + + binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] + sampling_rng = jax.random.split(rng_key, binom_p.shape[0]) + + def _binomial_sample_fn(carry, p_rng): + s, rho = carry + p, rng = p_rng + samples = jax.random.binomial(rng, s, p / rho) + s = s - samples + rho = rho - p + return ((s, rho), samples) + + (remain, _), samples = jax.lax.scan( + _binomial_sample_fn, + (n.astype(np.float64), jnp.ones(binom_p.shape[1:])), + (binom_p, sampling_rng), + ) + sample = jnp.concatenate( + [jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1 + ) return sample return sample_fn diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index fb2f6d9bb9..183b629f79 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -703,14 +703,15 @@ def test_beta_binomial(): ) -@pytest.mark.skipif( - not numpyro_available, reason="Multinomial dispatch requires numpyro" -) def test_multinomial(): rng = shared(np.random.default_rng(123)) + + # test with 'size' argument and n.shape == p.shape[:-1] n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) - g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) + size = (10_000, 2) + + g = pt.random.multinomial(n, p, size=size, rng=rng) g_fn = compile_random_function([], g, mode="JAX") samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) @@ -718,6 +719,20 @@ def test_multinomial(): samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1 ) + # test with no 'size' argument and no static shape + n = np.broadcast_to(np.array([10, 40]), size) + p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) + pt_n = pt.matrix("n") + pt_p = pt.matrix("p") + + g = pt.random.multinomial(pt_n, pt_p, rng=rng, size=None) + g_fn = compile_random_function([pt_n, pt_p], g, mode="JAX") + samples = g_fn(n, p) + np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1) + np.testing.assert_allclose( + samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1 + ) + @pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") def test_vonmises_mu_outside_circle():