Skip to content

Commit 54b88c9

Browse files
committed
reparam all needed variables when doing fit
1 parent cb4ef56 commit 54b88c9

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

pymc_experimental/tests/model/transforms/test_autoreparam.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,13 @@ def test_random_draw(model_c: pm.Model, model_nc, var):
6060
np.testing.assert_allclose(v_1.std(), nc.std())
6161

6262

63-
@pytest.mark.parametrize("var", ["g", "e"])
64-
def test_reparam_fit(model_c, var):
65-
model_v, vip = vip_reparametrize(model_c, [var])
63+
def test_reparam_fit(model_c):
64+
vars = ["g", "e"]
65+
model_v, vip = vip_reparametrize(model_c, ["g", "e"])
6666
with model_v:
6767
vip.fit(random_seed=42)
68-
np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01)
68+
for var in vars:
69+
np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01)
6970

7071

7172
def test_multilevel():

0 commit comments

Comments
 (0)