Skip to content

Commit c83b38e

Browse files
committed
fix corner cases
1 parent f83656f commit c83b38e

File tree

2 files changed

+23
-3
lines changed

2 files changed

+23
-3
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,19 @@ def _R2D2M2CP_beta(
7070
mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask)
7171
if not centered:
7272
with pm.Model(name):
73-
if psi_mask is not None:
73+
if psi_mask is not None and psi_mask.any():
7474
r_idx = psi_mask.nonzero()
7575
with pm.Model("raw"):
7676
raw = pm.Normal("masked", shape=len(r_idx[0]))
7777
raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw)
7878
raw = pm.Deterministic("raw", raw, dims=dims)
79+
elif psi_mask is not None:
80+
raw = pt.zeros_like(mu_param)
7981
else:
8082
raw = pm.Normal("raw", dims=dims)
8183
beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims)
8284
else:
83-
if psi_mask is not None:
85+
if psi_mask is not None and psi_mask.any():
8486
r_idx = psi_mask.nonzero()
8587
with pm.Model(name):
8688
mean = (mu_param / input_sigma)[r_idx]
@@ -93,6 +95,8 @@ def _R2D2M2CP_beta(
9395
)
9496
beta = pt.set_subtensor(mean, masked)
9597
beta = pm.Deterministic(name, beta, dims=dims)
98+
elif psi_mask is not None:
99+
beta = mean
96100
else:
97101
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
98102
return beta
@@ -121,7 +125,9 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
121125
mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0)
122126
if np.bitwise_and(~mask, positive_probs_std != 0).any():
123127
raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0")
124-
if (~mask).any():
128+
if (~mask).any() and mask.any():
129+
# limit case where some probs are not 1 or 0
130+
# setsubtensor is required
125131
r_idx = mask.nonzero()
126132
with pm.Model("psi"):
127133
psi = pm.Beta(
@@ -132,6 +138,9 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
132138
)
133139
psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi)
134140
psi = pm.Deterministic("psi", psi, dims=dims)
141+
elif (~mask).all():
142+
# limit case where all the probs are limit case
143+
psi = pt.as_tensor(positive_probs)
135144
else:
136145
psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims)
137146
mask = None

pymc_experimental/tests/distributions/test_multivariate.py

+11
Original file line numberDiff line numberDiff line change
@@ -236,3 +236,14 @@ def test_limit_case_creates_masked_vars(self, model: pm.Model, centered: bool):
236236
assert "beta1::masked" in model.named_vars, model.named_vars
237237
assert "beta1::psi::masked" in model.named_vars
238238
assert "beta0::psi::masked" in model.named_vars
239+
240+
def test_zero_length_rvs_not_created(self, model: pm.Model):
241+
model.add_coord("a", range(2))
242+
# deterministic case which should not have any new variables
243+
b = pmx.distributions.R2D2M2CP("b1", 1, [1, 1], r2=0.5, positive_probs=[1, 1], dims="a")
244+
assert not model.free_RVs, model.free_RVs
245+
246+
b = pmx.distributions.R2D2M2CP(
247+
"b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a"
248+
)
249+
assert not model.free_RVs, model.free_RVs

0 commit comments

Comments
 (0)