Skip to content

Commit d5ff424

Browse files
committed
Raise when number of dims does not match var.ndim
1 parent 9792dff commit d5ff424

File tree

4 files changed

+31
-17
lines changed

4 files changed

+31
-17
lines changed

pymc/distributions/discrete.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,13 +1185,6 @@ def logp(value, p):
11851185
)
11861186

11871187

1188-
class _OrderedLogistic(Categorical):
1189-
r"""
1190-
Underlying class for ordered logistic distributions.
1191-
See docs for the OrderedLogistic wrapper class for more details on how to use it in models.
1192-
"""
1193-
1194-
11951188
class OrderedLogistic:
11961189
R"""Ordered Logistic distribution.
11971190
@@ -1263,7 +1256,7 @@ class OrderedLogistic:
12631256
def __new__(cls, name, eta, cutpoints, compute_p=True, **kwargs):
12641257
p = cls.compute_p(eta, cutpoints)
12651258
if compute_p:
1266-
p = pm.Deterministic(f"{name}_probs", p, dims=kwargs.get("dims"))
1259+
p = pm.Deterministic(f"{name}_probs", p)
12671260
out_rv = Categorical(name, p=p, **kwargs)
12681261
return out_rv
12691262

@@ -1367,7 +1360,7 @@ class OrderedProbit:
13671360
def __new__(cls, name, eta, cutpoints, sigma=1, compute_p=True, **kwargs):
13681361
p = cls.compute_p(eta, cutpoints, sigma)
13691362
if compute_p:
1370-
p = pm.Deterministic(f"{name}_probs", p, dims=kwargs.get("dims"))
1363+
p = pm.Deterministic(f"{name}_probs", p)
13711364
out_rv = Categorical(name, p=p, **kwargs)
13721365
return out_rv
13731366

pymc/model/core.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,11 @@ def add_named_variable(self, var, dims: tuple[str | None, ...] | None = None):
15321532
raise ValueError(f"Dimension {dim} is not specified in `coords`.")
15331533
if any(var.name == dim for dim in dims if dim is not None):
15341534
raise ValueError(f"Variable `{var.name}` has the same name as its dimension label.")
1535+
# This check implicitly states that only vars with .ndim attribute can have dims
1536+
if var.ndim != len(dims):
1537+
raise ValueError(
1538+
f"{var} has {var.ndim} dims but {len(dims)} dim labels were provided."
1539+
)
15351540
self.named_vars_to_dims[var.name] = dims
15361541

15371542
self.named_vars[var.name] = var

tests/distributions/test_discrete.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -897,19 +897,25 @@ def test_shape_inputs(self, eta, cutpoints, expected):
897897
assert p_shape == expected
898898

899899
def test_compute_p(self):
900-
with pm.Model() as m:
901-
pm.OrderedLogistic("ol_p", cutpoints=np.array([-2, 0, 2]), eta=0)
902-
pm.OrderedLogistic("ol_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, compute_p=False)
900+
with pm.Model(coords={"test_dim": [0]}) as m:
901+
pm.OrderedLogistic("ol_p", cutpoints=np.array([-2, 0, 2]), eta=0, dims="test_dim")
902+
pm.OrderedLogistic(
903+
"ol_no_p", cutpoints=np.array([-2, 0, 2]), eta=0, compute_p=False, dims="test_dim"
904+
)
903905
assert len(m.deterministics) == 1
904906

905907
x = pm.OrderedLogistic.dist(cutpoints=np.array([-2, 0, 2]), eta=0)
906908
assert isinstance(x, TensorVariable)
907909

908910
# Test it works with auto-imputation
909-
with pm.Model() as m:
911+
with pm.Model(coords={"test_dim": [0, 1, 2]}) as m:
910912
with pytest.warns(ImputationWarning):
911913
pm.OrderedLogistic(
912-
"ol", cutpoints=np.array([-2, 0, 2]), eta=0, observed=[0, np.nan, 1]
914+
"ol",
915+
cutpoints=np.array([[-2, 0, 2]]),
916+
eta=0,
917+
observed=[0, np.nan, 1],
918+
dims=["test_dim"],
913919
)
914920
assert len(m.deterministics) == 2 # One from the auto-imputation, the other from compute_p
915921

tests/model/test_core.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -890,7 +890,17 @@ def test_add_named_variable_checks_dim_name(self):
890890
rv2.name = "yumyum"
891891
pmodel.add_named_variable(rv2, dims=("nomnom", None))
892892

893-
def test_dims_type_check(self):
893+
def test_add_named_variable_checks_number_of_dims(self):
894+
match = "dim labels were provided"
895+
with pm.Model(coords={"bad": range(6)}) as m:
896+
with pytest.raises(ValueError, match=match):
897+
m.add_named_variable(pt.random.normal(size=(6, 6, 6), name="a"), dims=("bad",))
898+
899+
# "bad" is an iterable with 3 elements, but we treat strings as a single dim, so it's still invalid
900+
with pytest.raises(ValueError, match=match):
901+
m.add_named_variable(pt.random.normal(size=(6, 6, 6), name="b"), dims="bad")
902+
903+
def test_rv_dims_type_check(self):
894904
with pm.Model(coords={"a": range(5)}) as m:
895905
with pytest.raises(TypeError, match="Dims must be string"):
896906
x = pm.Normal("x", shape=(10, 5), dims=(None, "a"))
@@ -1070,7 +1080,7 @@ def test_determinsitic_with_dims():
10701080
Test to check the passing of dims to the potential
10711081
"""
10721082
with pm.Model(coords={"observed": range(10)}) as model:
1073-
x = pm.Normal("x", 0, 1)
1083+
x = pm.Normal("x", 0, 1, shape=(10,))
10741084
y = pm.Deterministic("y", x**2, dims=("observed",))
10751085
assert model.named_vars_to_dims == {"y": ("observed",)}
10761086

@@ -1080,7 +1090,7 @@ def test_potential_with_dims():
10801090
Test to check the passing of dims to the potential
10811091
"""
10821092
with pm.Model(coords={"observed": range(10)}) as model:
1083-
x = pm.Normal("x", 0, 1)
1093+
x = pm.Normal("x", 0, 1, shape=(10,))
10841094
y = pm.Potential("y", x**2, dims=("observed",))
10851095
assert model.named_vars_to_dims == {"y": ("observed",)}
10861096

0 commit comments

Comments
 (0)