-
-
Notifications
You must be signed in to change notification settings - Fork 58
Support HMM via marginalization of DiscreteMarkovChain #257
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Is it using the Viterbi algorithm? |
Currently just the forward algorithm* to compute the logp *It's not pure forward because we are computing and storing p(data | state) for all data-state pairs outside the scan over state transition probabilities. We should be O(N^2*T) on compute, but we're not maximally efficient on memory. If I understand well, viterbi just gives the most probable sequence of hidden states in a maximum likelihood setting? We should be able to back that out of the posterior pretty easily. You'll need to school me if I'm over simplifying. |
yes Viterbi gives the posterior mode - but you are marginalizing the state to compute the likelihood here right? |
Yes, but to be precise: to compute the logp of any dependent variables, which may be observed/unobserved or a mix. |
Seems like our "clever" approach is not correct. We need to combine the emission probabilities as we compute the state probabilities iteratively. I thought we could factor them out but it doesn't seem to be the case. |
I added the example from this youtube vid as a test case, so we can get to a solution. I'm in the process of refactoring the logp function to compute alpha correctly, but it's typically a nested loop. Here's numpy code:
I'm trying to think how we can vectorize the inner loop, open to suggestions. Nvm figured this out, it looks like:
|
edfa63c
to
3398fff
Compare
958ada4
to
f97dd6d
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm happy where it's landed. I wish we solved the two major problems (categorical emission and lags), but I'd rather have it merged and getting used than keep it in limbo while we wait for free time to make it perfect. I'll open an issue about those points after it's merged.
) | ||
if rv_to_marginalize.owner.inputs[0].type.ndim > 2: | ||
raise NotImplementedError( | ||
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can a markov chain have a non-matrix transition probability?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be valid for batch dims
|
||
# To compute the prior probabilities of each state, we evaluate the logp of the domain (all possible states) under | ||
# the initial distribution. This is robust to everything the user can throw at it. | ||
batch_logp_init_dist = pt.vectorize(lambda x: logp(init_dist_, x), "()->()")( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No way to avoid this lambda here with vectorize_graph
? I recall this used to be a little function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way to avoid it is to be a little function, but seems like a fine use for lambda?
Categorical is one of the goals I have with #300 I think it's already working there, but I need to rebase and check once we merge this |
The lags is a nice follow up. The current distribution doesn't have a clear API for lags and batch dims, which further stopped me from addressing it here We just need to agree on this and then it should be straightforward to support both. The design question is: how do you specify a markov chain with 2 lags and an extra batch dimension? Say something with shape (5, 100) with two lags but different transition matrixes for each of the five batched chains |
Yeah good questions. You're right it's not clear. I guess the distribution has to store the We could let the user declare the lagged matrices as a tensor (since it's a bit more natural IMO at least) then internally flatten it down and build the index table, then rebuild the tensors after sampling. But this is all for another PR, I 100% agree. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think issues should be open for some the marginalizations that are not implemented now as well. But otherwise, looks good to merge.
raise NotImplementedError( | ||
f"RV with distribution {rv_to_marginalize.owner.op} cannot be marginalized. " | ||
f"Supported distribution include {supported_dists}" | ||
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought the old error message was more helpful
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was but it not gonna scale
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe link to the docs where it lists all the supported distributions then?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The notes state that functionality is restricted, only finite discrete RVs are supported which is kind of true. Although we don't yet support Truncated/Censored of infinite discrete RVs which thus become finite: #95
We also don't support Multinomial which in theory is finite... So I think we the disclaimer functionality is restricted and this error message indicating the type of the RV that could not be marginalized it's fair game?
Co-authored-by: Jesse Grabowski <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>
f97dd6d
to
3ebdfb5
Compare
The following example defines a 2-state HMM, with a 0.9 transition probability of staying in the same state, and a Normal emission centered around -1 for state 0 and 1 for state 1.
Not implemented
Higher order lags and batch P matrices not supported due to complexity (and me not groking the exact API)
Closes #167