Skip to content

Commit 365b117

Browse files
committed
Add ability to set mode in check_start_vals
1 parent c92a9a9 commit 365b117

File tree

2 files changed

+22
-4
lines changed

2 files changed

+22
-4
lines changed

pymc/model/core.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1747,7 +1747,7 @@ def eval_rv_shapes(self) -> dict[str, tuple[int, ...]]:
17471747
)
17481748
return {name: tuple(shape) for name, shape in zip(names, f())}
17491749

1750-
def check_start_vals(self, start):
1750+
def check_start_vals(self, start, **kwargs):
17511751
r"""Check that the starting values for MCMC do not cause the relevant log probability
17521752
to evaluate to something invalid (e.g. Inf or NaN)
17531753
@@ -1758,6 +1758,8 @@ def check_start_vals(self, start):
17581758
Defaults to ``trace.point(-1))`` if there is a trace provided and
17591759
``model.initial_point`` if not (defaults to empty dict). Initialization
17601760
methods for NUTS (see ``init`` keyword) can overwrite the default.
1761+
Other keyword arguments :
1762+
Any other keyword argument is sent to :py:meth:`~pymc.model.core.Model.point_logps`.
17611763
17621764
Raises
17631765
------
@@ -1787,7 +1789,7 @@ def check_start_vals(self, start):
17871789
f"Valid keys are: {valid_keys}, but {extra_keys} was supplied"
17881790
)
17891791

1790-
initial_eval = self.point_logps(point=elem)
1792+
initial_eval = self.point_logps(point=elem, **kwargs)
17911793

17921794
if not all(np.isfinite(v) for v in initial_eval.values()):
17931795
raise SamplingError(
@@ -1797,7 +1799,7 @@ def check_start_vals(self, start):
17971799
"You can call `model.debug()` for more details."
17981800
)
17991801

1800-
def point_logps(self, point=None, round_vals=2):
1802+
def point_logps(self, point=None, round_vals=2, **kwargs):
18011803
"""Computes the log probability of `point` for all random variables in the model.
18021804
18031805
Parameters
@@ -1807,6 +1809,8 @@ def point_logps(self, point=None, round_vals=2):
18071809
is used.
18081810
round_vals : int, default 2
18091811
Number of decimals to round log-probabilities.
1812+
Other keyword arguments :
1813+
Any other keyword argument are sent provided to :py:meth:`~pymc.model.core.Model.compile_fn`
18101814
18111815
Returns
18121816
-------
@@ -1822,7 +1826,7 @@ def point_logps(self, point=None, round_vals=2):
18221826
factor.name: np.round(np.asarray(factor_logp), round_vals)
18231827
for factor, factor_logp in zip(
18241828
factors,
1825-
self.compile_fn(factor_logps_fn)(point),
1829+
self.compile_fn(factor_logps_fn, **kwargs)(point),
18261830
)
18271831
}
18281832

tests/model/test_core.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -756,6 +756,20 @@ def test_invalid_variable_name(self):
756756
with pytest.raises(KeyError):
757757
model.check_start_vals(start)
758758

759+
@pytest.mark.parametrize("mode", [None, "JAX", "NUMBA"])
760+
def test_mode(self, mode):
761+
with pm.Model() as model:
762+
a = pm.Uniform("a", lower=0.0, upper=1.0)
763+
b = pm.Uniform("b", lower=2.0, upper=3.0)
764+
start = {
765+
"a_interval__": model.rvs_to_transforms[a].forward(0.3, *a.owner.inputs).eval(),
766+
"b_interval__": model.rvs_to_transforms[b].forward(2.1, *b.owner.inputs).eval(),
767+
}
768+
with patch("pymc.model.core.compile_pymc") as patched_compile_pymc:
769+
model.check_start_vals(start, mode=mode)
770+
patched_compile_pymc.assert_called_once()
771+
assert patched_compile_pymc.call_args.kwargs["mode"] == mode
772+
759773

760774
def test_set_initval():
761775
# Make sure the dependencies between variables are maintained when

0 commit comments

Comments
 (0)