Skip to content

Commit 9589f28

Browse files
Expose multivariate normal method argument in post-estimation tasks (#484)
* Expose `method` argument to MvNormals used in statespace distributions when doing post-estimation tasks * Use keyword arguments when calling post-estimation functions internally * Fix typo, more clear argument name * improve type hint * Add `method` to `build_statespace_graph` and include some general advice. * Fix incorrect signature in call to _sample_unconditional
1 parent 7755e1b commit 9589f28

File tree

2 files changed

+131
-16
lines changed

2 files changed

+131
-16
lines changed

pymc_extras/statespace/core/statespace.py

Lines changed: 116 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22

33
from collections.abc import Callable, Sequence
4-
from typing import Any
4+
from typing import Any, Literal
55

66
import numpy as np
77
import pandas as pd
@@ -822,6 +822,7 @@ def build_statespace_graph(
822822
mode: str | None = None,
823823
missing_fill_value: float | None = None,
824824
cov_jitter: float | None = JITTER_DEFAULT,
825+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
825826
save_kalman_filter_outputs_in_idata: bool = False,
826827
) -> None:
827828
"""
@@ -865,6 +866,14 @@ def build_statespace_graph(
865866
866867
- The Univariate Filter is more robust than other filters, and can tolerate a lower jitter value
867868
869+
mvn_method: str, default "svd"
870+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
871+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
872+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
873+
874+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
875+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
876+
868877
save_kalman_filter_outputs_in_idata: bool, optional, default=False
869878
If True, Kalman Filter outputs will be saved in the model as deterministics. Useful for debugging, but
870879
should not be necessary for the majority of users.
@@ -915,6 +924,7 @@ def build_statespace_graph(
915924
logp=logp,
916925
observed=data,
917926
dims=obs_dims,
927+
method=mvn_method,
918928
)
919929

920930
self._fit_coords = pm_mod.coords.copy()
@@ -1109,6 +1119,7 @@ def _sample_conditional(
11091119
group: str,
11101120
random_seed: RandomState | None = None,
11111121
data: pt.TensorLike | None = None,
1122+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
11121123
**kwargs,
11131124
):
11141125
"""
@@ -1130,6 +1141,14 @@ def _sample_conditional(
11301141
Observed data on which to condition the model. If not provided, the function will use the data that was
11311142
provided when the model was built.
11321143
1144+
mvn_method: str, default "svd"
1145+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
1146+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
1147+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
1148+
1149+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
1150+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
1151+
11331152
kwargs:
11341153
Additional keyword arguments are passed to pymc.sample_posterior_predictive
11351154
@@ -1181,6 +1200,7 @@ def _sample_conditional(
11811200
covs=cov,
11821201
logp=dummy_ll,
11831202
dims=state_dims,
1203+
method=mvn_method,
11841204
)
11851205

11861206
obs_mu = (Z @ mu[..., None]).squeeze(-1)
@@ -1192,6 +1212,7 @@ def _sample_conditional(
11921212
covs=obs_cov,
11931213
logp=dummy_ll,
11941214
dims=obs_dims,
1215+
method=mvn_method,
11951216
)
11961217

11971218
# TODO: Remove this after pm.Flat initial values are fixed
@@ -1222,6 +1243,7 @@ def _sample_unconditional(
12221243
steps: int | None = None,
12231244
use_data_time_dim: bool = False,
12241245
random_seed: RandomState | None = None,
1246+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
12251247
**kwargs,
12261248
):
12271249
"""
@@ -1251,6 +1273,14 @@ def _sample_unconditional(
12511273
random_seed : int, RandomState or Generator, optional
12521274
Seed for the random number generator.
12531275
1276+
mvn_method: str, default "svd"
1277+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
1278+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
1279+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
1280+
1281+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
1282+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
1283+
12541284
kwargs:
12551285
Additional keyword arguments are passed to pymc.sample_posterior_predictive
12561286
@@ -1309,6 +1339,7 @@ def _sample_unconditional(
13091339
steps=steps,
13101340
dims=dims,
13111341
mode=self._fit_mode,
1342+
method=mvn_method,
13121343
sequence_names=self.kalman_filter.seq_names,
13131344
k_endog=self.k_endog,
13141345
)
@@ -1331,7 +1362,11 @@ def _sample_unconditional(
13311362
return idata_unconditional.posterior_predictive
13321363

13331364
def sample_conditional_prior(
1334-
self, idata: InferenceData, random_seed: RandomState | None = None, **kwargs
1365+
self,
1366+
idata: InferenceData,
1367+
random_seed: RandomState | None = None,
1368+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
1369+
**kwargs,
13351370
) -> InferenceData:
13361371
"""
13371372
Sample from the conditional prior; that is, given parameter draws from the prior distribution,
@@ -1347,6 +1382,14 @@ def sample_conditional_prior(
13471382
random_seed : int, RandomState or Generator, optional
13481383
Seed for the random number generator.
13491384
1385+
mvn_method: str, default "svd"
1386+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
1387+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
1388+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
1389+
1390+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
1391+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
1392+
13501393
kwargs:
13511394
Additional keyword arguments are passed to pymc.sample_posterior_predictive
13521395
@@ -1358,10 +1401,16 @@ def sample_conditional_prior(
13581401
"predicted_prior", and "smoothed_prior".
13591402
"""
13601403

1361-
return self._sample_conditional(idata, "prior", random_seed, **kwargs)
1404+
return self._sample_conditional(
1405+
idata=idata, group="prior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1406+
)
13621407

13631408
def sample_conditional_posterior(
1364-
self, idata: InferenceData, random_seed: RandomState | None = None, **kwargs
1409+
self,
1410+
idata: InferenceData,
1411+
random_seed: RandomState | None = None,
1412+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
1413+
**kwargs,
13651414
):
13661415
"""
13671416
Sample from the conditional posterior; that is, given parameter draws from the posterior distribution,
@@ -1376,6 +1425,14 @@ def sample_conditional_posterior(
13761425
random_seed : int, RandomState or Generator, optional
13771426
Seed for the random number generator.
13781427
1428+
mvn_method: str, default "svd"
1429+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
1430+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
1431+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
1432+
1433+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
1434+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
1435+
13791436
kwargs:
13801437
Additional keyword arguments are passed to pymc.sample_posterior_predictive
13811438
@@ -1387,14 +1444,17 @@ def sample_conditional_posterior(
13871444
"predicted_posterior", and "smoothed_posterior".
13881445
"""
13891446

1390-
return self._sample_conditional(idata, "posterior", random_seed, **kwargs)
1447+
return self._sample_conditional(
1448+
idata=idata, group="posterior", random_seed=random_seed, mvn_method=mvn_method, **kwargs
1449+
)
13911450

13921451
def sample_unconditional_prior(
13931452
self,
13941453
idata: InferenceData,
13951454
steps: int | None = None,
13961455
use_data_time_dim: bool = False,
13971456
random_seed: RandomState | None = None,
1457+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
13981458
**kwargs,
13991459
) -> InferenceData:
14001460
"""
@@ -1423,6 +1483,14 @@ def sample_unconditional_prior(
14231483
random_seed : int, RandomState or Generator, optional
14241484
Seed for the random number generator.
14251485
1486+
mvn_method: str, default "svd"
1487+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
1488+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
1489+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
1490+
1491+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
1492+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
1493+
14261494
kwargs:
14271495
Additional keyword arguments are passed to pymc.sample_posterior_predictive
14281496
@@ -1439,7 +1507,13 @@ def sample_unconditional_prior(
14391507
"""
14401508

14411509
return self._sample_unconditional(
1442-
idata, "prior", steps, use_data_time_dim, random_seed, **kwargs
1510+
idata=idata,
1511+
group="prior",
1512+
steps=steps,
1513+
use_data_time_dim=use_data_time_dim,
1514+
random_seed=random_seed,
1515+
mvn_method=mvn_method,
1516+
**kwargs,
14431517
)
14441518

14451519
def sample_unconditional_posterior(
@@ -1448,6 +1522,7 @@ def sample_unconditional_posterior(
14481522
steps: int | None = None,
14491523
use_data_time_dim: bool = False,
14501524
random_seed: RandomState | None = None,
1525+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
14511526
**kwargs,
14521527
) -> InferenceData:
14531528
"""
@@ -1477,6 +1552,14 @@ def sample_unconditional_posterior(
14771552
random_seed : int, RandomState or Generator, optional
14781553
Seed for the random number generator.
14791554
1555+
mvn_method: str, default "svd"
1556+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
1557+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
1558+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
1559+
1560+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
1561+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
1562+
14801563
Returns
14811564
-------
14821565
InferenceData
@@ -1490,7 +1573,13 @@ def sample_unconditional_posterior(
14901573
"""
14911574

14921575
return self._sample_unconditional(
1493-
idata, "posterior", steps, use_data_time_dim, random_seed, **kwargs
1576+
idata=idata,
1577+
group="posterior",
1578+
steps=steps,
1579+
use_data_time_dim=use_data_time_dim,
1580+
random_seed=random_seed,
1581+
mvn_method=mvn_method,
1582+
**kwargs,
14941583
)
14951584

14961585
def sample_statespace_matrices(
@@ -1933,6 +2022,7 @@ def forecast(
19332022
filter_output="smoothed",
19342023
random_seed: RandomState | None = None,
19352024
verbose: bool = True,
2025+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
19362026
**kwargs,
19372027
) -> InferenceData:
19382028
"""
@@ -1989,6 +2079,14 @@ def forecast(
19892079
verbose: bool, default=True
19902080
Whether to print diagnostic information about forecasting.
19912081
2082+
mvn_method: str, default "svd"
2083+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
2084+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
2085+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
2086+
2087+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
2088+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
2089+
19922090
kwargs:
19932091
Additional keyword arguments are passed to pymc.sample_posterior_predictive
19942092
@@ -2098,6 +2196,7 @@ def forecast(
20982196
sequence_names=self.kalman_filter.seq_names,
20992197
k_endog=self.k_endog,
21002198
append_x0=False,
2199+
method=mvn_method,
21012200
)
21022201

21032202
forecast_model.rvs_to_initial_values = {
@@ -2126,6 +2225,7 @@ def impulse_response_function(
21262225
shock_trajectory: np.ndarray | None = None,
21272226
orthogonalize_shocks: bool = False,
21282227
random_seed: RandomState | None = None,
2228+
mvn_method: Literal["cholesky", "eigh", "svd"] = "svd",
21292229
**kwargs,
21302230
):
21312231
"""
@@ -2177,6 +2277,14 @@ def impulse_response_function(
21772277
random_seed : int, RandomState or Generator, optional
21782278
Seed for the random number generator.
21792279
2280+
mvn_method: str, default "svd"
2281+
Method used to invert the covariance matrix when calculating the pdf of a multivariate normal
2282+
(or when generating samples). One of "cholesky", "eigh", or "svd". "cholesky" is fastest, but least robust
2283+
to ill-conditioned matrices, while "svd" is slow but extremely robust.
2284+
2285+
In general, if your model has measurement error, "cholesky" will be safe to use. Otherwise, "svd" is
2286+
recommended. "eigh" can also be tried if sampling with "svd" is very slow, but it is not as robust as "svd".
2287+
21802288
kwargs:
21812289
Additional keyword arguments are passed to pymc.sample_posterior_predictive
21822290
@@ -2236,7 +2344,7 @@ def impulse_response_function(
22362344
shock_trajectory = pt.zeros((n_steps, self.k_posdef))
22372345
if Q is not None:
22382346
init_shock = pm.MvNormal(
2239-
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method="svd"
2347+
"initial_shock", mu=0, cov=Q, dims=[SHOCK_DIM], method=mvn_method
22402348
)
22412349
else:
22422350
init_shock = pm.Deterministic(

0 commit comments

Comments
 (0)