Skip to content

Add guide on how to wrap a JAX function in a Aesara Op #302

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Apr 14, 2022

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 24, 2022

Closes #299

To make the merge process smoother we've provided some links and a checklist below.

We understand that PRs can sometimes be overwhelming, especially as the reviews start coming in.
Please let us know if the reviews are unclear or the recommended next step seems overly demanding,
if you would like help in addressing a reviewer's comments,
or if you have been waiting too long to hear back on your PR. -->

Helpful links

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@ricardoV94 ricardoV94 requested a review from twiecki March 24, 2022 13:52
@ricardoV94 ricardoV94 added the tree: advanced Internal use label label Mar 24, 2022
@ricardoV94 ricardoV94 requested a review from OriolAbril March 24, 2022 13:53
@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

canyon289 commented on 2022-03-24T14:46:24Z
----------------------------------------------------------------

Aeasara typo


@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

canyon289 commented on 2022-03-24T14:46:25Z
----------------------------------------------------------------

of our hmm


@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

canyon289 commented on 2022-03-24T14:49:09Z
----------------------------------------------------------------

Line #11.            # Aesara complains if the dtype of the returned output is not

Replace complains to the precise thing that happen, does it warn, raise and error something else?


@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

twiecki commented on 2022-03-24T14:49:38Z
----------------------------------------------------------------

* It also implements


@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

twiecki commented on 2022-03-24T14:49:39Z
----------------------------------------------------------------

*hour -> our


@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

twiecki commented on 2022-03-24T14:49:40Z
----------------------------------------------------------------

isn't there a compiled_logpt method?


ricardoV94 commented on 2022-03-24T15:15:01Z
----------------------------------------------------------------

Good point

ricardoV94 commented on 2022-03-24T17:31:17Z
----------------------------------------------------------------

Actually that doesn't accept the mode kwarg, but I simplified this to use model.compile_fn

@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

twiecki commented on 2022-03-24T14:49:41Z
----------------------------------------------------------------

writing graphs

might


@review-notebook-app
Copy link

review-notebook-app bot commented Mar 24, 2022

View / edit / reply to this conversation on ReviewNB

twiecki commented on 2022-03-24T14:49:42Z
----------------------------------------------------------------

Like JAX, Aesara has the goal of mimicking the NumPy and Scipy APIs, so that writting code in Aesara should feel very similar to how code is written in those libraries. So if you can, you should usually write your custom function using Aesara operations (which compile to JAX anyway) rather than JAX directly.

Or maybe move that note further up.


ricardoV94 commented on 2022-03-24T17:35:59Z
----------------------------------------------------------------

Moved the note above the bonus. Not sure about the stronger message. I actually had it even stronger at first, but then tuned it down to not sound too opinionated

@twiecki
Copy link
Member

twiecki commented Mar 24, 2022

This is awesome. I would more clearly motivate and delineate the two approaches you are showing and in which scenarios a user might use one over the other. I would put that early on.

It's definitely advanced and it would be great if we could simplify it. One would be to just make simpler wrappers so that less boilerplate code is required (long-term), the other would be a simpler example here. I think it works well and shows something more advanced which is good, but it makes the example even heavier. Why not use a dummy Op?

@ricardoV94
Copy link
Member Author

It's definitely advanced and it would be great if we could simplify it. One would be to just make simpler wrappers so that less boilerplate code is required (long-term), the other would be a simpler example here. I think it works well and shows something more advanced which is good, but it makes the example even heavier. Why not use a dummy Op?

I think that between the black-box likelihood and the Aesara [how to jaxify documentation] users have enough examples to figure out how to write their own wrappers, so I wanted to go a bit more in depth. We can also link to the this gist and/or add it in the docs directly (but that would not be enough for a pymc-examples imo): https://gist.github.com/dfm/a2db466f46ab931947882b08b2f21558

Copy link
Member Author

Good point


View entire conversation on ReviewNB

@ricardoV94
Copy link
Member Author

This is awesome. I would more clearly motivate and delineate the two approaches you are showing and in which scenarios a user might use one over the other

What two approaches?

The unwrapping? That requires having a Aesara Op in the first place.

