@@ -72,39 +72,56 @@ This notebook illustrates how we can implement a new Aesara {class}`~aesara.grap
72
72
73
73
+++
74
74
75
- For illustration purposes, we will simulate data following a simple [ Hidden Markov Model] ( https://en.wikipedia.org/wiki/Hidden_Markov_model ) (HMM), with 4 possible latent states $S \in \{ 0, 1, 2, 3 \} $ and normal emission likelihood.
75
+ For illustration purposes, we will simulate data following a simple [ Hidden Markov Model] ( https://en.wikipedia.org/wiki/Hidden_Markov_model ) (HMM), with 3 possible latent states $S \in \{ 0, 1, 2\} $ and normal emission likelihood.
76
76
77
- $$ Y \sim \text{Normal}(S \cdot \text{signal}, (S + 1) \cdot \text{noise}) $$
77
+ $$ Y \sim \text{Normal}((S + 1) \cdot \text{signal}, \text{noise}) $$
78
78
79
- Our HMM will have a fixed Binomial probability of decaying from a higher state $S_t$ to a lower state $S _ {t+1}$ in every step,
79
+ Our HMM will have a fixed Categorical probability $P$ of switching across states, which depends only on the last state
80
80
81
- $$ S_{t+1} \sim \text{Binomial}(S_t, \text{1-p_decay }) $$
81
+ $$ S_{t+1} \sim \text{Categorical}(P_{S_t }) $$
82
82
83
- This implies a zero probability of going from a lower state $S _ {t }$ to a higher state $S_ {t+1}$.
83
+ To complete our model, we assume a fixed probability $P _ {t0 }$ for each possible initial state $S_ {t0}$,
84
84
85
- To complete our model, we assume a fixed probability for each possible initial state $S_ {t0}$,
86
-
87
- $$ S_{t0} \sim \text{Categorical}(P_{\{0, 1, 2, 3\}}) $$
85
+ $$ S_{t0} \sim \text{Categorical}(P_{t0}) $$
88
86
89
87
90
88
### Simulating data
91
89
Let's generate data according to this model! The first step is to set some values for the parameters in our model
92
90
93
91
``` {code-cell} ipython3
94
92
# Emission signal and noise parameters
95
- emission_signal_true = 0.75
96
- emission_noise_true = 0.05
93
+ emission_signal_true = 1.15
94
+ emission_noise_true = 0.15
95
+
96
+ p_initial_state_true = np.array([0.9, 0.09, 0.01])
97
+
98
+ # Probability of switching from state_t to state_t+1
99
+ p_transition_true = np.array(
100
+ [
101
+ # 0, 1, 2
102
+ [0.9, 0.09, 0.01], # 0
103
+ [0.1, 0.8, 0.1], # 1
104
+ [0.2, 0.1, 0.7], # 2
105
+ ]
106
+ )
97
107
98
- # Probability of starting in initial states [0, 1, 2, 3]
99
- p_initial_state_true = np.array([0.01, 0.04, 0.25, 0.7])
108
+ # Confirm that we have defined valid probabilities
100
109
assert np.isclose(np.sum(p_initial_state_true), 1)
110
+ assert np.allclose(np.sum(p_transition_true, axis=-1), 1)
111
+ ```
101
112
102
- p_decay_true = 0.125
113
+ ``` {code-cell} ipython3
114
+ # Let's compute the log of the probalitiy transition matrix for later use
115
+ with np.errstate(divide="ignore"):
116
+ logp_initial_state_true = np.log(p_initial_state_true)
117
+ logp_transition_true = np.log(p_transition_true)
118
+
119
+ logp_initial_state_true, logp_transition_true
103
120
```
104
121
105
122
``` {code-cell} ipython3
106
- # We will observe 100 HMM processes, each with a total of 50 steps
107
- n_obs = 100
123
+ # We will observe 70 HMM processes, each with a total of 50 steps
124
+ n_obs = 70
108
125
n_steps = 50
109
126
```
110
127
@@ -117,29 +134,31 @@ rng = np.random.default_rng(rng_seed)
117
134
We write a helper function to generate a single HMM process and create our simulated data
118
135
119
136
``` {code-cell} ipython3
120
- def simulate_hmm(p_initial_state, p_decay , emission_signal, emission_noise, n_steps, rng):
137
+ def simulate_hmm(p_initial_state, p_transition , emission_signal, emission_noise, n_steps, rng):
121
138
"""Generate hidden state and emission from our HMM model."""
122
- n_possible_states = len(p_initial_state)
123
- initial_state = rng.choice(range(n_possible_states), p=p_initial_state)
124
139
125
- hidden_state = [initial_state]
126
- for step in range(n_steps):
127
- hidden_state.append(rng.binomial(n=hidden_state[-1], p=1 - p_decay))
140
+ possible_states = np.array([0, 1, 2])
128
141
129
- hidden_state = np.array(hidden_state)
142
+ hidden_states = []
143
+ initial_state = rng.choice(possible_states, p=p_initial_state)
144
+ hidden_states.append(initial_state)
145
+ for step in range(n_steps):
146
+ new_hidden_state = rng.choice(possible_states, p=p_transition[hidden_states[-1]])
147
+ hidden_states.append(new_hidden_state)
148
+ hidden_states = np.array(hidden_states)
130
149
131
- emission = rng.normal(
132
- hidden_state * emission_signal,
133
- (hidden_state + 1) * emission_noise,
150
+ emissions = rng.normal(
151
+ (hidden_states + 1) * emission_signal,
152
+ emission_noise,
134
153
)
135
154
136
- return hidden_state, emission
155
+ return hidden_states, emissions
137
156
```
138
157
139
158
``` {code-cell} ipython3
140
159
single_hmm_hidden_state, single_hmm_emission = simulate_hmm(
141
160
p_initial_state_true,
142
- p_decay_true ,
161
+ p_transition_true ,
143
162
emission_signal_true,
144
163
emission_noise_true,
145
164
n_steps,
@@ -156,7 +175,7 @@ emission_observed = []
156
175
for i in range(n_obs):
157
176
hidden_state, emission = simulate_hmm(
158
177
p_initial_state_true,
159
- p_decay_true ,
178
+ p_transition_true ,
160
179
emission_signal_true,
161
180
emission_noise_true,
162
181
n_steps,
@@ -171,9 +190,11 @@ emission_observed = np.array(emission_observed)
171
190
172
191
``` {code-cell} ipython3
173
192
fig, ax = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
174
- for hidden_state, emission in zip(hidden_state_true, emission_observed):
175
- ax[0].plot(hidden_state, color="C0", alpha=0.1)
176
- ax[1].plot(emission, color="C0", alpha=0.1)
193
+ # Plot first five hmm processes
194
+ for i in range(4):
195
+ ax[0].plot(hidden_state_true[i] + i * 0.02, color=f"C{i}", lw=2, alpha=0.4)
196
+ ax[1].plot(emission_observed[i], color=f"C{i}", lw=2, alpha=0.4)
197
+ ax[0].set_yticks([0, 1, 2])
177
198
ax[0].set_ylabel("hidden state")
178
199
ax[1].set_ylabel("observed emmission")
179
200
ax[1].set_xlabel("step")
@@ -192,49 +213,9 @@ We will write a JAX function to compute the likelihood of our HMM model, margina
192
213
193
214
We will take advantage of JAX [ scan] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html ) to obtain an efficient and differentiable log-likelihood, and the handy [ vmap] ( https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap ) to automatically vectorize this log-likelihood across multiple observed processes.
194
215
195
- Before that, let us create some helper variables derived from our true parameters, that we can use to test our implementation.
196
-
197
- ``` {code-cell} ipython3
198
- n_hidden_states = len(p_initial_state_true)
199
- n_hidden_states
200
- ```
201
-
202
- ``` {code-cell} ipython3
203
- logp_initial_state_true = np.log(p_initial_state_true)
204
- logp_initial_state_true
205
- ```
206
-
207
- ``` {code-cell} ipython3
208
- # Compute the probability transition matrix, of going from S_t to S_t+1
209
- # p[0, 0], is the probability of going from S_t=0 to S_t+1=0
210
- # p[1, 0], is the probability of going from S_t=1 to S_t+1=0
211
- # p[0, 1], is the probabilty of going from S_t=0 to S_t+1=1 (which is impossible)
212
- # p[3, 3], is the probability of going from S_t=3, to S_t+1=3
213
- possible_states = np.arange(n_hidden_states, dtype="int16")
214
- p_transition_true = (
215
- pm.logp(
216
- pm.Binomial.dist(n=possible_states, p=1 - p_decay_true),
217
- possible_states[:, None],
218
- )
219
- .T.exp()
220
- .eval()
221
- )
222
- p_transition_true
223
- ```
224
-
225
- ``` {code-cell} ipython3
226
- # Confirm that we have a valid transition probability matrix
227
- assert np.allclose(np.sum(p_transition_true, axis=-1), 1)
228
- ```
229
-
230
- ``` {code-cell} ipython3
231
- logp_transition_true = np.log(p_transition_true)
232
- logp_transition_true
233
- ```
234
-
235
- ### Writing the JAX function
216
+ +++
236
217
237
- This is our core JAX function which computes the marginal log-likelihood of a single HMM process
218
+ Our core JAX function computes the marginal log-likelihood of a single HMM process
238
219
239
220
``` {code-cell} ipython3
240
221
def hmm_logp(
@@ -246,14 +227,13 @@ def hmm_logp(
246
227
):
247
228
"""Compute the marginal log-likelihood of a single HMM process."""
248
229
249
- # Caution: Using global variable for simplicity!
250
- hidden_states = np.arange(n_hidden_states)
230
+ hidden_states = np.array([0, 1, 2])
251
231
252
232
# Compute log-likelihood of observed emissions for each (step x possible hidden state)
253
233
logp_emission = jsp.stats.norm.logpdf(
254
234
emission_observed[:, None],
255
- hidden_states * emission_signal,
256
- (hidden_states + 1) * emission_noise,
235
+ ( hidden_states + 1) * emission_signal,
236
+ emission_noise,
257
237
)
258
238
259
239
# We use the forward_algorithm to compute log_alpha(x_t) = logp(x_t, y_1:t)
@@ -362,9 +342,22 @@ For the `grad` we will create a second {class}`~aesara.graph.op.Op` that wraps o
362
342
363
343
``` {code-cell} ipython3
364
344
class HMMLogpOp(Op):
365
- def make_node(self, *inputs):
345
+ def make_node(
346
+ self,
347
+ emission_observed,
348
+ emission_signal,
349
+ emission_noise,
350
+ logp_initial_state,
351
+ logp_transition,
352
+ ):
366
353
# Convert our inputs to symbolic variables
367
- inputs = [at.as_tensor_variable(inp) for inp in inputs]
354
+ inputs = [
355
+ at.as_tensor_variable(emission_observed),
356
+ at.as_tensor_variable(emission_signal),
357
+ at.as_tensor_variable(emission_noise),
358
+ at.as_tensor_variable(logp_initial_state),
359
+ at.as_tensor_variable(logp_transition),
360
+ ]
368
361
# Define the type of the output returned by the wrapped JAX function
369
362
outputs = [at.dscalar()]
370
363
return Apply(self, inputs, outputs)
@@ -380,13 +373,42 @@ class HMMLogpOp(Op):
380
373
outputs[0][0] = np.asarray(result, dtype=node.outputs[0].dtype)
381
374
382
375
def grad(self, inputs, output_gradients):
383
- gradients = hmm_logp_grad_op(*inputs)
384
- return [output_gradients[0] * gradient for gradient in gradients]
376
+ (
377
+ grad_wrt_emission_obsered,
378
+ grad_wrt_emission_signal,
379
+ grad_wrt_emission_noise,
380
+ grad_wrt_logp_initial_state,
381
+ grad_wrt_logp_transition,
382
+ ) = hmm_logp_grad_op(*inputs)
383
+ # If there are inputs for which the gradients will never be needed or cannot
384
+ # be computed, `aesara.gradient.grad_not_implemented` should be used as the
385
+ # output gradient for that input.
386
+ output_gradient = output_gradients[0]
387
+ return [
388
+ output_gradient * grad_wrt_emission_obsered,
389
+ output_gradient * grad_wrt_emission_signal,
390
+ output_gradient * grad_wrt_emission_noise,
391
+ output_gradient * grad_wrt_logp_initial_state,
392
+ output_gradient * grad_wrt_logp_transition,
393
+ ]
385
394
386
395
387
396
class HMMLogpGradOp(Op):
388
- def make_node(self, *inputs):
389
- inputs = [at.as_tensor_variable(inp) for inp in inputs]
397
+ def make_node(
398
+ self,
399
+ emission_observed,
400
+ emission_signal,
401
+ emission_noise,
402
+ logp_initial_state,
403
+ logp_transition,
404
+ ):
405
+ inputs = [
406
+ at.as_tensor_variable(emission_observed),
407
+ at.as_tensor_variable(emission_signal),
408
+ at.as_tensor_variable(emission_noise),
409
+ at.as_tensor_variable(logp_initial_state),
410
+ at.as_tensor_variable(logp_transition),
411
+ ]
390
412
# This `Op` wil return one gradient per input. For simplicity, we assume
391
413
# each output is of the same type as the input. In practice, you should use
392
414
# the exact dtype to avoid overhead when saving the results of the computation
@@ -395,11 +417,18 @@ class HMMLogpGradOp(Op):
395
417
return Apply(self, inputs, outputs)
396
418
397
419
def perform(self, node, inputs, outputs):
398
- # If there are inputs for which the gradients will never be needed or cannot
399
- # be computed, `aesara.gradient.grad_not_implemented` should be used
400
- results = jitted_vec_hmm_logp_grad(*inputs)
401
- for i, result in enumerate(results):
402
- outputs[i][0] = np.asarray(result, dtype=node.outputs[i].dtype)
420
+ (
421
+ grad_wrt_emission_obsered_result,
422
+ grad_wrt_emission_signal_result,
423
+ grad_wrt_emission_noise_result,
424
+ grad_wrt_logp_initial_state_result,
425
+ grad_wrt_logp_transition_result,
426
+ ) = jitted_vec_hmm_logp_grad(*inputs)
427
+ outputs[0][0] = np.asarray(grad_wrt_emission_obsered_result, dtype=node.outputs[0].dtype)
428
+ outputs[1][0] = np.asarray(grad_wrt_emission_signal_result, dtype=node.outputs[1].dtype)
429
+ outputs[2][0] = np.asarray(grad_wrt_emission_noise_result, dtype=node.outputs[2].dtype)
430
+ outputs[3][0] = np.asarray(grad_wrt_logp_initial_state_result, dtype=node.outputs[3].dtype)
431
+ outputs[4][0] = np.asarray(grad_wrt_logp_transition_result, dtype=node.outputs[4].dtype)
403
432
404
433
405
434
# Initialize our `Op`s
@@ -459,14 +488,11 @@ with pm.Model(rng_seeder=int(rng.integers(2**30))) as model:
459
488
emission_signal = pm.Normal("emission_signal", 0, 1)
460
489
emission_noise = pm.HalfNormal("emission_noise", 1)
461
490
462
- p_initial_state = pm.Dirichlet("p_initial_state", np.ones(n_hidden_states ))
491
+ p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3 ))
463
492
logp_initial_state = at.log(p_initial_state)
464
493
465
- p_decay = pm.Beta("p_decay", 1, 1)
466
- logp_transition = pm.logp(
467
- pm.Binomial.dist(n=possible_states, p=1 - p_decay),
468
- possible_states[:, None],
469
- ).T
494
+ p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
495
+ logp_transition = at.log(p_transition)
470
496
471
497
loglike = pm.Potential(
472
498
"hmm_loglike",
@@ -490,7 +516,7 @@ pycharm:
490
516
pm.model_to_graphviz(model)
491
517
```
492
518
493
- Before we start sampling, we check the logp of each variable at the model initial point
519
+ Before we start sampling, we check the logp of each variable at the model initial point. Bugs tend to manifest themselves in the form of ` nan ` or ` -inf ` for the initial probabilities.
494
520
495
521
``` {code-cell} ipython3
496
522
initial_point = model.compute_initial_point()
@@ -535,7 +561,7 @@ true_values = [
535
561
emission_signal_true,
536
562
emission_noise_true,
537
563
*p_initial_state_true,
538
- p_decay_true ,
564
+ *p_transition_true.ravel() ,
539
565
]
540
566
541
567
az.plot_posterior(idata, ref_val=true_values);
@@ -676,7 +702,7 @@ jitted_hmm_logp_value_and_grad = jax.jit(jax.value_and_grad(vec_hmm_logp, argnum
676
702
677
703
``` {code-cell} ipython3
678
704
class HmmLogpValueGradOp(Op):
679
- # By default only return the first output
705
+ # By default only show the first output, and "hide" the other ones
680
706
default_output = 0
681
707
682
708
def make_node(self, *inputs):
0 commit comments