Skip to content

Implement faster Multinomial JAX dispatch #1316

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 25, 2025

Conversation

educhesne
Copy link
Contributor

@educhesne educhesne commented Mar 23, 2025

Description

Replace the call to numpyro.distributions.util.multinomial by a custom function written in jax implementing Sequential conditional binomial sampling

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pytensor--1316.org.readthedocs.build/en/1316/

Etienne Duchesne added 2 commits March 23, 2025 14:09
Replace the call to numpyro.distributions.util.multinomial by a custom function
@@ -450,6 +443,42 @@ def sample_fn(rng, size, dtype, n, p):
return sample_fn


def _jax_multinomial(n, p, shape=None, key=None):
if jnp.shape(n) != jnp.shape(p)[:-1]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Theres' some redundancy here. If size (not shape btw) is provided we should just broadcast n to size and p to size + p.shape[-1]. Only if size is not provided should we broadcast n and p[...: -1] together (basically finding the implicit size)

Copy link
Contributor Author

@educhesne educhesne Mar 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, besides there is actually no need to broadcast p at all; jax.random.binomial broadcasts it for us.
About the name shape: I kept the signature of the numpyro function but I can change it

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might be right about p although I guess it's more readable if you broadcast explicitly

return sample

return sample_fn


def _jax_multinomial(n, p, size=None, key=None):
if size is not None:
broadcast_shape_n = jax.lax.broadcast_shapes(jnp.shape(n), size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you don't need to broadcast size with n, that's not allowed by the RandomVariable semantics. If size is provided n must be broadcastable to size, but not the other way. You can's say n=ones((5, 1)), size=(1, 3)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, my bad

Copy link

codecov bot commented Mar 24, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 81.98%. Comparing base (2e9d502) to head (81f0b12).
Report is 1 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1316   +/-   ##
=======================================
  Coverage   81.98%   81.98%           
=======================================
  Files         188      188           
  Lines       48475    48489   +14     
  Branches     8673     8673           
=======================================
+ Hits        39740    39756   +16     
+ Misses       6583     6582    -1     
+ Partials     2152     2151    -1     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/random.py 94.01% <100.00%> (+1.28%) ⬆️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@educhesne
Copy link
Contributor Author

It seems the sample_fn function is never called with size=None even though pt.random.multinomial(..., size=None) is in the test. Could it be that the None size is catched in between ? (there is a static_shape happening in jax_funcify_RandomVariable I don't fully understand)
Should I remove the part which is not reached by the coverage test ?

@ricardoV94
Copy link
Member

It can be reached in some cases when for instance you have size=x.shape, and x is an input variable without known shape.

return sample

return sample_fn


def _jax_multinomial(n, p, size=None, key=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can inline this in the dispatch for MultinomialRV. Also key=None is not really valid but we don't need default parameters anyway

@educhesne
Copy link
Contributor Author

It can be reached in some cases when for instance you have size=x.shape, and x is an input variable without known shape.

I'm sorry I don't know how to update the test to cover this case

Locally I've tried this without success:

p = np.array([[0.3, 0.7, 0.0], [0.1, 0.4, 0.5]])
size = (10_000, 2)
n = np.broadcast_to(np.array([10, 40]), size)
x = pt.matrix('x')

g = pt.random.multinomial(n, p, rng=rng, size=x.shape)
g_fn = compile_random_function([x], g, mode="JAX")
samples = g_fn(np.zeros(size))

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 25, 2025

Sorry what I sad was non-sense. Here is a test with size=None

import pytensor.tensor as pt

from tests.link.jax.test_random import compile_random_function

n = pt.vector("n")
p = pt.vector("p")

g = pt.random.multinomial(n, p, size=None)
g_fn = compile_random_function([n, p], g, mode="JAX")

size is None, and it can't be inferred from n because n doesn't have a static shape.

@educhesne
Copy link
Contributor Author

Thank you for your feedbacks !

@ricardoV94
Copy link
Member

Thank you for the feature!!!

@ricardoV94 ricardoV94 merged commit 8a7356c into pymc-devs:main Mar 25, 2025
71 checks passed
@ricardoV94
Copy link
Member

Awesome

@educhesne educhesne deleted the multinomial_jax_dispatch branch March 25, 2025 16:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement faster Multinomial JAX dispatch
2 participants