-
-
Notifications
You must be signed in to change notification settings - Fork 61
Expose multivariate normal method
argument in post-estimation tasks
#484
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…s when doing post-estimation tasks
@@ -1109,6 +1109,7 @@ def _sample_conditional( | |||
group: str, | |||
random_seed: RandomState | None = None, | |||
data: pt.TensorLike | None = None, | |||
method: str = "svd", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you want to type hint the options and/or mention them on the docstrings? Potentially mention the reason for choosing some over the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just pushed this without having read this comment, so yes :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR exposes the internal multivariate normal method
argument (default "svd"
) in post-estimation sampling, forecasting, and related state‐space APIs, allowing users to choose faster but potentially less robust algorithms like "cholesky"
or "eig"
.
- Added
method: str = "svd"
parameters to distribution constructors (__new__
,dist
,rv_op
) infilters/distributions.py
and to public sampling, forecasting, and IRF methods incore/statespace.py
. - Propagated the
method
argument through all internalpm.MvNormal.dist(..., method=...)
calls. - Updated docstrings in
core/statespace.py
to describe the newmethod
parameter and its allowed values.
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
File | Description |
---|---|
pymc_extras/statespace/filters/distributions.py | Expose method argument in distribution constructors and propagate to MvNormal |
pymc_extras/statespace/core/statespace.py | Add method parameter to sampling, forecasting, and IRF methods with docs |
Comments suppressed due to low confidence (1)
pymc_extras/statespace/filters/distributions.py:213
- Add tests for multiple
method
values (e.g., 'cholesky', 'eig') in thepm.MvNormal
calls inside the scan to ensure the new argument is correctly applied.
mu=0, cov=Q, rng=rng, method=method
Post-estimation sampling tasks, especially
sample_conditional_posterior
, can be really slow for long time series. One reason for this is that I had set themethod
argument of all the multivariate normal distributions internally to besvd
, which is the most robust option. There are cases, however, where one knows the model is well behaved, so you can drop back to something faster (likecholesky
) without any problem. This PR allows this choice.The default is still
svd
. Ideally we'd check a couple covariance matrices from the providedidata
and try to give the user a smart default, but I don't want to do that much work here.