-
Notifications
You must be signed in to change notification settings - Fork 132
Implement faster Multinomial JAX dispatch #1316
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement faster Multinomial JAX dispatch #1316
Conversation
Replace the call to numpyro.distributions.util.multinomial by a custom function
pytensor/link/jax/dispatch/random.py
Outdated
@@ -450,6 +443,42 @@ def sample_fn(rng, size, dtype, n, p): | |||
return sample_fn | |||
|
|||
|
|||
def _jax_multinomial(n, p, shape=None, key=None): | |||
if jnp.shape(n) != jnp.shape(p)[:-1]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Theres' some redundancy here. If size
(not shape btw) is provided we should just broadcast n
to size
and p
to size + p.shape[-1]
. Only if size
is not provided should we broadcast n
and p[...: -1]
together (basically finding the implicit size)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, besides there is actually no need to broadcast p at all; jax.random.binomial
broadcasts it for us.
About the name shape
: I kept the signature of the numpyro function but I can change it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might be right about p
although I guess it's more readable if you broadcast explicitly
pytensor/link/jax/dispatch/random.py
Outdated
return sample | ||
|
||
return sample_fn | ||
|
||
|
||
def _jax_multinomial(n, p, size=None, key=None): | ||
if size is not None: | ||
broadcast_shape_n = jax.lax.broadcast_shapes(jnp.shape(n), size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you don't need to broadcast size with n, that's not allowed by the RandomVariable semantics. If size is provided n must be broadcastable to size, but not the other way. You can's say n=ones((5, 1)), size=(1, 3)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed, my bad
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1316 +/- ##
=======================================
Coverage 81.98% 81.98%
=======================================
Files 188 188
Lines 48475 48489 +14
Branches 8673 8673
=======================================
+ Hits 39740 39756 +16
+ Misses 6583 6582 -1
+ Partials 2152 2151 -1
🚀 New features to boost your workflow:
|
It seems the |
It can be reached in some cases when for instance you have size=x.shape, and x is an input variable without known shape. |
pytensor/link/jax/dispatch/random.py
Outdated
return sample | ||
|
||
return sample_fn | ||
|
||
|
||
def _jax_multinomial(n, p, size=None, key=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can inline this in the dispatch for MultinomialRV. Also key=None is not really valid but we don't need default parameters anyway
I'm sorry I don't know how to update the test to cover this case Locally I've tried this without success:
|
Sorry what I sad was non-sense. Here is a test with size=None import pytensor.tensor as pt
from tests.link.jax.test_random import compile_random_function
n = pt.vector("n")
p = pt.vector("p")
g = pt.random.multinomial(n, p, size=None)
g_fn = compile_random_function([n, p], g, mode="JAX")
|
Thank you for your feedbacks ! |
Thank you for the feature!!! |
Awesome |
Description
Replace the call to numpyro.distributions.util.multinomial by a custom function written in jax implementing Sequential conditional binomial sampling
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pytensor--1316.org.readthedocs.build/en/1316/