Skip to content

Commit 04f05ff

Browse files
Fix errors in notebook due to version update (#630)
* Fix errors in notebook due to version update * Run pre-commit * Fix python version
1 parent 6f2eb44 commit 04f05ff

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

examples/howto/wrapping_jax_function.ipynb

+10-10
Large diffs are not rendered by default.

examples/howto/wrapping_jax_function.myst.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ x_grad_wrt_emission_signal.eval()
480480
We are now ready to make inferences about our HMM model with PyMC. We will define priors for each model parameter and use {class}`~pymc.Potential` to add the joint log-likelihood term to our model.
481481

482482
```{code-cell} ipython3
483-
with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
483+
with pm.Model() as model:
484484
emission_signal = pm.Normal("emission_signal", 0, 1)
485485
emission_noise = pm.HalfNormal("emission_noise", 1)
486486
@@ -515,7 +515,7 @@ pm.model_to_graphviz(model)
515515
Before we start sampling, we check the logp of each variable at the model initial point. Bugs tend to manifest themselves in the form of `nan` or `-inf` for the initial probabilities.
516516

517517
```{code-cell} ipython3
518-
initial_point = model.compute_initial_point()
518+
initial_point = model.initial_point()
519519
initial_point
520520
```
521521

@@ -604,7 +604,7 @@ jax_fn()
604604
We can also compile a JAX function that computes the log probability of each variable in our PyMC model, similar to {meth}`~pymc.Model.point_logps`. We will use the helper method {meth}`~pymc.Model.compile_fn`.
605605

606606
```{code-cell} ipython3
607-
model_logp_jax_fn = model.compile_fn(model.logpt(sum=False), mode="JAX")
607+
model_logp_jax_fn = model.compile_fn(model.logp(sum=False), mode="JAX")
608608
model_logp_jax_fn(initial_point)
609609
```
610610

@@ -622,7 +622,7 @@ Now that we know our model logp can be entirely compiled to JAX, we can use the
622622

623623
```{code-cell} ipython3
624624
with model:
625-
idata_numpyro = pm.sampling_jax.sample_numpyro_nuts(chains=2, progress_bar=False)
625+
idata_numpyro = pm.sampling_jax.sample_numpyro_nuts(chains=2, progressbar=False)
626626
```
627627

628628
```{code-cell} ipython3

0 commit comments

Comments
 (0)