27
27
)
28
28
from pymc_experimental .statespace .filters .distributions import (
29
29
LinearGaussianStateSpace ,
30
+ MvNormalSVD ,
30
31
SequenceMvNormal ,
31
32
)
32
33
from pymc_experimental .statespace .filters .utilities import stabilize
@@ -864,9 +865,8 @@ def build_statespace_graph(
864
865
cov_jitter = cov_jitter ,
865
866
)
866
867
867
- outputs = filter_outputs
868
- logp = outputs .pop (- 1 )
869
- states , covs = outputs [:3 ], outputs [3 :]
868
+ logp = filter_outputs .pop (- 1 )
869
+ states , covs = filter_outputs [:3 ], filter_outputs [3 :]
870
870
filtered_states , predicted_states , observed_states = states
871
871
filtered_covariances , predicted_covariances , observed_covariances = covs
872
872
if save_kalman_filter_outputs_in_idata :
@@ -2010,7 +2010,7 @@ def forecast(
2010
2010
2011
2011
with pm .Model (coords = temp_coords ) as forecast_model :
2012
2012
(_ , _ , * matrices ), grouped_outputs = self ._kalman_filter_outputs_from_dummy_graph (
2013
- data_dims = ["data_time" , OBS_STATE_DIM ]
2013
+ data_dims = ["data_time" , OBS_STATE_DIM ],
2014
2014
)
2015
2015
2016
2016
group_idx = FILTER_OUTPUT_TYPES .index (filter_output )
@@ -2026,7 +2026,7 @@ def forecast(
2026
2026
if scenario is not None :
2027
2027
sub_dict = {
2028
2028
forecast_model [data_name ]: pt .as_tensor_variable (
2029
- scenario .get (data_name ), name = "data_var"
2029
+ scenario .get (data_name ), name = data_name
2030
2030
)
2031
2031
for data_name in self .data_names
2032
2032
}
@@ -2173,16 +2173,16 @@ def impulse_response_function(
2173
2173
if use_posterior_cov :
2174
2174
Q = post_Q
2175
2175
if orthogonalize_shocks :
2176
- Q = pt .linalg .cholesky (Q )
2176
+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
2177
2177
elif shock_cov is not None :
2178
2178
Q = pt .as_tensor_variable (shock_cov )
2179
2179
if orthogonalize_shocks :
2180
- Q = pt .linalg .cholesky (Q )
2180
+ Q = pt .linalg .cholesky (Q ) / pt . diag ( Q )
2181
2181
2182
2182
if shock_trajectory is None :
2183
2183
shock_trajectory = pt .zeros ((n_steps , self .k_posdef ))
2184
2184
if Q is not None :
2185
- init_shock = pm . MvNormal ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
2185
+ init_shock = MvNormalSVD ("initial_shock" , mu = 0 , cov = Q , dims = [SHOCK_DIM ])
2186
2186
else :
2187
2187
init_shock = pm .Deterministic (
2188
2188
"initial_shock" ,
0 commit comments