You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This has been requested a couple of times. It is in nature a bit similar to the Black-box likelihood, but applied to JAX graphs. It also demonstrates how to unwrap a wrapped Op so that the entire Aesara graph can be compiled to JAX.
I chose a cute Hidden HMM likelihood that makes use of scan and vmap. Here is a suggestion for a TOC:
And some motivating Intro I wrote so far:
PyMC uses the Aesara library to create and manipulate probabilistic graphs. Aesara is backend-agnostic, meaning it can make use of functions written in different languages or frameworks, including pure Python, C, Cython, NUMBA, and JAX. All that is needed is to encapsulate such function in a Aeasara Op that implements a specific API contract regarding how inputs and outputs should be handled, and possibly extra functionality like symbolic shape inference and gradient expressions. This is well covered in the Aesara documentation and black-box likelihood pymc-example.
More recentely, Aesara became capable of compiling directly to some of these languages/frameworks, meaning that we can convert a complete Aesara graph into a JAX or NUMBA jitted function, whereas traditionally they could only be converted to Python or C.
This has some interesting advantages, such as being able to sample models defined in PyMC using pure JAX samplers, such as those implemented in NumPyro or BlackJax.
This notebook illustrates how we can implement a new Aesara Op that wraps a JAX jitted function.
We start in a similar path as that taken in the black box likelihood, which wraps a Cython function in a Aesara Op, this time wrapping a JAX jitted function instead.
We then enable Aesara to "unwrap" this just wrapped JAX function, so that the whole graph can be compiled to JAX. We make use of this to sample our PyMC model via the BlackJAX NUTS sampler.
The text was updated successfully, but these errors were encountered:
This has been requested a couple of times. It is in nature a bit similar to the Black-box likelihood, but applied to JAX graphs. It also demonstrates how to unwrap a wrapped Op so that the entire Aesara graph can be compiled to JAX.
I chose a cute Hidden HMM likelihood that makes use of scan and vmap. Here is a suggestion for a TOC:
And some motivating Intro I wrote so far:
The text was updated successfully, but these errors were encountered: