@@ -388,58 +388,6 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t):
388
388
assert np .allclose (y_logp_val , y_logp_ref_val )
389
389
390
390
391
- @pytest .mark .xfail (reason = "see #148" )
392
- @pytensor .config .change_flags (compute_test_value = "raise" )
393
- @pytest .mark .xfail (reason = "see #148" )
394
- def test_initial_values ():
395
- srng = pt .random .RandomStream (seed = 2320 )
396
-
397
- p_S_0 = np .array ([0.9 , 0.1 ])
398
- S_0_rv = srng .categorical (p_S_0 , name = "S_0" )
399
- S_0_rv .tag .test_value = 0
400
-
401
- Gamma_at = pt .matrix ("Gamma" )
402
- Gamma_at .tag .test_value = np .array ([[0 , 1 ], [1 , 0 ]])
403
-
404
- s_0_vv = S_0_rv .clone ()
405
- s_0_vv .name = "s_0"
406
-
407
- def step_fn (S_tm1 , Gamma ):
408
- S_t = srng .categorical (Gamma [S_tm1 ], name = "S_t" )
409
- return S_t
410
-
411
- S_1T_rv , _ = pytensor .scan (
412
- fn = step_fn ,
413
- outputs_info = [{"initial" : S_0_rv , "taps" : [- 1 ]}],
414
- non_sequences = [Gamma_at ],
415
- strict = True ,
416
- n_steps = 10 ,
417
- name = "S_0T" ,
418
- )
419
-
420
- S_1T_rv .name = "S_1T"
421
- s_1T_vv = S_1T_rv .clone ()
422
- s_1T_vv .name = "s_1T"
423
-
424
- logp_parts = conditional_logp ({S_1T_rv : s_1T_vv , S_0_rv : s_0_vv })
425
-
426
- s_0_val = 0
427
- s_1T_val = np .array ([1 , 0 , 1 , 0 , 1 , 1 , 0 , 1 , 0 , 1 ])
428
- Gamma_val = np .array ([[0.1 , 0.9 ], [0.9 , 0.1 ]])
429
-
430
- exp_res = np .log (p_S_0 [s_0_val ])
431
- s_prev = s_0_val
432
- for s in s_1T_val :
433
- exp_res += np .log (Gamma_val [s_prev , s ])
434
- s_prev = s
435
-
436
- S_0T_logp = sum (v .sum () for v in logp_parts .values ())
437
- S_0T_logp_fn = pytensor .function ([s_0_vv , s_1T_vv , Gamma_at ], S_0T_logp )
438
- res = S_0T_logp_fn (s_0_val , s_1T_val , Gamma_val )
439
-
440
- assert res == pytest .approx (exp_res )
441
-
442
-
443
391
@pytest .mark .parametrize ("remove_asserts" , (True , False ))
444
392
def test_mode_is_kept (remove_asserts ):
445
393
mode = Mode ().including ("local_remove_all_assert" ) if remove_asserts else None
0 commit comments