Or do you mean the bonus part?

Copy link
Member Author

Actually that doesn't accept the mode kwarg, but I simplified this to use model.compile_fn


View entire conversation on ReviewNB

Copy link
Member Author

Moved the note above the bonus. Not sure about the stronger message. I actually had it even stronger at first, but then tuned it down to not sound too opinionated


View entire conversation on ReviewNB

@OriolAbril
Copy link
Member

Haven't had time to review the notebook, but from skimming the comments I think that this is precisely one of the situations where diataxis helps. IIUC, this notebook is a mixture of how-to and explanation according to diataxis. Thomas was also saying we'd need a tutorial (diataxis separates how-to and tutorials, the descriptions of each type there are very good). We could aim for these 3 notebooks, one of each type that link to each other. i.e. the how-to is for jax and pymc users that already know they need or want to use jax, as diataxis says, it is a recipe, it should not be expected to be followable by someone who doesn't already know how to cook, so the intro could point them to the tutorial for the base case or explanation doc if they don't know whether they should use jax or if they want to understand the pymc-aesara-jax integration in depth.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 24, 2022

Haven't had time to review the notebook, but from skimming the comments I think that this is precisely one of the situations where diataxis helps. IIUC, this notebook is a mixture of how-to and explanation according to diataxis. Thomas was also saying we'd need a tutorial (diataxis separates how-to and tutorials, the descriptions of each type there are very good). We could aim for these 3 notebooks, one of each type that link to each other. i.e. the how-to is for jax and pymc users that already know they need or want to use jax, as diataxis says, it is a recipe, it should not be expected to be followable by someone who doesn't already know how to cook, so the intro could point them to the tutorial for the base case or explanation doc if they don't know whether they should use jax or if they want to understand the pymc-aesara-jax integration in depth.

Good points. I am not sure about the diataxis categories. Skimming through the website, I would think this notebook includes how-to and perhaps tutorial. I am curious what part of it you think fits as an explanation.

I am okay with restructuring or splitting the notebook if there's a good reason for it. Right now I am not yet convinced there is enough material for that.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 24, 2022

Good points. I am not sure about the diataxis categories. Skimming through the website, I would think this notebook includes how-to and perhaps tutorial. I am curious what part of it you think fits as an explanation.

Perhaps you were thinking of the "which one should you use, Aesara vs JAX"?

@ricardoV94 ricardoV94 force-pushed the wrap_jax_function branch 4 times, most recently from dfae082 to 8e98244 Compare March 25, 2022 09:46
@review-notebook-app
Copy link

review-notebook-app bot commented Mar 31, 2022

View / edit / reply to this conversation on ReviewNB

OriolAbril commented on 2022-03-31T22:39:26Z
----------------------------------------------------------------

I would add also jax as tag, and use hidden markov model without the acronym later in the tag. This should be added to the style guide too because it is not clear at all, but the how to category should be how-to with a dash


ricardoV94 commented on 2022-04-11T15:18:40Z
----------------------------------------------------------------

Thanks

@review-notebook-app
Copy link

review-notebook-app bot commented Mar 31, 2022

View / edit / reply to this conversation on ReviewNB

OriolAbril commented on 2022-03-31T22:39:32Z
----------------------------------------------------------------

similar comment to the one above. These explanations as a section towards the bottom of this notebook makes the content hard to find. I think it would be better off as an independent notebook even if short and even if it has no code (you can even write it in markdown if that were the case and you felt like it, if it has the post directive it will be parsed and added to the website as if it were a notebook)


ricardoV94 commented on 2022-04-11T16:00:11Z
----------------------------------------------------------------

I can remove this section altogether. The reason I added was to dissuade users from just putting everything inside JAX Op's, sometimes simply because they know that library better, or are unaware that most of the same things can be done with Aesara.

Copy link
Member

I totally understands Oriol's objections here, but this same issue has been raised before and hence we may need to discuss and agree about how are we going to organize de documentation. Or if that has been discussed the criteria should be stated somewhere.

I see this notebook as advanced in the sense that knowing what Aesara or an Op are, is not something a Begginer Bayesian practitioner needs to know to perform Bayesian stats. Even an advanced modeller may not need to directly interact with Aesara for a successful modelling.

