From af9803cca720c75761f0516ce07e4cfa251b4a8d Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Sun, 23 Mar 2025 14:09:23 +0100 Subject: [PATCH 1/6] Implement multinomial JAX dispatch directly in jax Replace the call to numpyro.distributions.util.multinomial by a custom function --- pytensor/link/jax/dispatch/random.py | 45 ++++++++++++++++++++++------ 1 file changed, 36 insertions(+), 9 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 8a33dfac13..d1a838e128 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -1,6 +1,7 @@ from functools import singledispatch import jax +import jax.numpy as jnp import numpy as np from numpy.random import Generator from numpy.random.bit_generator import ( # type: ignore[attr-defined] @@ -429,19 +430,11 @@ def sample_fn(rng, size, dtype, n, p): @jax_sample_fn.register(ptr.MultinomialRV) def jax_sample_fn_multinomial(op, node): - if not numpyro_available: - raise NotImplementedError( - f"No JAX implementation for the given distribution: {op.name}. " - "Implementation is available if NumPyro is installed." - ) - - from numpyro.distributions.util import multinomial - def sample_fn(rng, size, dtype, n, p): rng_key = rng["jax_state"] rng_key, sampling_key = jax.random.split(rng_key, 2) - sample = multinomial(key=sampling_key, n=n, p=p, shape=size) + sample = _jax_multinomial(key=sampling_key, n=n, p=p, shape=size) rng["jax_state"] = rng_key @@ -450,6 +443,40 @@ def sample_fn(rng, size, dtype, n, p): return sample_fn +def _jax_multinomial(n, p, shape=None, key=None): + if jnp.shape(n) != jnp.shape(p)[:-1]: + broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + n = jnp.broadcast_to(n, broadcast_shape) + p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) + if shape is not None: + broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), shape) + n = jnp.broadcast_to(n, broadcast_shape) + else: + shape = shape or p.shape[:-1] + + p = p / jnp.sum(p, axis=-1, keepdims=True) + binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] + + sampling_rng = jax.random.split(key, binom_p.shape[0]) + + def _binomial_sample_fn(carry, p_rng): + s, rho = carry + p, rng = p_rng + samples = jax.random.binomial(rng, s, p / rho, shape) + s = s - samples + rho = rho - p + return ((s, rho), samples) + + (remain, _), samples = jax.lax.scan( + _binomial_sample_fn, + (n.astype("float"), jnp.ones(binom_p.shape[1:])), + (binom_p, sampling_rng), + ) + return jnp.concatenate( + [jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1 + ) + + @jax_sample_fn.register(ptr.VonMisesRV) def jax_sample_fn_vonmises(op, node): if not numpyro_available: From b551daf8bea9bec32782182733f4b1c1266ed73a Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Sun, 23 Mar 2025 19:12:11 +0100 Subject: [PATCH 2/6] improve test coverage on multinomial jax dispatch --- pytensor/link/jax/dispatch/random.py | 6 ++++-- tests/link/jax/test_random.py | 20 ++++++++++++++++---- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index d1a838e128..d37cc90c54 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -448,11 +448,13 @@ def _jax_multinomial(n, p, shape=None, key=None): broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) n = jnp.broadcast_to(n, broadcast_shape) p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) + if shape is not None: broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), shape) n = jnp.broadcast_to(n, broadcast_shape) + else: - shape = shape or p.shape[:-1] + shape = p.shape[:-1] p = p / jnp.sum(p, axis=-1, keepdims=True) binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] @@ -469,7 +471,7 @@ def _binomial_sample_fn(carry, p_rng): (remain, _), samples = jax.lax.scan( _binomial_sample_fn, - (n.astype("float"), jnp.ones(binom_p.shape[1:])), + (n.astype(np.float64), jnp.ones(binom_p.shape[1:])), (binom_p, sampling_rng), ) return jnp.concatenate( diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 2a6ebca0af..baf85da73e 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -703,14 +703,15 @@ def test_beta_binomial(): ) -@pytest.mark.skipif( - not numpyro_available, reason="Multinomial dispatch requires numpyro" -) def test_multinomial(): rng = shared(np.random.default_rng(123)) + + # test with 'size' argument and n.shape == p.shape[:-1] n = np.array([10, 40]) p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) - g = pt.random.multinomial(n, p, size=(10_000, 2), rng=rng) + size = (10_000, 2) + + g = pt.random.multinomial(n, p, size=size, rng=rng) g_fn = compile_random_function([], g, mode="JAX") samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[..., None] * p, rtol=0.1) @@ -718,6 +719,17 @@ def test_multinomial(): samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1 ) + # test with no 'size' argument and n.shape != p.shape[:-1] + n = np.broadcast_to(np.array([10, 40]), size) + + g = pt.random.multinomial(n, p, rng=rng) + g_fn = compile_random_function([], g, mode="JAX") + samples = g_fn() + np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1) + np.testing.assert_allclose( + samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1 + ) + @pytest.mark.skipif(not numpyro_available, reason="VonMises dispatch requires numpyro") def test_vonmises_mu_outside_circle(): From 9e6ac0b2beb304eef3aa4ea57299878c0adf2684 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Mon, 24 Mar 2025 13:45:44 +0100 Subject: [PATCH 3/6] broadcast explicitly the arguments in mulinomial jax dispatch --- pytensor/link/jax/dispatch/random.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 05917e9394..9a5d6c4c84 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -402,28 +402,26 @@ def sample_fn(rng_key, size, dtype, n, p): return sample_fn -def _jax_multinomial(n, p, shape=None, key=None): - if jnp.shape(n) != jnp.shape(p)[:-1]: - broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) - n = jnp.broadcast_to(n, broadcast_shape) - p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) +def _jax_multinomial(n, p, size=None, key=None): + if size is not None: + broadcast_shape_n = jax.lax.broadcast_shapes(jnp.shape(n), size) + n = jnp.broadcast_to(n, broadcast_shape_n) - if shape is not None: - broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), shape) - n = jnp.broadcast_to(n, broadcast_shape) + broadcast_shape_p = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + p = jnp.broadcast_to(p, broadcast_shape_p + jnp.shape(p)[-1:]) else: - shape = p.shape[:-1] + broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + n = jnp.broadcast_to(n, broadcast_shape) + p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) - p = p / jnp.sum(p, axis=-1, keepdims=True) binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] - sampling_rng = jax.random.split(key, binom_p.shape[0]) def _binomial_sample_fn(carry, p_rng): s, rho = carry p, rng = p_rng - samples = jax.random.binomial(rng, s, p / rho, shape) + samples = jax.random.binomial(rng, s, p / rho) s = s - samples rho = rho - p return ((s, rho), samples) From 417275249e351e42dcdff70d0fa9f8e9450aa118 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Mon, 24 Mar 2025 14:01:44 +0100 Subject: [PATCH 4/6] fix typo error in multinomial jax dispatch --- pytensor/link/jax/dispatch/random.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 9a5d6c4c84..c8aa3fdc05 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -396,7 +396,7 @@ def sample_fn(rng_key, size, dtype, n, p): @jax_sample_fn.register(ptr.MultinomialRV) def jax_sample_fn_multinomial(op, node): def sample_fn(rng_key, size, dtype, n, p): - sample = _jax_multinomial(key=rng_key, n=n, p=p, shape=size) + sample = _jax_multinomial(key=rng_key, n=n, p=p, size=size) return sample return sample_fn @@ -407,7 +407,7 @@ def _jax_multinomial(n, p, size=None, key=None): broadcast_shape_n = jax.lax.broadcast_shapes(jnp.shape(n), size) n = jnp.broadcast_to(n, broadcast_shape_n) - broadcast_shape_p = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + broadcast_shape_p = jax.lax.broadcast_shapes(jnp.shape(p)[:-1], size) p = jnp.broadcast_to(p, broadcast_shape_p + jnp.shape(p)[-1:]) else: From aef0d4bd5505a9bd124e41f81d97272b60b410cd Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Mon, 24 Mar 2025 17:55:01 +0100 Subject: [PATCH 5/6] fix broadcast in multinomial jax dispatch --- pytensor/link/jax/dispatch/random.py | 7 ++----- tests/link/jax/test_random.py | 2 +- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index c8aa3fdc05..f6e84d04ac 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -404,11 +404,8 @@ def sample_fn(rng_key, size, dtype, n, p): def _jax_multinomial(n, p, size=None, key=None): if size is not None: - broadcast_shape_n = jax.lax.broadcast_shapes(jnp.shape(n), size) - n = jnp.broadcast_to(n, broadcast_shape_n) - - broadcast_shape_p = jax.lax.broadcast_shapes(jnp.shape(p)[:-1], size) - p = jnp.broadcast_to(p, broadcast_shape_p + jnp.shape(p)[-1:]) + n = jnp.broadcast_to(n, size) + p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:]) else: broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 12b9c8a3e3..6f0ffc0630 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -722,7 +722,7 @@ def test_multinomial(): # test with no 'size' argument and n.shape != p.shape[:-1] n = np.broadcast_to(np.array([10, 40]), size) - g = pt.random.multinomial(n, p, rng=rng) + g = pt.random.multinomial(n, p, rng=rng, size=None) g_fn = compile_random_function([], g, mode="JAX") samples = g_fn() np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1) From 81f0b12aec15273966407d1d9e704515bc405688 Mon Sep 17 00:00:00 2001 From: Etienne Duchesne Date: Tue, 25 Mar 2025 16:34:22 +0100 Subject: [PATCH 6/6] improve test coverage in multinomial jax dispatch and inline auxiliary function --- pytensor/link/jax/dispatch/random.py | 60 +++++++++++++--------------- tests/link/jax/test_random.py | 11 +++-- 2 files changed, 35 insertions(+), 36 deletions(-) diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index f6e84d04ac..678b2dc486 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -396,43 +396,39 @@ def sample_fn(rng_key, size, dtype, n, p): @jax_sample_fn.register(ptr.MultinomialRV) def jax_sample_fn_multinomial(op, node): def sample_fn(rng_key, size, dtype, n, p): - sample = _jax_multinomial(key=rng_key, n=n, p=p, size=size) + if size is not None: + n = jnp.broadcast_to(n, size) + p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:]) + + else: + broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) + n = jnp.broadcast_to(n, broadcast_shape) + p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) + + binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] + sampling_rng = jax.random.split(rng_key, binom_p.shape[0]) + + def _binomial_sample_fn(carry, p_rng): + s, rho = carry + p, rng = p_rng + samples = jax.random.binomial(rng, s, p / rho) + s = s - samples + rho = rho - p + return ((s, rho), samples) + + (remain, _), samples = jax.lax.scan( + _binomial_sample_fn, + (n.astype(np.float64), jnp.ones(binom_p.shape[1:])), + (binom_p, sampling_rng), + ) + sample = jnp.concatenate( + [jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1 + ) return sample return sample_fn -def _jax_multinomial(n, p, size=None, key=None): - if size is not None: - n = jnp.broadcast_to(n, size) - p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:]) - - else: - broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) - n = jnp.broadcast_to(n, broadcast_shape) - p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) - - binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] - sampling_rng = jax.random.split(key, binom_p.shape[0]) - - def _binomial_sample_fn(carry, p_rng): - s, rho = carry - p, rng = p_rng - samples = jax.random.binomial(rng, s, p / rho) - s = s - samples - rho = rho - p - return ((s, rho), samples) - - (remain, _), samples = jax.lax.scan( - _binomial_sample_fn, - (n.astype(np.float64), jnp.ones(binom_p.shape[1:])), - (binom_p, sampling_rng), - ) - return jnp.concatenate( - [jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1 - ) - - @jax_sample_fn.register(ptr.VonMisesRV) def jax_sample_fn_vonmises(op, node): if not numpyro_available: diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 6f0ffc0630..183b629f79 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -719,12 +719,15 @@ def test_multinomial(): samples.std(axis=0), np.sqrt(n[..., None] * p * (1 - p)), rtol=0.1 ) - # test with no 'size' argument and n.shape != p.shape[:-1] + # test with no 'size' argument and no static shape n = np.broadcast_to(np.array([10, 40]), size) + p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]]) + pt_n = pt.matrix("n") + pt_p = pt.matrix("p") - g = pt.random.multinomial(n, p, rng=rng, size=None) - g_fn = compile_random_function([], g, mode="JAX") - samples = g_fn() + g = pt.random.multinomial(pt_n, pt_p, rng=rng, size=None) + g_fn = compile_random_function([pt_n, pt_p], g, mode="JAX") + samples = g_fn(n, p) np.testing.assert_allclose(samples.mean(axis=0), n[0, :, None] * p, rtol=0.1) np.testing.assert_allclose( samples.std(axis=0), np.sqrt(n[0, :, None] * p * (1 - p)), rtol=0.1