Skip to content

Commit 320df8d

Browse files
committed
Add flag to CheckParameterValue to inform whether it can be replaced by -inf
1 parent 4cdc7ba commit 320df8d

File tree

4 files changed

+48
-19
lines changed

4 files changed

+48
-19
lines changed

pymc/distributions/dist_math.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,21 @@
5050
}
5151

5252

53-
def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str = ""):
54-
"""
55-
Wrap a log probability graph in a CheckParameterValue that asserts several
56-
conditions are True. When conditions are not met a ParameterValueError assertion is
57-
raised, with an optional custom message defined by `msg`
53+
def check_parameters(
54+
expr: Variable,
55+
*conditions: Iterable[Variable],
56+
msg: str = "",
57+
can_be_replaced_by_ninf: bool = True,
58+
):
59+
"""Wrap an expression in a CheckParameterValue that asserts several conditions are met.
60+
61+
When conditions are not met a ParameterValueError assertion is raised,
62+
with an optional custom message defined by `msg`.
5863
59-
Note that check_parameter should not be used to enforce the logic of the logp
64+
When the flag `can_be_replaced_by_ninf` is True (default), PyMC is allowed to replace the
65+
assertion by a switch(condition, expr, -inf). This is used for logp graphs!
66+
67+
Note that check_parameter should not be used to enforce the logic of the
6068
expression under the normal parameter support as it can be disabled by the user via
6169
check_bounds = False in pm.Model()
6270
"""
@@ -65,7 +73,8 @@ def check_parameters(logp: Variable, *conditions: Iterable[Variable], msg: str =
6573
cond if (cond is not True and cond is not False) else np.array(cond) for cond in conditions
6674
]
6775
all_true_scalar = at.all([at.all(cond) for cond in conditions_])
68-
return CheckParameterValue(msg)(logp, all_true_scalar)
76+
77+
return CheckParameterValue(msg, can_be_replaced_by_ninf)(expr, all_true_scalar)
6978

7079

7180
def logpow(x, m):

pymc/logprob/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,8 +210,11 @@ class CheckParameterValue(CheckAndRaise):
210210
Raises `ParameterValueError` if the check is not True.
211211
"""
212212

213-
def __init__(self, msg=""):
213+
__props__ = ("msg", "exc_type", "can_be_replaced_by_ninf")
214+
215+
def __init__(self, msg: str = "", can_be_replaced_by_ninf: bool = False):
214216
super().__init__(ParameterValueError, msg)
217+
self.can_be_replaced_by_ninf = can_be_replaced_by_ninf
215218

216219
def __str__(self):
217220
return f"Check{{{self.msg}}}"

pymc/pytensorf.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -913,19 +913,21 @@ def local_remove_check_parameter(fgraph, node):
913913

914914
@node_rewriter(tracks=[CheckParameterValue])
915915
def local_check_parameter_to_ninf_switch(fgraph, node):
916-
if isinstance(node.op, CheckParameterValue):
917-
logp_expr, *logp_conds = node.inputs
918-
if len(logp_conds) > 1:
919-
logp_cond = at.all(logp_conds)
920-
else:
921-
(logp_cond,) = logp_conds
922-
out = at.switch(logp_cond, logp_expr, -np.inf)
923-
out.name = node.op.msg
916+
if not node.op.can_be_replaced_by_ninf:
917+
return None
918+
919+
logp_expr, *logp_conds = node.inputs
920+
if len(logp_conds) > 1:
921+
logp_cond = at.all(logp_conds)
922+
else:
923+
(logp_cond,) = logp_conds
924+
out = at.switch(logp_cond, logp_expr, -np.inf)
925+
out.name = node.op.msg
924926

925-
if out.dtype != node.outputs[0].dtype:
926-
out = at.cast(out, node.outputs[0].dtype)
927+
if out.dtype != node.outputs[0].dtype:
928+
out = at.cast(out, node.outputs[0].dtype)
927929

928-
return [out]
930+
return [out]
929931

930932

931933
pytensor.compile.optdb["canonicalize"].register(

tests/test_pytensorf.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,21 @@ def test_check_bounds_flag(self):
326326
with m:
327327
assert np.all(compile_pymc([], bound)() == -np.inf)
328328

329+
def test_check_parameters_can_be_replaced_by_ninf(self):
330+
expr = at.vector("expr", shape=(3,))
331+
cond = at.ge(expr, 0)
332+
333+
final_expr = check_parameters(expr, cond, can_be_replaced_by_ninf=True)
334+
fn = compile_pymc([expr], final_expr)
335+
np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3])
336+
np.testing.assert_array_equal(fn(expr=[-1, 2, 3]), [-np.inf, -np.inf, -np.inf])
337+
338+
final_expr = check_parameters(expr, cond, msg="test", can_be_replaced_by_ninf=False)
339+
fn = compile_pymc([expr], final_expr)
340+
np.testing.assert_array_equal(fn(expr=[1, 2, 3]), [1, 2, 3])
341+
with pytest.raises(ParameterValueError, match="test"):
342+
fn([-1, 2, 3])
343+
329344
def test_compile_pymc_sets_rng_updates(self):
330345
rng = pytensor.shared(np.random.default_rng(0))
331346
x = pm.Normal.dist(rng=rng)

0 commit comments

Comments
 (0)