And then I see that this notebook is introductory in the sense that provides an intro to Aesara and Ops.


View entire conversation on ReviewNB

Copy link
Member

These are some of the questions we discussed at the documentation meetings, especially while Martina was working on GSoD.

this same issue has been raised before and hence we may need to discuss and agree about how are we going to organize de documentation. Or if that has been discussed the criteria should be stated somewhere

This was one of the main goals of GSoD last year. After going though the material and discussing it in some meetings, Martina proposed to use tags and categories instead of the hardcoded hierarchy used in v3 docs. My only contribution was to then set up sphinx to use tags and categories. Martina also created https://github.com/pymc-devs/pymc/wiki/Categories-and-Tags-for-PyMC-Examples as a temporal place to check for tag ideas mostly (which hasn't been updated as extra tags were used but that was the idea, the list of tags in the website should already serve this goal once we have enough notebooks with metadata).

In a later meeting we also decided to use 2 categories per notebook instead of only one: level and diataxis type.

I see this notebook as advanced in the sense that knowing what Aesara or an Op are, is not something a Begginer Bayesian practitioner needs to know to perform Bayesian stats. Even an advanced modeller may not need to directly interact with Aesara for a successful modelling.

And then I see that this notebook is introductory in the sense that provides an intro to Aesara and Ops.

The beginner/intermediate/advanced concepts are very complicated to use because everyone already has an idea about what each of them means, and even providing definitions for the meaning of the categories is of little use to readers who will probably not find it. To mitigate these effect we discussed alternative names like learning/using/mastering but weren't completely convinced with any of the options (that was the best one we came up with IIRC) and somehow tried the bird/dragon thing to avoid using any words and see how it looked and it ended up being well liked (within the docs meeting consensus which were often 3-4 people).

The general idea, bird or different name was to hide a little bit the beginner/intermediate/advanced concept from readers, focusing more on a complexity gradient than on some "absolute" levels. Therefore the levels are placed in a line to indicate a gradient of complexity/advancedness from left to right. Hopefully, a dragon icon doesn't come with too many baggage related to the difficulty of the notebook it categorizes and readers will form a more accurate idea by themselves by reading the examples, going "back" to the flying bird if the dragon notebooks go over their head... (the category text is still shown in the category archive pages and in page exerpts though, at least for now, we'll see if that is enough if it works and decide to go harder to the bird+dragon iconography or get back to plain words).

From the questions above, there is a 2nd probably even more important topic that transpires: the scope of the pymc docs. By pymc docs I mean the docs/source folder in pymc and pymc-examples. Mostly from a pragmatic point of view, because we can't teach everything ourselves nor maintain documentation for everything, we ended up agreeing that the scope of the pymc docs is teaching pymc. It is not teaching Bayesian statistics in the same way it is not teaching python. The "learning bayesian stats while/and learning pymc all at the same time" gap which is a bit of a grey area is covered by pymc-resources piggybacking on books that teach bayesian stats and providing the code examples in pymc, no deeper involvement.

In that light, using Aesara Ops explicitly is something only intermediate/advanced pymc users should do imo. The level should not come so much by the depth in which the topic of the notebook is covered but more by the level of that task within the "pymc use scale". As I mentioned in the comments, I envisioned this notebook as an advanced notebook for people who know what they are doing and want/need to interface directly with JAX (i.e. writing things directly in Aesara is nice, but might not be realistic if the model depends on external libraries written in jax), and that this notebook would be complemented by other notebooks on Aesara, Ops, the pros and cons of the different things that can be done with ops...


View entire conversation on ReviewNB

@twiecki
Copy link
Member

twiecki commented Apr 11, 2022

Do you still want to port this to use a BNN?

@ricardoV94
Copy link
Member Author

No, I changed my mind. I think adding a third party library just obfuscates the process (unless a reader is familiar with such library)

Copy link
Member Author

Thanks


View entire conversation on ReviewNB

Copy link
Member Author

I am happy to change the how-to / level tags to something more appropriate. I don't think that users who have a good mental understanding of Ops need this guide at all.

