Skip to content

Commit 8e98244

Browse files
committed
Fix reference paths
1 parent 3e61646 commit 8e98244

File tree

2 files changed

+26
-37
lines changed

2 files changed

+26
-37
lines changed

examples/case_studies/wrapping_jax_function.ipynb

+13-20
Original file line numberDiff line numberDiff line change
@@ -82,17 +82,17 @@
8282
"\n",
8383
"PyMC uses the [Aesara](https://aesara.readthedocs.io/en/latest/) library to create and manipulate probabilistic graphs. Aesara is backend-agnostic, meaning it can make use of functions written in different languages or frameworks, including pure Python, NumPy, C, Cython, Numba, and [JAX](https://jax.readthedocs.io/en/latest/index.html). \n",
8484
"\n",
85-
"All that is needed is to encapsulate such function in a Aesara {class}`~aesara.graph.basic.Op`, which enforces a specific API regarding how inputs and outputs of pure \"operations\" should be handled. It also implements methods for optional extra functionality like symbolic shape inference and automatic differentiation. This is well covered in the Aesara [`Op`s documentation](https://aesara.readthedocs.io/en/latest/extending/op.html) and in our {ref}`blackbox_external_likelihood_numpy` pymc-example.\n",
85+
"All that is needed is to encapsulate such function in a Aesara {class}`~aesara.graph.op.Op`, which enforces a specific API regarding how inputs and outputs of pure \"operations\" should be handled. It also implements methods for optional extra functionality like symbolic shape inference and automatic differentiation. This is well covered in the Aesara [`Op`s documentation](https://aesara.readthedocs.io/en/latest/extending/op.html) and in our {ref}`blackbox_external_likelihood_numpy` pymc-example.\n",
8686
"\n",
8787
"More recently, Aesara became capable of compiling directly to some of these languages/frameworks, meaning that we can convert a complete Aesara graph into a JAX or NUMBA jitted function, whereas traditionally they could only be converted to Python or C.\n",
8888
"\n",
8989
"This has some interesting uses, such as sampling models defined in PyMC with pure JAX samplers, like those implemented in [NumPyro](https://num.pyro.ai/en/latest/index.html) or [BlackJax](https://github.com/blackjax-devs/blackjax). \n",
9090
"\n",
91-
"This notebook illustrates how we can implement a new Aesara {class}`~aesara.graph.basic.Op` that wraps a JAX function. \n",
91+
"This notebook illustrates how we can implement a new Aesara {class}`~aesara.graph.op.Op` that wraps a JAX function. \n",
9292
"\n",
9393
"### Outline\n",
9494
"\n",
95-
"1. We start in a similar path as that taken in the {ref}`blackbox_external_likelihood_numpy`, which wraps a NumPy function in a Aesara {class}`~aesara.graph.basic.Op`, this time wrapping a JAX jitted function instead. \n",
95+
"1. We start in a similar path as that taken in the {ref}`blackbox_external_likelihood_numpy`, which wraps a NumPy function in a Aesara {class}`~aesara.graph.op.Op`, this time wrapping a JAX jitted function instead. \n",
9696
"2. We then enable Aesara to \"unwrap\" the just wrapped JAX function, so that the whole graph can be compiled to JAX. We make use of this to sample our PyMC model via the JAX NumPyro NUTS sampler."
9797
]
9898
},
@@ -672,14 +672,14 @@
672672
"cell_type": "markdown",
673673
"metadata": {},
674674
"source": [
675-
"Now we are ready to wrap our JAX jitted function in a Aesara {class}`~aesara.graph.basic.Op`, that we can then use in our PyMC models. We recommend you check Aesara's official [`Op` documentation](https://aesara.readthedocs.io/en/latest/extending/op.html) if you want to understand it in more detail.\n",
675+
"Now we are ready to wrap our JAX jitted function in a Aesara {class}`~aesara.graph.op.Op`, that we can then use in our PyMC models. We recommend you check Aesara's official [`Op` documentation](https://aesara.readthedocs.io/en/latest/extending/op.html) if you want to understand it in more detail.\n",
676676
"\n",
677-
"In brief, we will inherit from {class}`~aesara.graph.basic.Op` and define the following methods:\n",
677+
"In brief, we will inherit from {class}`~aesara.graph.op.Op` and define the following methods:\n",
678678
"1. `make_node`: Creates an {class}`~aesara.graph.basic.Apply` node that holds together the symbolic inputs and outputs of our operation\n",
679679
"2. `perform`: Python code that returns the evaluation of our operation, given concrete input values\n",
680680
"3. `grad`: Returns a Aesara symbolic graph that represents the gradient expression of an output cost wrt to its inputs\n",
681681
"\n",
682-
"For the `grad` we will create a second {class}`~aesara.graph.basic.Op` that wraps our jitted grad version from above"
682+
"For the `grad` we will create a second {class}`~aesara.graph.op.Op` that wraps our jitted grad version from above"
683683
]
684684
},
685685
{
@@ -801,7 +801,7 @@
801801
}
802802
},
803803
"source": [
804-
"It's also helpful to confirm that the gradient of our {class}`~aesara.graph.basic.Op` can be requested via the Aesara `grad` interface"
804+
"It's also helpful to confirm that the gradient of our {class}`~aesara.graph.op.Op` can be requested via the Aesara `grad` interface"
805805
]
806806
},
807807
{
@@ -1246,7 +1246,7 @@
12461246
"cell_type": "markdown",
12471247
"metadata": {},
12481248
"source": [
1249-
"As mentioned in the beginning, Aesara can compile an entire graph to JAX. To do this, it needs to know how each {class}`~aesara.graph.basic.Op` in the graph can be converted to a JAX function. This can be done by {term}`dispatch <dispatching>` with {func}`aesara.link.jax.dispatch.jax_funcify`. Most of the default Aesara {class}`~aesara.graph.basic.Op`s already have such a dispatch function, but we will need to add a new one for our custom `HMMLogpOp`, as Aesara has never seen that before.\n",
1249+
"As mentioned in the beginning, Aesara can compile an entire graph to JAX. To do this, it needs to know how each {class}`~aesara.graph.op.Op` in the graph can be converted to a JAX function. This can be done by {term}`dispatch <dispatching>` with {func}`aesara.link.jax.dispatch.jax_funcify`. Most of the default Aesara {class}`~aesara.graph.op.Op`s already have such a dispatch function, but we will need to add a new one for our custom `HMMLogpOp`, as Aesara has never seen that before.\n",
12501250
"\n",
12511251
"For that we need a function which returns (another) JAX function, that performs the same computation as in our `perform` method. Fortunately, we started exactly with such function, so this amounts to 3 short lines of code."
12521252
]
@@ -1270,7 +1270,7 @@
12701270
"We do not return the jitted function, so that the entire Aesara graph can be jitted together after being converted to JAX.\n",
12711271
":::\n",
12721272
"\n",
1273-
"For a better understanding of {class}`~aesara.graph.basic.Op` JAX conversions, we recommend reading Aesara's [Adding JAX and Numba support for Ops guide](https://aesara.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html?highlight=JAX)\n",
1273+
"For a better understanding of {class}`~aesara.graph.op.Op` JAX conversions, we recommend reading Aesara's [Adding JAX and Numba support for Ops guide](https://aesara.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html?highlight=JAX)\n",
12741274
"\n",
12751275
"We can test that our conversion function is working properly by compiling a {func}`aesara.function` with `mode=\"JAX\"`:"
12761276
]
@@ -1307,7 +1307,7 @@
13071307
"cell_type": "markdown",
13081308
"metadata": {},
13091309
"source": [
1310-
"We can also compile a JAX function that computes the log probability of each variable in our PyMC model, similar to {meth}`~pymc.Model.point_logps`. We will use the helper method {met}`~pymc.model.Model.compile_fn`."
1310+
"We can also compile a JAX function that computes the log probability of each variable in our PyMC model, similar to {meth}`~pymc.Model.point_logps`. We will use the helper method {meth}`~pymc.Model.compile_fn`."
13111311
]
13121312
},
13131313
{
@@ -1335,13 +1335,6 @@
13351335
"model_logp_jax_fn(initial_point)"
13361336
]
13371337
},
1338-
{
1339-
"cell_type": "markdown",
1340-
"metadata": {},
1341-
"source": [
1342-
"We see a JAX warning that `float64` casting operations present in the original Aesara graph will be ignored. This is expected, as aesara is less strict about using a single float precision. This should be fine in our case, but if you need to understand what parts of the graph are trying to introduce `float64` variables, you can change the Aesara [warn_float64 flag](https://aesara.readthedocs.io/en/latest/library/config.html#config.warn_float64)"
1343-
]
1344-
},
13451338
{
13461339
"cell_type": "markdown",
13471340
"metadata": {},
@@ -1496,7 +1489,7 @@
14961489
"\n",
14971490
"Point 2 means your graphs are likely to perform better if written in Aesara. In general you don't have to worry about using specialized functions like `log1p` or `logsumexp`, as Aesara will be able to detect the equivalent naive expressions and replace them by their specialized counterparts. Importantly, you still benefit from these optimizations when your graph is later compiled to JAX.\n",
14981491
"\n",
1499-
"The catch is that Aesara cannot reason about JAX functions, and by association {class}`~aesara.graph.basic.Op`s that wrap them. This means that the larger the portion of the graph is \"hidden\" inside a JAX function, the less a user will benefit from Aesara's rewrite and debugging abilities.\n",
1492+
"The catch is that Aesara cannot reason about JAX functions, and by association {class}`~aesara.graph.op.Op`s that wrap them. This means that the larger the portion of the graph is \"hidden\" inside a JAX function, the less a user will benefit from Aesara's rewrite and debugging abilities.\n",
15001493
"\n",
15011494
"Point 3 is more important for library developers. It is the main reason why PyMC developers opted to use Aesara (and before that, its predecessor Theano) as its backend. Many of the user-facing utilities provided by PyMC rely on the ability to easily parse and manipulate Aesara graphs."
15021495
]
@@ -1512,11 +1505,11 @@
15121505
"cell_type": "markdown",
15131506
"metadata": {},
15141507
"source": [
1515-
"We had to create two {class}`~aesara.graph.basic.Op`s, one for the function we cared about and a separate one for its gradients. However, JAX provides a `value_and_grad` utility that can return both the value of a function and its gradients. We can do something similar and get away with a single {class}`~aesara.graph.basic.Op` if we are clever about it.\n",
1508+
"We had to create two {class}`~aesara.graph.op.Op`s, one for the function we cared about and a separate one for its gradients. However, JAX provides a `value_and_grad` utility that can return both the value of a function and its gradients. We can do something similar and get away with a single {class}`~aesara.graph.op.Op` if we are clever about it.\n",
15161509
"\n",
15171510
"By doing this we can (potentially) save memory and reuse computation that is shared between the function and its gradients. This may be relevant when working with very large JAX functions.\n",
15181511
"\n",
1519-
"Note that this is only useful if you are interested in taking gradients with respect to your {class}`~aesara.graph.basic.Op` using Aesara. If your endgoal is to compile your graph to JAX, and only then take the gradients (as NumPyro does), then it's better to use the first approach. You don't even need to implement the `grad` method and associated {class}`~aesara.graph.basic.Op` in that case."
1512+
"Note that this is only useful if you are interested in taking gradients with respect to your {class}`~aesara.graph.op.Op` using Aesara. If your endgoal is to compile your graph to JAX, and only then take the gradients (as NumPyro does), then it's better to use the first approach. You don't even need to implement the `grad` method and associated {class}`~aesara.graph.op.Op` in that case."
15201513
]
15211514
},
15221515
{

0 commit comments

Comments
 (0)