Skip to content

Commit f4a9bc2

Browse files
committed
test: add tests for autoreparam
1 parent 02d42d6 commit f4a9bc2

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

pymc_experimental/tests/model/transforms/test_autoreparam.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ def model_c():
1111
m = pm.Normal("m")
1212
s = pm.LogNormal("s")
1313
pm.Normal("g", m, s, shape=5)
14+
pm.Exponential("e", scale=s, shape=7)
1415
return mod
1516

1617

@@ -20,31 +21,34 @@ def model_nc():
2021
m = pm.Normal("m")
2122
s = pm.LogNormal("s")
2223
pm.Deterministic("g", pm.Normal("z", shape=5) * s + m)
24+
pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s)
2325
return mod
2426

2527

26-
def test_reparametrize_created(model_c: pm.Model):
27-
model_reparam, vip = vip_reparametrize(model_c, ["g"])
28-
assert "g" in vip.get_lambda()
29-
assert "g::lam_logit__" in model_reparam.named_vars
30-
assert "g::tau_" in model_reparam.named_vars
28+
@pytest.mark.parameterize("var", ["g", "e"])
29+
def test_reparametrize_created(model_c: pm.Model, var):
30+
model_reparam, vip = vip_reparametrize(model_c, [var])
31+
assert f"{var}" in vip.get_lambda()
32+
assert f"{var}::lam_logit__" in model_reparam.named_vars
33+
assert f"{var}::tau_" in model_reparam.named_vars
3134
vip.set_all_lambda(1)
32-
assert ~np.isfinite(model_reparam["g::lam_logit__"].get_value()).any()
35+
assert ~np.isfinite(model_reparam[f"{var}::lam_logit__"].get_value()).any()
3336

3437

35-
def test_random_draw(model_c: pm.Model, model_nc):
38+
@pytest.mark.parameterize("var", ["g", "e"])
39+
def test_random_draw(model_c: pm.Model, model_nc, var):
3640
model_c = pm.do(model_c, {"m": 3, "s": 2})
3741
model_nc = pm.do(model_nc, {"m": 3, "s": 2})
38-
model_v, vip = vip_reparametrize(model_c, ["g"])
39-
assert "g" in [v.name for v in model_v.deterministics]
40-
c = pm.draw(model_c["g"], random_seed=42, draws=1000)
41-
nc = pm.draw(model_nc["g"], random_seed=42, draws=1000)
42+
model_v, vip = vip_reparametrize(model_c, [var])
43+
assert var in [v.name for v in model_v.deterministics]
44+
c = pm.draw(model_c[var], random_seed=42, draws=1000)
45+
nc = pm.draw(model_nc[var], random_seed=42, draws=1000)
4246
vip.set_all_lambda(1)
43-
v_1 = pm.draw(model_v["g"], random_seed=42, draws=1000)
47+
v_1 = pm.draw(model_v[var], random_seed=42, draws=1000)
4448
vip.set_all_lambda(0)
45-
v_0 = pm.draw(model_v["g"], random_seed=42, draws=1000)
49+
v_0 = pm.draw(model_v[var], random_seed=42, draws=1000)
4650
vip.set_all_lambda(0.5)
47-
v_05 = pm.draw(model_v["g"], random_seed=42, draws=1000)
51+
v_05 = pm.draw(model_v[var], random_seed=42, draws=1000)
4852
np.testing.assert_allclose(c.mean(), nc.mean())
4953
np.testing.assert_allclose(c.mean(), v_0.mean())
5054
np.testing.assert_allclose(v_05.mean(), v_1.mean())
@@ -56,11 +60,12 @@ def test_random_draw(model_c: pm.Model, model_nc):
5660
np.testing.assert_allclose(v_1.std(), nc.std())
5761

5862

59-
def test_reparam_fit(model_c):
60-
model_v, vip = vip_reparametrize(model_c, ["g"])
63+
@pytest.mark.parameterize("var", ["g", "e"])
64+
def test_reparam_fit(model_c, var):
65+
model_v, vip = vip_reparametrize(model_c, [var])
6166
with model_v:
6267
vip.fit(random_seed=42)
63-
np.testing.assert_allclose(vip.get_lambda()["g"], 0, atol=0.01)
68+
np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01)
6469

6570

6671
def test_multilevel():

0 commit comments

Comments
 (0)