I do think that many users (will) hear about the fact that you can use JAX functions in PyMC and that PyMC models can be compiled to JAX and what to do that. I don't think these users will be necessarily interested in learning about Ops, and will want to learn about them without some applied context.

Part of the reason why I went a bit more in detail here, is that I see many beginner users trying to define Ops in the "shorteset" possible way, (e.g., by just defining the itypes/ otypes or using the helper decorator@as_op), and then quickly stumbling whenever they (or PyMC) asks for gradients or the inevitable shape/type errors crop up.

The previous notebooks did not show how to simply debug the Op outside of a PyMC model context, and I don't really see how you can get away without this knowledge. Again, if users know how to do this, they are unlikely to need this guide just for the recipe step of: to use jax put it in the "perform" method of an Op


View entire conversation on ReviewNB

Copy link
Member Author

Good catch


View entire conversation on ReviewNB

Copy link
Member Author

I think less than 5 in this notebook. I wouldn't add it yet. Maybe if a second notebook using JAX crops up


View entire conversation on ReviewNB

Copy link
Member Author

Yeah, I had removed it already, but missed the sentence!


View entire conversation on ReviewNB

Copy link
Member Author

Agree. I tried to make that point a bit more obvious. Let me know if you have a specific improvement in mind


View entire conversation on ReviewNB

Copy link
Member Author

Didn't know about that one, thanks. I have a bit more parameters now...


View entire conversation on ReviewNB

Copy link
Member Author

I can remove this section altogether. The reason I added was to dissuade users from just putting everything inside JAX Op's, sometimes simply because they know that library better, or are unaware that most of the same things can be done with Aesara.


View entire conversation on ReviewNB

Copy link
Member

@OriolAbril OriolAbril left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The extra installs thing is not working, I think it is because the docs I wrote on them originally were wrong and said to use myst_substitutions key in metadata json. It should be substitutions only (I updated the docs a couple days ago too)

@OriolAbril OriolAbril merged commit 66ea234 into pymc-devs:main Apr 14, 2022
OriolAbril added a commit that referenced this pull request Apr 24, 2022
* updated notebook

* run wit latest v2 version

* fixed graphviz redering

* partial resolution of pr comments

* Add guide on how to wrap a JAX function in a Aesara Op (#302)

* Add guide on how to wrap a JAX function in a Aesara Op

* Fix typos and reorder last sections

* Fix reference paths

* Simplify model and be more verbose in Op creation

* Address review comments

* Address more review comments

* Fix more Aesara docs references

* Use reference to Aesara index page

* Update myst_nbs/case_studies/wrapping_jax_function.myst.md

Co-authored-by: Oriol Abril-Pla <[email protected]>

* update twitter link (#314)

* update Gaussian Mixture Model example with `pm.NormalMixture` (#310)

* create truncated regression example

* delete truncated regression example from main branch

* create truncated regression example

* delete truncated regression example from main branch

* create truncated regression example

* delete truncated regression example from main branch

* initial commit

* update link to pull request in Authors section

* add tag `classification`

* update authorship verbs

* plot through xarray, using XrContinuousRV

* add x axis labels

* add xarray_einstats to watermark, and fix 'classification' as a tag, not a category

Co-authored-by: Benjamin T. Vincent <[email protected]>

* add `*.DS_Store` to `.gitignore` (#315)

Add *.DS_Store to .gitignore

Co-authored-by: Benjamin T. Vincent <[email protected]>

* updated notebook

* run wit latest v2 version

* fixed graphviz redering

* partial resolution of pr comments

* reverted plot labels

* restored plt labels

* fixed arviz labels

* fixed sampling warning

* fixed authors

Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Oriol Abril-Pla <[email protected]>
Co-authored-by: Abhishek K M <[email protected]>
Co-authored-by: Benjamin T. Vincent <[email protected]>
Co-authored-by: Benjamin T. Vincent <[email protected]>
@ricardoV94 ricardoV94 deleted the wrap_jax_function branch June 17, 2022 16:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
tree: advanced Internal use label
Projects
Development

Successfully merging this pull request may close these issues.

Add new notebook showcasing how to (un)wrap a JAX Op in Aesara
4 participants