Skip to content

Commit 2c2c375

Browse files
committed
add requirement for dims to be immutable for the prior as it is required for the limit case masking
1 parent a030dda commit 2c2c375

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,12 @@ def R2D2M2CP(
394394
*broadcast_dims, dim = dims
395395
input_sigma = pt.as_tensor(input_sigma)
396396
output_sigma = pt.as_tensor(output_sigma)
397-
with pm.Model(name):
397+
with pm.Model(name) as model:
398+
if not all(
399+
isinstance(model.dim_lengths[d], pt.TensorConstant)
400+
for d in dims
401+
):
402+
raise ValueError(f"{dims!r} should be constant length imutable dims")
398403
if r2_std is not None:
399404
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
400405
phi = _phi(

pymc_experimental/tests/distributions/test_multivariate.py

+24
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,27 @@ def test_zero_length_rvs_not_created(self, model: pm.Model):
255255
"b2", 1, [1, 1], r2=0.5, positive_probs=[1, 1], positive_probs_std=[0, 0], dims="a"
256256
)
257257
assert not model.free_RVs, model.free_RVs
258+
259+
def test_immutable_dims(self, model: pm.Model):
260+
model.add_coord("a", range(2), mutable=True)
261+
model.add_coord("b", range(2), mutable=False)
262+
with pytest.raises(ValueError, match="should be constant length imutable dims"):
263+
pmx.distributions.R2D2M2CP(
264+
"beta0",
265+
1,
266+
[1, 1],
267+
dims="a",
268+
r2=0.8,
269+
positive_probs=[0.5, 1],
270+
positive_probs_std=[0.3, 0],
271+
)
272+
with pytest.raises(ValueError, match="should be constant length imutable dims"):
273+
pmx.distributions.R2D2M2CP(
274+
"beta0",
275+
1,
276+
[1, 1],
277+
dims=("a", "b"),
278+
r2=0.8,
279+
positive_probs=[0.5, 1],
280+
positive_probs_std=[0.3, 0],
281+
)

0 commit comments

Comments
 (0)