Skip to content

Commit 0ba2744

Browse files
committed
Address more review comments
1 parent 3bb955d commit 0ba2744

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

examples/case_studies/wrapping_jax_function.ipynb

+4-4
Original file line numberDiff line numberDiff line change
@@ -1223,7 +1223,7 @@
12231223
"We do not return the jitted function, so that the entire Aesara graph can be jitted together after being converted to JAX.\n",
12241224
":::\n",
12251225
"\n",
1226-
"For a better understanding of {class}`~aesara.graph.op.Op` JAX conversions, we recommend reading Aesara's [Adding JAX and Numba support for Ops guide](https://aesara.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html?highlight=JAX)\n",
1226+
"For a better understanding of {class}`~aesara.graph.op.Op` JAX conversions, we recommend reading Aesara's {doc}`Adding JAX and Numba support for Ops guide <aesara:extending/creating_a_numba_jax_op>`.\n",
12271227
"\n",
12281228
"We can test that our conversion function is working properly by compiling a {func}`aesara.function` with `mode=\"JAX\"`:"
12291229
]
@@ -1308,7 +1308,7 @@
13081308
"cell_type": "markdown",
13091309
"metadata": {},
13101310
"source": [
1311-
"Now that we know our model logp can be entirely compiled to JAX, we can use the handy {mod}`pymc.sampling_jax` module to sample our model using the pure JAX sampler implemented in NumPyro."
1311+
"Now that we know our model logp can be entirely compiled to JAX, we can use the handy {func}`pymc.sampling_jax.sample_numpyro_nuts` to sample our model using the pure JAX sampler implemented in NumPyro."
13121312
]
13131313
},
13141314
{
@@ -1436,7 +1436,7 @@
14361436
"\n",
14371437
"There are, however, some of advantages to working with Aesara:\n",
14381438
"\n",
1439-
"1. Aesara graphs are considerably easier to [inspect and debug](https://aesara.readthedocs.io/en/latest/tutorial/debug_faq.html) than JAX functions\n",
1439+
"1. Aesara graphs are considerably easier to {ref}`inspect and debug <aesara:debug_faq>` than JAX functions\n",
14401440
"2. Aesara has clever [optimization and stabilization routines](https://aesara.readthedocs.io/en/latest/optimizations.html) that are not possible or implemented in JAX\n",
14411441
"3. Aesara graphs can be easily [manipulated after creation](https://aesara.readthedocs.io/en/latest/extending/graph_rewriting.html#graph-rewriting)\n",
14421442
"\n",
@@ -1641,7 +1641,7 @@
16411641
"pygments_lexer": "ipython3",
16421642
"version": "3.10.2"
16431643
},
1644-
"myst_substitutions": {
1644+
"substitutions": {
16451645
"extra_dependencies": "jax numpyro"
16461646
},
16471647
"toc": {

myst_nbs/case_studies/wrapping_jax_function.myst.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -583,7 +583,7 @@ def hmm_logp_dispatch(op, **kwargs):
583583
We do not return the jitted function, so that the entire Aesara graph can be jitted together after being converted to JAX.
584584
:::
585585

586-
For a better understanding of {class}`~aesara.graph.op.Op` JAX conversions, we recommend reading Aesara's [Adding JAX and Numba support for Ops guide](https://aesara.readthedocs.io/en/latest/extending/creating_a_numba_jax_op.html?highlight=JAX)
586+
For a better understanding of {class}`~aesara.graph.op.Op` JAX conversions, we recommend reading Aesara's {doc}`Adding JAX and Numba support for Ops guide <aesara:extending/creating_a_numba_jax_op>`.
587587

588588
We can test that our conversion function is working properly by compiling a {func}`aesara.function` with `mode="JAX"`:
589589

@@ -616,7 +616,7 @@ We include a {ref}`short discussion <aesara_vs_jax>` at the end of this document
616616

617617
+++
618618

619-
Now that we know our model logp can be entirely compiled to JAX, we can use the handy {mod}`pymc.sampling_jax` module to sample our model using the pure JAX sampler implemented in NumPyro.
619+
Now that we know our model logp can be entirely compiled to JAX, we can use the handy {func}`pymc.sampling_jax.sample_numpyro_nuts` to sample our model using the pure JAX sampler implemented in NumPyro.
620620

621621
```{code-cell} ipython3
622622
with model:
@@ -662,7 +662,7 @@ Like JAX, Aesara has the goal of mimicking the NumPy and Scipy APIs, so that wri
662662

663663
There are, however, some of advantages to working with Aesara:
664664

665-
1. Aesara graphs are considerably easier to [inspect and debug](https://aesara.readthedocs.io/en/latest/tutorial/debug_faq.html) than JAX functions
665+
1. Aesara graphs are considerably easier to {ref}`inspect and debug <aesara:debug_faq>` than JAX functions
666666
2. Aesara has clever [optimization and stabilization routines](https://aesara.readthedocs.io/en/latest/optimizations.html) that are not possible or implemented in JAX
667667
3. Aesara graphs can be easily [manipulated after creation](https://aesara.readthedocs.io/en/latest/extending/graph_rewriting.html#graph-rewriting)
668668

0 commit comments

Comments
 (0)