Skip to content

Commit 5688555

Browse files
committed
Fix VI xfail test
1 parent 3d19864 commit 5688555

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pymc/tests/variational/test_inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,17 +307,19 @@ def test_remove_scan_op():
307307
buff.close()
308308

309309

310-
@pytest.mark.xfail(reason="Broke from static shape handling with Aesara 2.8.8")
311310
def test_var_replacement():
312311
X_mean = pm.floatX(np.linspace(0, 10, 10))
313312
y = pm.floatX(np.random.normal(X_mean * 4, 0.05))
313+
inp_size = aesara.shared(np.array(10, dtype="int64"), name="inp_size")
314314
with pm.Model():
315-
inp = pm.Normal("X", X_mean, size=X_mean.shape)
315+
inp = pm.Normal("X", X_mean, size=(inp_size,))
316316
coef = pm.Normal("b", 4.0)
317317
mean = inp * coef
318-
pm.Normal("y", mean, 0.1, observed=y)
318+
pm.Normal("y", mean, 0.1, shape=inp.shape, observed=y)
319319
advi = pm.fit(100)
320320
assert advi.sample_node(mean).eval().shape == (10,)
321+
322+
inp_size.set_value(11)
321323
x_new = pm.floatX(np.linspace(0, 10, 11))
322324
assert advi.sample_node(mean, more_replacements={inp: x_new}).eval().shape == (11,)
323325

0 commit comments

Comments
 (0)