Skip to content

Add new notebook showcasing how to (un)wrap a JAX Op in Aesara #299

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
ricardoV94 opened this issue Mar 23, 2022 · 0 comments · Fixed by #302
Closed

Add new notebook showcasing how to (un)wrap a JAX Op in Aesara #299

ricardoV94 opened this issue Mar 23, 2022 · 0 comments · Fixed by #302
Assignees
Labels
proposal New notebook proposal still up for discussion

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 23, 2022

This has been requested a couple of times. It is in nature a bit similar to the Black-box likelihood, but applied to JAX graphs. It also demonstrates how to unwrap a wrapped Op so that the entire Aesara graph can be compiled to JAX.

I chose a cute Hidden HMM likelihood that makes use of scan and vmap. Here is a suggestion for a TOC:

image

And some motivating Intro I wrote so far:

PyMC uses the Aesara library to create and manipulate probabilistic graphs. Aesara is backend-agnostic, meaning it can make use of functions written in different languages or frameworks, including pure Python, C, Cython, NUMBA, and JAX. All that is needed is to encapsulate such function in a Aeasara Op that implements a specific API contract regarding how inputs and outputs should be handled, and possibly extra functionality like symbolic shape inference and gradient expressions. This is well covered in the Aesara documentation and black-box likelihood pymc-example.

More recentely, Aesara became capable of compiling directly to some of these languages/frameworks, meaning that we can convert a complete Aesara graph into a JAX or NUMBA jitted function, whereas traditionally they could only be converted to Python or C.

This has some interesting advantages, such as being able to sample models defined in PyMC using pure JAX samplers, such as those implemented in NumPyro or BlackJax.

This notebook illustrates how we can implement a new Aesara Op that wraps a JAX jitted function.

  1. We start in a similar path as that taken in the black box likelihood, which wraps a Cython function in a Aesara Op, this time wrapping a JAX jitted function instead.
  2. We then enable Aesara to "unwrap" this just wrapped JAX function, so that the whole graph can be compiled to JAX. We make use of this to sample our PyMC model via the BlackJAX NUTS sampler.
@ricardoV94 ricardoV94 added the proposal New notebook proposal still up for discussion label Mar 23, 2022
@ricardoV94 ricardoV94 self-assigned this Mar 23, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
proposal New notebook proposal still up for discussion
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant