Skip to content

Commit 52f902e

Browse files
committed
fix bug with missed variable
1 parent c83b38e commit 52f902e

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,23 @@ def _R2D2M2CP_beta(
7171
if not centered:
7272
with pm.Model(name):
7373
if psi_mask is not None and psi_mask.any():
74+
# limit case where some probs are not 1 or 0
75+
# setsubtensor is required
7476
r_idx = psi_mask.nonzero()
7577
with pm.Model("raw"):
7678
raw = pm.Normal("masked", shape=len(r_idx[0]))
7779
raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw)
7880
raw = pm.Deterministic("raw", raw, dims=dims)
7981
elif psi_mask is not None:
82+
# all variables are deterministic
8083
raw = pt.zeros_like(mu_param)
8184
else:
8285
raw = pm.Normal("raw", dims=dims)
8386
beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims)
8487
else:
8588
if psi_mask is not None and psi_mask.any():
89+
# limit case where some probs are not 1 or 0
90+
# setsubtensor is required
8691
r_idx = psi_mask.nonzero()
8792
with pm.Model(name):
8893
mean = (mu_param / input_sigma)[r_idx]
@@ -96,7 +101,8 @@ def _R2D2M2CP_beta(
96101
beta = pt.set_subtensor(mean, masked)
97102
beta = pm.Deterministic(name, beta, dims=dims)
98103
elif psi_mask is not None:
99-
beta = mean
104+
# all variables are deterministic
105+
beta = pm.Deterministic(name, (mu_param / input_sigma), dims=dims)
100106
else:
101107
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
102108
return beta

pymc_experimental/tests/distributions/test_multivariate.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def r2(self):
5050
def r2_std(self, request):
5151
return request.param
5252

53-
@pytest.fixture(params=["true", "false", "limit-1", "limit-0"])
53+
@pytest.fixture(params=["true", "false", "limit-1", "limit-0", "limit-all"])
5454
def positive_probs(self, input_std, request):
5555
if request.param == "true":
5656
return np.full_like(input_std, 0.5)
@@ -64,6 +64,8 @@ def positive_probs(self, input_std, request):
6464
ret = np.full_like(input_std, 0.5)
6565
ret[..., 0] = 0
6666
return ret
67+
elif request.param == "limit-all":
68+
return np.full_like(input_std, 0)
6769

6870
@pytest.fixture(params=[True, False], ids=["probs-std", "no-probs-std"])
6971
def positive_probs_std(self, positive_probs, request):
@@ -122,14 +124,15 @@ def test_init(
122124
assert eps.eval().shape == output_std.shape
123125
assert beta.eval().shape == input_std.shape
124126
# r2 rv is only created if r2 std is not None
127+
assert "beta" in model.named_vars
125128
assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars)
126129
# phi is only created if variable importance is not None and there is more than one var
127130
assert ("beta::phi" in model.named_vars) == (
128131
"variables_importance" in phi_args or "importance_concentration" in phi_args
129132
), set(model.named_vars)
130-
assert ("beta::psi" in model.named_vars) == (positive_probs_std is not None), set(
131-
model.named_vars
132-
)
133+
assert ("beta::psi" in model.named_vars) == (
134+
positive_probs_std is not None and positive_probs_std.any()
135+
), set(model.named_vars)
133136
assert np.isfinite(sum(model.point_logps().values())), model.point_logps()
134137

135138
def test_failing_importance(self, dims, input_shape, output_std, input_std):

0 commit comments

Comments
 (0)