Skip to content

Commit db0b218

Browse files
committed
Remove stale xfail scan test
1 parent 585962d commit db0b218

File tree

1 file changed

+0
-52
lines changed

1 file changed

+0
-52
lines changed

tests/logprob/test_scan.py

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -388,58 +388,6 @@ def scan_fn(mus_t, sigma_t, Y_t_val, S_t_val, Gamma_t):
388388
assert np.allclose(y_logp_val, y_logp_ref_val)
389389

390390

391-
@pytest.mark.xfail(reason="see #148")
392-
@pytensor.config.change_flags(compute_test_value="raise")
393-
@pytest.mark.xfail(reason="see #148")
394-
def test_initial_values():
395-
srng = pt.random.RandomStream(seed=2320)
396-
397-
p_S_0 = np.array([0.9, 0.1])
398-
S_0_rv = srng.categorical(p_S_0, name="S_0")
399-
S_0_rv.tag.test_value = 0
400-
401-
Gamma_at = pt.matrix("Gamma")
402-
Gamma_at.tag.test_value = np.array([[0, 1], [1, 0]])
403-
404-
s_0_vv = S_0_rv.clone()
405-
s_0_vv.name = "s_0"
406-
407-
def step_fn(S_tm1, Gamma):
408-
S_t = srng.categorical(Gamma[S_tm1], name="S_t")
409-
return S_t
410-
411-
S_1T_rv, _ = pytensor.scan(
412-
fn=step_fn,
413-
outputs_info=[{"initial": S_0_rv, "taps": [-1]}],
414-
non_sequences=[Gamma_at],
415-
strict=True,
416-
n_steps=10,
417-
name="S_0T",
418-
)
419-
420-
S_1T_rv.name = "S_1T"
421-
s_1T_vv = S_1T_rv.clone()
422-
s_1T_vv.name = "s_1T"
423-
424-
logp_parts = conditional_logp({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv})
425-
426-
s_0_val = 0
427-
s_1T_val = np.array([1, 0, 1, 0, 1, 1, 0, 1, 0, 1])
428-
Gamma_val = np.array([[0.1, 0.9], [0.9, 0.1]])
429-
430-
exp_res = np.log(p_S_0[s_0_val])
431-
s_prev = s_0_val
432-
for s in s_1T_val:
433-
exp_res += np.log(Gamma_val[s_prev, s])
434-
s_prev = s
435-
436-
S_0T_logp = sum(v.sum() for v in logp_parts.values())
437-
S_0T_logp_fn = pytensor.function([s_0_vv, s_1T_vv, Gamma_at], S_0T_logp)
438-
res = S_0T_logp_fn(s_0_val, s_1T_val, Gamma_val)
439-
440-
assert res == pytest.approx(exp_res)
441-
442-
443391
@pytest.mark.parametrize("remove_asserts", (True, False))
444392
def test_mode_is_kept(remove_asserts):
445393
mode = Mode().including("local_remove_all_assert") if remove_asserts else None

0 commit comments

Comments
 (0)