Skip to content

Commit 71f5b3d

Browse files
committed
Simplify model and be more verbose in Op creation
1 parent d08b760 commit 71f5b3d

File tree

2 files changed

+432
-430
lines changed

2 files changed

+432
-430
lines changed

examples/case_studies/wrapping_jax_function.ipynb

+309-333
Large diffs are not rendered by default.

myst_nbs/case_studies/wrapping_jax_function.myst.md

+123-97
Original file line numberDiff line numberDiff line change
@@ -72,39 +72,56 @@ This notebook illustrates how we can implement a new Aesara {class}`~aesara.grap
7272

7373
+++
7474

75-
For illustration purposes, we will simulate data following a simple [Hidden Markov Model](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMM), with 4 possible latent states $S \in \{0, 1, 2, 3\}$ and normal emission likelihood.
75+
For illustration purposes, we will simulate data following a simple [Hidden Markov Model](https://en.wikipedia.org/wiki/Hidden_Markov_model) (HMM), with 3 possible latent states $S \in \{0, 1, 2\}$ and normal emission likelihood.
7676

77-
$$Y \sim \text{Normal}(S \cdot \text{signal}, (S + 1) \cdot \text{noise})$$
77+
$$Y \sim \text{Normal}((S + 1) \cdot \text{signal}, \text{noise})$$
7878

79-
Our HMM will have a fixed Binomial probability of decaying from a higher state $S_t$ to a lower state $S_{t+1}$ in every step,
79+
Our HMM will have a fixed Categorical probability $P$ of switching across states, which depends only on the last state
8080

81-
$$S_{t+1} \sim \text{Binomial}(S_t, \text{1-p_decay})$$
81+
$$S_{t+1} \sim \text{Categorical}(P_{S_t})$$
8282

83-
This implies a zero probability of going from a lower state $S_{t}$ to a higher state $S_{t+1}$.
83+
To complete our model, we assume a fixed probability $P_{t0}$ for each possible initial state $S_{t0}$,
8484

85-
To complete our model, we assume a fixed probability for each possible initial state $S_{t0}$,
86-
87-
$$S_{t0} \sim \text{Categorical}(P_{\{0, 1, 2, 3\}})$$
85+
$$S_{t0} \sim \text{Categorical}(P_{t0})$$
8886

8987

9088
### Simulating data
9189
Let's generate data according to this model! The first step is to set some values for the parameters in our model
9290

9391
```{code-cell} ipython3
9492
# Emission signal and noise parameters
95-
emission_signal_true = 0.75
96-
emission_noise_true = 0.05
93+
emission_signal_true = 1.15
94+
emission_noise_true = 0.15
95+
96+
p_initial_state_true = np.array([0.9, 0.09, 0.01])
97+
98+
# Probability of switching from state_t to state_t+1
99+
p_transition_true = np.array(
100+
[
101+
# 0, 1, 2
102+
[0.9, 0.09, 0.01], # 0
103+
[0.1, 0.8, 0.1], # 1
104+
[0.2, 0.1, 0.7], # 2
105+
]
106+
)
97107
98-
# Probability of starting in initial states [0, 1, 2, 3]
99-
p_initial_state_true = np.array([0.01, 0.04, 0.25, 0.7])
108+
# Confirm that we have defined valid probabilities
100109
assert np.isclose(np.sum(p_initial_state_true), 1)
110+
assert np.allclose(np.sum(p_transition_true, axis=-1), 1)
111+
```
101112

102-
p_decay_true = 0.125
113+
```{code-cell} ipython3
114+
# Let's compute the log of the probalitiy transition matrix for later use
115+
with np.errstate(divide="ignore"):
116+
logp_initial_state_true = np.log(p_initial_state_true)
117+
logp_transition_true = np.log(p_transition_true)
118+
119+
logp_initial_state_true, logp_transition_true
103120
```
104121

105122
```{code-cell} ipython3
106-
# We will observe 100 HMM processes, each with a total of 50 steps
107-
n_obs = 100
123+
# We will observe 70 HMM processes, each with a total of 50 steps
124+
n_obs = 70
108125
n_steps = 50
109126
```
110127

@@ -117,29 +134,31 @@ rng = np.random.default_rng(rng_seed)
117134
We write a helper function to generate a single HMM process and create our simulated data
118135

119136
```{code-cell} ipython3
120-
def simulate_hmm(p_initial_state, p_decay, emission_signal, emission_noise, n_steps, rng):
137+
def simulate_hmm(p_initial_state, p_transition, emission_signal, emission_noise, n_steps, rng):
121138
"""Generate hidden state and emission from our HMM model."""
122-
n_possible_states = len(p_initial_state)
123-
initial_state = rng.choice(range(n_possible_states), p=p_initial_state)
124139
125-
hidden_state = [initial_state]
126-
for step in range(n_steps):
127-
hidden_state.append(rng.binomial(n=hidden_state[-1], p=1 - p_decay))
140+
possible_states = np.array([0, 1, 2])
128141
129-
hidden_state = np.array(hidden_state)
142+
hidden_states = []
143+
initial_state = rng.choice(possible_states, p=p_initial_state)
144+
hidden_states.append(initial_state)
145+
for step in range(n_steps):
146+
new_hidden_state = rng.choice(possible_states, p=p_transition[hidden_states[-1]])
147+
hidden_states.append(new_hidden_state)
148+
hidden_states = np.array(hidden_states)
130149
131-
emission = rng.normal(
132-
hidden_state * emission_signal,
133-
(hidden_state + 1) * emission_noise,
150+
emissions = rng.normal(
151+
(hidden_states + 1) * emission_signal,
152+
emission_noise,
134153
)
135154
136-
return hidden_state, emission
155+
return hidden_states, emissions
137156
```
138157

139158
```{code-cell} ipython3
140159
single_hmm_hidden_state, single_hmm_emission = simulate_hmm(
141160
p_initial_state_true,
142-
p_decay_true,
161+
p_transition_true,
143162
emission_signal_true,
144163
emission_noise_true,
145164
n_steps,
@@ -156,7 +175,7 @@ emission_observed = []
156175
for i in range(n_obs):
157176
hidden_state, emission = simulate_hmm(
158177
p_initial_state_true,
159-
p_decay_true,
178+
p_transition_true,
160179
emission_signal_true,
161180
emission_noise_true,
162181
n_steps,
@@ -171,9 +190,11 @@ emission_observed = np.array(emission_observed)
171190

172191
```{code-cell} ipython3
173192
fig, ax = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
174-
for hidden_state, emission in zip(hidden_state_true, emission_observed):
175-
ax[0].plot(hidden_state, color="C0", alpha=0.1)
176-
ax[1].plot(emission, color="C0", alpha=0.1)
193+
# Plot first five hmm processes
194+
for i in range(4):
195+
ax[0].plot(hidden_state_true[i] + i * 0.02, color=f"C{i}", lw=2, alpha=0.4)
196+
ax[1].plot(emission_observed[i], color=f"C{i}", lw=2, alpha=0.4)
197+
ax[0].set_yticks([0, 1, 2])
177198
ax[0].set_ylabel("hidden state")
178199
ax[1].set_ylabel("observed emmission")
179200
ax[1].set_xlabel("step")
@@ -192,49 +213,9 @@ We will write a JAX function to compute the likelihood of our HMM model, margina
192213

193214
We will take advantage of JAX [scan](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) to obtain an efficient and differentiable log-likelihood, and the handy [vmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) to automatically vectorize this log-likelihood across multiple observed processes.
194215

195-
Before that, let us create some helper variables derived from our true parameters, that we can use to test our implementation.
196-
197-
```{code-cell} ipython3
198-
n_hidden_states = len(p_initial_state_true)
199-
n_hidden_states
200-
```
201-
202-
```{code-cell} ipython3
203-
logp_initial_state_true = np.log(p_initial_state_true)
204-
logp_initial_state_true
205-
```
206-
207-
```{code-cell} ipython3
208-
# Compute the probability transition matrix, of going from S_t to S_t+1
209-
# p[0, 0], is the probability of going from S_t=0 to S_t+1=0
210-
# p[1, 0], is the probability of going from S_t=1 to S_t+1=0
211-
# p[0, 1], is the probabilty of going from S_t=0 to S_t+1=1 (which is impossible)
212-
# p[3, 3], is the probability of going from S_t=3, to S_t+1=3
213-
possible_states = np.arange(n_hidden_states, dtype="int16")
214-
p_transition_true = (
215-
pm.logp(
216-
pm.Binomial.dist(n=possible_states, p=1 - p_decay_true),
217-
possible_states[:, None],
218-
)
219-
.T.exp()
220-
.eval()
221-
)
222-
p_transition_true
223-
```
224-
225-
```{code-cell} ipython3
226-
# Confirm that we have a valid transition probability matrix
227-
assert np.allclose(np.sum(p_transition_true, axis=-1), 1)
228-
```
229-
230-
```{code-cell} ipython3
231-
logp_transition_true = np.log(p_transition_true)
232-
logp_transition_true
233-
```
234-
235-
### Writing the JAX function
216+
+++
236217

237-
This is our core JAX function which computes the marginal log-likelihood of a single HMM process
218+
Our core JAX function computes the marginal log-likelihood of a single HMM process
238219

239220
```{code-cell} ipython3
240221
def hmm_logp(
@@ -246,14 +227,13 @@ def hmm_logp(
246227
):
247228
"""Compute the marginal log-likelihood of a single HMM process."""
248229
249-
# Caution: Using global variable for simplicity!
250-
hidden_states = np.arange(n_hidden_states)
230+
hidden_states = np.array([0, 1, 2])
251231
252232
# Compute log-likelihood of observed emissions for each (step x possible hidden state)
253233
logp_emission = jsp.stats.norm.logpdf(
254234
emission_observed[:, None],
255-
hidden_states * emission_signal,
256-
(hidden_states + 1) * emission_noise,
235+
(hidden_states + 1) * emission_signal,
236+
emission_noise,
257237
)
258238
259239
# We use the forward_algorithm to compute log_alpha(x_t) = logp(x_t, y_1:t)
@@ -362,9 +342,22 @@ For the `grad` we will create a second {class}`~aesara.graph.op.Op` that wraps o
362342

363343
```{code-cell} ipython3
364344
class HMMLogpOp(Op):
365-
def make_node(self, *inputs):
345+
def make_node(
346+
self,
347+
emission_observed,
348+
emission_signal,
349+
emission_noise,
350+
logp_initial_state,
351+
logp_transition,
352+
):
366353
# Convert our inputs to symbolic variables
367-
inputs = [at.as_tensor_variable(inp) for inp in inputs]
354+
inputs = [
355+
at.as_tensor_variable(emission_observed),
356+
at.as_tensor_variable(emission_signal),
357+
at.as_tensor_variable(emission_noise),
358+
at.as_tensor_variable(logp_initial_state),
359+
at.as_tensor_variable(logp_transition),
360+
]
368361
# Define the type of the output returned by the wrapped JAX function
369362
outputs = [at.dscalar()]
370363
return Apply(self, inputs, outputs)
@@ -380,13 +373,42 @@ class HMMLogpOp(Op):
380373
outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
381374
382375
def grad(self, inputs, output_gradients):
383-
gradients = hmm_logp_grad_op(*inputs)
384-
return [output_gradients[0] * gradient for gradient in gradients]
376+
(
377+
grad_wrt_emission_obsered,
378+
grad_wrt_emission_signal,
379+
grad_wrt_emission_noise,
380+
grad_wrt_logp_initial_state,
381+
grad_wrt_logp_transition,
382+
) = hmm_logp_grad_op(*inputs)
383+
# If there are inputs for which the gradients will never be needed or cannot
384+
# be computed, `aesara.gradient.grad_not_implemented` should be used as the
385+
# output gradient for that input.
386+
output_gradient = output_gradients[0]
387+
return [
388+
output_gradient * grad_wrt_emission_obsered,
389+
output_gradient * grad_wrt_emission_signal,
390+
output_gradient * grad_wrt_emission_noise,
391+
output_gradient * grad_wrt_logp_initial_state,
392+
output_gradient * grad_wrt_logp_transition,
393+
]
385394
386395
387396
class HMMLogpGradOp(Op):
388-
def make_node(self, *inputs):
389-
inputs = [at.as_tensor_variable(inp) for inp in inputs]
397+
def make_node(
398+
self,
399+
emission_observed,
400+
emission_signal,
401+
emission_noise,
402+
logp_initial_state,
403+
logp_transition,
404+
):
405+
inputs = [
406+
at.as_tensor_variable(emission_observed),
407+
at.as_tensor_variable(emission_signal),
408+
at.as_tensor_variable(emission_noise),
409+
at.as_tensor_variable(logp_initial_state),
410+
at.as_tensor_variable(logp_transition),
411+
]
390412
# This `Op` wil return one gradient per input. For simplicity, we assume
391413
# each output is of the same type as the input. In practice, you should use
392414
# the exact dtype to avoid overhead when saving the results of the computation
@@ -395,11 +417,18 @@ class HMMLogpGradOp(Op):
395417
return Apply(self, inputs, outputs)
396418
397419
def perform(self, node, inputs, outputs):
398-
# If there are inputs for which the gradients will never be needed or cannot
399-
# be computed, `aesara.gradient.grad_not_implemented` should be used
400-
results = jitted_vec_hmm_logp_grad(*inputs)
401-
for i, result in enumerate(results):
402-
outputs[i][0] = np.asarray(result, dtype=node.outputs[i].dtype)
420+
(
421+
grad_wrt_emission_obsered_result,
422+
grad_wrt_emission_signal_result,
423+
grad_wrt_emission_noise_result,
424+
grad_wrt_logp_initial_state_result,
425+
grad_wrt_logp_transition_result,
426+
) = jitted_vec_hmm_logp_grad(*inputs)
427+
outputs[0][0] = np.asarray(grad_wrt_emission_obsered_result, dtype=node.outputs[0].dtype)
428+
outputs[1][0] = np.asarray(grad_wrt_emission_signal_result, dtype=node.outputs[1].dtype)
429+
outputs[2][0] = np.asarray(grad_wrt_emission_noise_result, dtype=node.outputs[2].dtype)
430+
outputs[3][0] = np.asarray(grad_wrt_logp_initial_state_result, dtype=node.outputs[3].dtype)
431+
outputs[4][0] = np.asarray(grad_wrt_logp_transition_result, dtype=node.outputs[4].dtype)
403432
404433
405434
# Initialize our `Op`s
@@ -459,14 +488,11 @@ with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
459488
emission_signal = pm.Normal("emission_signal", 0, 1)
460489
emission_noise = pm.HalfNormal("emission_noise", 1)
461490
462-
p_initial_state = pm.Dirichlet("p_initial_state", np.ones(n_hidden_states))
491+
p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
463492
logp_initial_state = at.log(p_initial_state)
464493
465-
p_decay = pm.Beta("p_decay", 1, 1)
466-
logp_transition = pm.logp(
467-
pm.Binomial.dist(n=possible_states, p=1 - p_decay),
468-
possible_states[:, None],
469-
).T
494+
p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
495+
logp_transition = at.log(p_transition)
470496
471497
loglike = pm.Potential(
472498
"hmm_loglike",
@@ -490,7 +516,7 @@ pycharm:
490516
pm.model_to_graphviz(model)
491517
```
492518

493-
Before we start sampling, we check the logp of each variable at the model initial point
519+
Before we start sampling, we check the logp of each variable at the model initial point. Bugs tend to manifest themselves in the form of `nan` or `-inf` for the initial probabilities.
494520

495521
```{code-cell} ipython3
496522
initial_point = model.compute_initial_point()
@@ -535,7 +561,7 @@ true_values = [
535561
emission_signal_true,
536562
emission_noise_true,
537563
*p_initial_state_true,
538-
p_decay_true,
564+
*p_transition_true.ravel(),
539565
]
540566
541567
az.plot_posterior(idata, ref_val=true_values);
@@ -676,7 +702,7 @@ jitted_hmm_logp_value_and_grad = jax.jit(jax.value_and_grad(vec_hmm_logp, argnum
676702

677703
```{code-cell} ipython3
678704
class HmmLogpValueGradOp(Op):
679-
# By default only return the first output
705+
# By default only show the first output, and "hide" the other ones
680706
default_output = 0
681707
682708
def make_node(self, *inputs):

0 commit comments

Comments
 (0)