|
1 | 1 | from functools import singledispatch
|
2 | 2 |
|
3 | 3 | import jax
|
| 4 | +import jax.numpy as jnp |
4 | 5 | import numpy as np
|
5 | 6 | from numpy.random import Generator
|
6 | 7 | from numpy.random.bit_generator import ( # type: ignore[attr-defined]
|
@@ -394,16 +395,35 @@ def sample_fn(rng_key, size, dtype, n, p):
|
394 | 395 |
|
395 | 396 | @jax_sample_fn.register(ptr.MultinomialRV)
|
396 | 397 | def jax_sample_fn_multinomial(op, node):
|
397 |
| - if not numpyro_available: |
398 |
| - raise NotImplementedError( |
399 |
| - f"No JAX implementation for the given distribution: {op.name}. " |
400 |
| - "Implementation is available if NumPyro is installed." |
401 |
| - ) |
402 |
| - |
403 |
| - from numpyro.distributions.util import multinomial |
404 |
| - |
405 | 398 | def sample_fn(rng_key, size, dtype, n, p):
|
406 |
| - sample = multinomial(key=rng_key, n=n, p=p, shape=size) |
| 399 | + if size is not None: |
| 400 | + n = jnp.broadcast_to(n, size) |
| 401 | + p = jnp.broadcast_to(p, size + jnp.shape(p)[-1:]) |
| 402 | + |
| 403 | + else: |
| 404 | + broadcast_shape = jax.lax.broadcast_shapes(jnp.shape(n), jnp.shape(p)[:-1]) |
| 405 | + n = jnp.broadcast_to(n, broadcast_shape) |
| 406 | + p = jnp.broadcast_to(p, broadcast_shape + jnp.shape(p)[-1:]) |
| 407 | + |
| 408 | + binom_p = jnp.moveaxis(p, -1, 0)[:-1, ...] |
| 409 | + sampling_rng = jax.random.split(rng_key, binom_p.shape[0]) |
| 410 | + |
| 411 | + def _binomial_sample_fn(carry, p_rng): |
| 412 | + s, rho = carry |
| 413 | + p, rng = p_rng |
| 414 | + samples = jax.random.binomial(rng, s, p / rho) |
| 415 | + s = s - samples |
| 416 | + rho = rho - p |
| 417 | + return ((s, rho), samples) |
| 418 | + |
| 419 | + (remain, _), samples = jax.lax.scan( |
| 420 | + _binomial_sample_fn, |
| 421 | + (n.astype(np.float64), jnp.ones(binom_p.shape[1:])), |
| 422 | + (binom_p, sampling_rng), |
| 423 | + ) |
| 424 | + sample = jnp.concatenate( |
| 425 | + [jnp.moveaxis(samples, 0, -1), jnp.expand_dims(remain, -1)], axis=-1 |
| 426 | + ) |
407 | 427 | return sample
|
408 | 428 |
|
409 | 429 | return sample_fn
|
|
0 commit comments