diff --git a/pymc_experimental/model/transforms/autoreparam.py b/pymc_experimental/model/transforms/autoreparam.py index cc1f78289..bb3996459 100644 --- a/pymc_experimental/model/transforms/autoreparam.py +++ b/pymc_experimental/model/transforms/autoreparam.py @@ -246,6 +246,44 @@ def _( return vip_rep +@_vip_reparam_node.register +def _( + op: pm.Exponential, + node: Apply, + name: str, + dims: List[Variable], + transform: Optional[Transform], + lam: pt.TensorVariable, +) -> ModelDeterministic: + rng, size, scale = node.inputs + scale_centered = scale**lam + scale_noncentered = scale ** (1 - lam) + vip_rv_ = pm.Exponential.dist( + scale=scale_centered, + size=size, + rng=rng, + ) + vip_rv_value_ = vip_rv_.clone() + vip_rv_.name = f"{name}::tau_" + if transform is not None: + vip_rv_value_.name = f"{vip_rv_.name}_{transform.name}__" + else: + vip_rv_value_.name = vip_rv_.name + vip_rv = model_free_rv( + vip_rv_, + vip_rv_value_, + transform, + *dims, + ) + + vip_rep_ = scale_noncentered * vip_rv + + vip_rep_.name = name + + vip_rep = model_deterministic(vip_rep_, *dims) + return vip_rep + + def vip_reparametrize( model: pm.Model, var_names: Sequence[str], diff --git a/tests/model/transforms/test_autoreparam.py b/tests/model/transforms/test_autoreparam.py index b2ea245ae..1d2173066 100644 --- a/tests/model/transforms/test_autoreparam.py +++ b/tests/model/transforms/test_autoreparam.py @@ -11,6 +11,7 @@ def model_c(): m = pm.Normal("m") s = pm.LogNormal("s") pm.Normal("g", m, s, shape=5) + pm.Exponential("e", scale=s, shape=7) return mod @@ -20,31 +21,34 @@ def model_nc(): m = pm.Normal("m") s = pm.LogNormal("s") pm.Deterministic("g", pm.Normal("z", shape=5) * s + m) + pm.Deterministic("e", pm.Exponential("z_e", 1, shape=7) * s) return mod -def test_reparametrize_created(model_c: pm.Model): - model_reparam, vip = vip_reparametrize(model_c, ["g"]) - assert "g" in vip.get_lambda() - assert "g::lam_logit__" in model_reparam.named_vars - assert "g::tau_" in model_reparam.named_vars +@pytest.mark.parametrize("var", ["g", "e"]) +def test_reparametrize_created(model_c: pm.Model, var): + model_reparam, vip = vip_reparametrize(model_c, [var]) + assert f"{var}" in vip.get_lambda() + assert f"{var}::lam_logit__" in model_reparam.named_vars + assert f"{var}::tau_" in model_reparam.named_vars vip.set_all_lambda(1) - assert ~np.isfinite(model_reparam["g::lam_logit__"].get_value()).any() + assert ~np.isfinite(model_reparam[f"{var}::lam_logit__"].get_value()).any() -def test_random_draw(model_c: pm.Model, model_nc): +@pytest.mark.parametrize("var", ["g", "e"]) +def test_random_draw(model_c: pm.Model, model_nc, var): model_c = pm.do(model_c, {"m": 3, "s": 2}) model_nc = pm.do(model_nc, {"m": 3, "s": 2}) - model_v, vip = vip_reparametrize(model_c, ["g"]) - assert "g" in [v.name for v in model_v.deterministics] - c = pm.draw(model_c["g"], random_seed=42, draws=1000) - nc = pm.draw(model_nc["g"], random_seed=42, draws=1000) + model_v, vip = vip_reparametrize(model_c, [var]) + assert var in [v.name for v in model_v.deterministics] + c = pm.draw(model_c[var], random_seed=42, draws=1000) + nc = pm.draw(model_nc[var], random_seed=42, draws=1000) vip.set_all_lambda(1) - v_1 = pm.draw(model_v["g"], random_seed=42, draws=1000) + v_1 = pm.draw(model_v[var], random_seed=42, draws=1000) vip.set_all_lambda(0) - v_0 = pm.draw(model_v["g"], random_seed=42, draws=1000) + v_0 = pm.draw(model_v[var], random_seed=42, draws=1000) vip.set_all_lambda(0.5) - v_05 = pm.draw(model_v["g"], random_seed=42, draws=1000) + v_05 = pm.draw(model_v[var], random_seed=42, draws=1000) np.testing.assert_allclose(c.mean(), nc.mean()) np.testing.assert_allclose(c.mean(), v_0.mean()) np.testing.assert_allclose(v_05.mean(), v_1.mean()) @@ -57,10 +61,12 @@ def test_random_draw(model_c: pm.Model, model_nc): def test_reparam_fit(model_c): - model_v, vip = vip_reparametrize(model_c, ["g"]) + vars = ["g", "e"] + model_v, vip = vip_reparametrize(model_c, ["g", "e"]) with model_v: - vip.fit(random_seed=42) - np.testing.assert_allclose(vip.get_lambda()["g"], 0, atol=0.01) + vip.fit(50000, random_seed=42) + for var in vars: + np.testing.assert_allclose(vip.get_lambda()[var], 0, atol=0.01) def test_multilevel():