Skip to content

Commit a4ace35

Browse files
dfmricardoV94
authored andcommitted
Fixing keyword args in top-level pm.compile_fn
`Model.compile_fn` requires all parameters except `outs` to be keyword arguments so the top level `compile_fn` didn't work as written.
1 parent e58c0e6 commit a4ace35

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

pymc/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1886,7 +1886,7 @@ def compile_fn(outs, mode=None, point_fn=True, model=None, **kwargs):
18861886
Compiled Aesara function as point function.
18871887
"""
18881888
model = modelcontext(model)
1889-
return model.compile_fn(outs, mode, point_fn=point_fn, **kwargs)
1889+
return model.compile_fn(outs, mode=mode, point_fn=point_fn, **kwargs)
18901890

18911891

18921892
def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:

pymc/tests/test_model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -996,3 +996,21 @@ def test_deterministic():
996996

997997
def test_empty_model_representation():
998998
assert pm.Model().str_repr() == ""
999+
1000+
1001+
def test_compile_fn():
1002+
with pm.Model() as m:
1003+
x = pm.Normal("x", 0, 1, size=2)
1004+
y = pm.LogNormal("y", 0, 1, size=2)
1005+
1006+
test_vals = np.array([0.0, -1.0])
1007+
state = {"x": test_vals, "y": test_vals}
1008+
1009+
with m:
1010+
func = pm.compile_fn(x + y, inputs=[x, y])
1011+
result_compute = func(state)
1012+
1013+
func = m.compile_fn(x + y, inputs=[x, y])
1014+
result_expect = func(state)
1015+
1016+
np.testing.assert_allclose(result_compute, result_expect)

0 commit comments

Comments
 (0)