Skip to content

Commit cf6d4ce

Browse files
authored
Allow passing dims to Potential and Deterministic (#6576)
* Adding dims parameter in Potential * Test case to check the dims in potential
1 parent 047141c commit cf6d4ce

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

pymc/model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,7 +2041,7 @@ def Deterministic(name, var, model=None, dims=None):
20412041
return var
20422042

20432043

2044-
def Potential(name, var, model=None):
2044+
def Potential(name, var, model=None, dims=None):
20452045
"""
20462046
Add an arbitrary factor potential to the model likelihood
20472047
@@ -2135,7 +2135,7 @@ def Potential(name, var, model=None):
21352135
model = modelcontext(model)
21362136
var.name = model.name_for(name)
21372137
model.potentials.append(var)
2138-
model.add_named_variable(var)
2138+
model.add_named_variable(var, dims)
21392139

21402140
from pymc.printing import str_for_potential_or_deterministic
21412141

tests/test_model.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,6 +1062,26 @@ def test_deterministic():
10621062
assert model["y"] == y
10631063

10641064

1065+
def test_determinsitic_with_dims():
1066+
"""
1067+
Test to check the passing of dims to the potential
1068+
"""
1069+
with pm.Model(coords={"observed": range(10)}) as model:
1070+
x = pm.Normal("x", 0, 1)
1071+
y = pm.Deterministic("y", x**2, dims=("observed",))
1072+
assert model.named_vars_to_dims == {"y": ("observed",)}
1073+
1074+
1075+
def test_potential_with_dims():
1076+
"""
1077+
Test to check the passing of dims to the potential
1078+
"""
1079+
with pm.Model(coords={"observed": range(10)}) as model:
1080+
x = pm.Normal("x", 0, 1)
1081+
y = pm.Potential("y", x**2, dims=("observed",))
1082+
assert model.named_vars_to_dims == {"y": ("observed",)}
1083+
1084+
10651085
def test_empty_model_representation():
10661086
assert pm.Model().str_repr() == ""
10671087

0 commit comments

Comments
 (0)