10
10
from pytensor .graph .basic import Node
11
11
12
12
floatX = pytensor .config .floatX
13
+ COV_ZERO_TOL = 0
13
14
14
15
lgss_shape_message = (
15
16
"The LinearGaussianStateSpace distribution needs shape information to be constructed. "
@@ -157,8 +158,11 @@ def step_fn(*args):
157
158
middle_rng , a_innovation = pm .MvNormal .dist (mu = 0 , cov = Q , rng = rng ).owner .outputs
158
159
next_rng , y_innovation = pm .MvNormal .dist (mu = 0 , cov = H , rng = middle_rng ).owner .outputs
159
160
160
- a_next = c + T @ a + R @ a_innovation
161
- y_next = d + Z @ a_next + y_innovation
161
+ a_mu = c + T @ a
162
+ a_next = pt .switch (pt .all (pt .le (Q , COV_ZERO_TOL )), a_mu , a_mu + R @ a_innovation )
163
+
164
+ y_mu = d + Z @ a_next
165
+ y_next = pt .switch (pt .all (pt .le (H , COV_ZERO_TOL )), y_mu , y_mu + y_innovation )
162
166
163
167
next_state = pt .concatenate ([a_next , y_next ], axis = 0 )
164
168
@@ -168,7 +172,11 @@ def step_fn(*args):
168
172
Z_init = Z_ if Z_ in non_sequences else Z_ [0 ]
169
173
H_init = H_ if H_ in non_sequences else H_ [0 ]
170
174
171
- init_y_ = pm .MvNormal .dist (Z_init @ init_x_ , H_init , rng = rng )
175
+ init_y_ = pt .switch (
176
+ pt .all (pt .le (H_init , COV_ZERO_TOL )),
177
+ Z_init @ init_x_ ,
178
+ pm .MvNormal .dist (Z_init @ init_x_ , H_init , rng = rng ),
179
+ )
172
180
init_dist_ = pt .concatenate ([init_x_ , init_y_ ], axis = 0 )
173
181
174
182
statespace , updates = pytensor .scan (
@@ -216,6 +224,7 @@ def __new__(
216
224
steps = None ,
217
225
mode = None ,
218
226
sequence_names = None ,
227
+ k_endog = None ,
219
228
** kwargs ,
220
229
):
221
230
dims = kwargs .pop ("dims" , None )
@@ -239,11 +248,29 @@ def __new__(
239
248
sequence_names = sequence_names ,
240
249
** kwargs ,
241
250
)
242
-
243
251
k_states = T .type .shape [0 ]
244
252
245
- latent_states = latent_obs_combined [..., :k_states ]
246
- obs_states = latent_obs_combined [..., k_states :]
253
+ if k_endog is None and k_states is None :
254
+ raise ValueError ("Could not infer number of observed states, explicitly pass k_endog." )
255
+ if k_endog is not None and k_states is not None :
256
+ total_shape = latent_obs_combined .type .shape [- 1 ]
257
+ inferred_endog = total_shape - k_states
258
+ if inferred_endog != k_endog :
259
+ raise ValueError (
260
+ f"Inferred k_endog does not agree with provided value ({ inferred_endog } != { k_endog } ). "
261
+ f"It is not necessary to provide k_endog when the value can be inferred."
262
+ )
263
+ latent_slice = slice (None , - k_endog )
264
+ obs_slice = slice (- k_endog , None )
265
+ elif k_endog is None :
266
+ latent_slice = slice (None , k_states )
267
+ obs_slice = slice (k_states , None )
268
+ else :
269
+ latent_slice = slice (None , - k_endog )
270
+ obs_slice = slice (- k_endog , None )
271
+
272
+ latent_states = latent_obs_combined [..., latent_slice ]
273
+ obs_states = latent_obs_combined [..., obs_slice ]
247
274
248
275
latent_states = pm .Deterministic (f"{ name } _latent" , latent_states , dims = latent_dims )
249
276
obs_states = pm .Deterministic (f"{ name } _observed" , obs_states , dims = obs_dims )
0 commit comments