Skip to content

Commit 5ec4dbb

Browse files
committed
Move Aesara samplers further below.
1 parent dab5862 commit 5ec4dbb

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

blog/v4_announcement.ipynb

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -406,21 +406,7 @@
406406
"source": [
407407
"That's a 3x speed-up -- for a single-line code change (although we've seen speed-ups much more impressive than that in the 20x range)! And this is just running things on the CPU, we can just as easily run this on the GPU where we saw even more impressive speed-ups (especially as we scale the data).\n",
408408
"\n",
409-
"Again, for a more proper benchmark that also compares this to Stan, see [this blog post](https://martiningram.github.io/mcmc-comparison/).\n",
410-
"\n",
411-
"#### The Future: Samplers written in `aesara`\n",
412-
"\n",
413-
"While this current approach is already quite exciting, we can take this one step further. The setup we showed above takes the model logp graph (represented in `aesara`) and compiles it to `JAX`. The resulting `JAX` function we can then call from a sampler written in directly in `JAX` (i.e. `numpyro` or `blackjax`).\n",
414-
"\n",
415-
"While lightning fast, this is suboptimal for two reasons:\n",
416-
"1. For new backends, like `numba`, we would need to rewrite the sampler also in `numba`.\n",
417-
"2. While we get low-level optimizations from `JAX` on the logp+sampler JAX-graph, we do not get any high-level optimizations, which is what `aesara` is great at, because `aesara` does not see the sampler.\n",
418-
"\n",
419-
"With [`aehmc`](https://www.github.com/aesara-devs/aehmc) and [`aemcmc`](https://www.github.com/aesara-devs/aemcmc) the `aesara` devs are developing a library of samplers *written in `aesara`*. That way, our model logp, consisting out of `aesara` `Ops` can then be combined with the sampler logic, now also consisting out of `aesara` `Ops`, and form one big `aesara` graph.\n",
420-
"\n",
421-
"On that big graph containing model *and* sampler, `aesara` can the do high-level optimizations to get a more efficient graph representation. In a next step it can then compile it to whatever backend we want: `JAX`, `numba`, `C`, or whatever other backend we add in the future.\n",
422-
"\n",
423-
"If you think this is interesting, definitely check out these packages and consider contributing, this is where the next round of innovation will come from!"
409+
"Again, for a more proper benchmark that also compares this to Stan, see [this blog post](https://martiningram.github.io/mcmc-comparison/)."
424410
]
425411
},
426412
{
@@ -996,7 +982,7 @@
996982
},
997983
{
998984
"cell_type": "markdown",
999-
"id": "795c5687",
985+
"id": "fadec72c",
1000986
"metadata": {},
1001987
"source": [
1002988
"## New website\n",
@@ -1011,6 +997,22 @@
1011997
"source": [
1012998
"## A Look Towards the Future\n",
1013999
"\n",
1000+
"### Samplers written in Aesara\n",
1001+
"\n",
1002+
"Above we described how with the new JAX backend we can run the model *and* the sampler as one big JAX graph, without any Python call-overadh. While this is already quite exciting, we can take this one step further. The setup we showed above takes the model logp graph (represented in `aesara`) and compiles it to `JAX`. The resulting `JAX` function we can then call from a sampler written in directly in `JAX` (i.e. `numpyro` or `blackjax`).\n",
1003+
"\n",
1004+
"While lightning fast, this is suboptimal for two reasons:\n",
1005+
"1. For new backends, like `numba`, we would need to rewrite the sampler also in `numba`.\n",
1006+
"2. While we get low-level optimizations from `JAX` on the logp+sampler JAX-graph, we do not get any high-level optimizations, which is what `aesara` is great at, because `aesara` does not see the sampler.\n",
1007+
"\n",
1008+
"With [`aehmc`](https://www.github.com/aesara-devs/aehmc) and [`aemcmc`](https://www.github.com/aesara-devs/aemcmc) the `aesara` devs are developing a library of samplers *written in `aesara`*. That way, our model logp, consisting out of `aesara` `Ops` can then be combined with the sampler logic, now also consisting out of `aesara` `Ops`, and form one big `aesara` graph.\n",
1009+
"\n",
1010+
"On that big graph containing model *and* sampler, `aesara` can the do high-level optimizations to get a more efficient graph representation. In a next step it can then compile it to whatever backend we want: `JAX`, `numba`, `C`, or whatever other backend we add in the future.\n",
1011+
"\n",
1012+
"If you think this is interesting, definitely check out these packages and consider contributing, this is where the next round of innovation will come from!\n",
1013+
"\n",
1014+
"### Automatic model reparameterizations\n",
1015+
"\n",
10141016
"As mentioned in the beginning, `aesara` is a unique library in the PyData ecosystem as it is the only one that provides a static, mutable computation graph. Having direct access to this computation graph allows for many interesting features:\n",
10151017
"* graph optimizations like `log(exp(x)) -> x`\n",
10161018
"* symbolic rewrites like `N(0, 1) + a` -> `N(a, 1)`\n",

0 commit comments

Comments
 (0)