Skip to content

Commit 46f8e2f

Browse files
committed
Allow logcdf and icdf inference
1 parent ae9fcac commit 46f8e2f

File tree

3 files changed

+53
-5
lines changed

3 files changed

+53
-5
lines changed

pymc/logprob/basic.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,13 +80,25 @@ def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
8080
def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
8181
"""Create a graph for the log-CDF of a Random Variable."""
8282
value = pt.as_tensor_variable(value, dtype=rv.dtype)
83-
return _logcdf_helper(rv, value, **kwargs)
83+
try:
84+
return _logcdf_helper(rv, value, **kwargs)
85+
except NotImplementedError:
86+
# Try to rewrite rv
87+
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
88+
[ir_rv] = fgraph.outputs
89+
return _logcdf_helper(ir_rv, value, **kwargs)
8490

8591

8692
def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable:
8793
"""Create a graph for the inverse CDF of a Random Variable."""
88-
value = pt.as_tensor_variable(value)
89-
return _icdf_helper(rv, value, **kwargs)
94+
value = pt.as_tensor_variable(value, dtype=rv.dtype)
95+
try:
96+
return _icdf_helper(rv, value, **kwargs)
97+
except NotImplementedError:
98+
# Try to rewrite rv
99+
fgraph, rv_values, _ = construct_ir_fgraph({rv: value})
100+
[ir_rv] = fgraph.outputs
101+
return _icdf_helper(ir_rv, value, **kwargs)
90102

91103

92104
def factorized_joint_logprob(

pymc/logprob/transforms.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,10 @@
7777
MeasurableElemwise,
7878
MeasurableVariable,
7979
_get_measurable_outputs,
80+
_icdf,
81+
_icdf_helper,
82+
_logcdf,
83+
_logcdf_helper,
8084
_logprob,
8185
_logprob_helper,
8286
)
@@ -390,6 +394,38 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
390394
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
391395

392396

397+
@_logcdf.register(MeasurableTransform)
398+
def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwargs):
399+
"""Compute the log-CDF graph for a `MeasurabeTransform`."""
400+
other_inputs = list(inputs)
401+
measurable_input = other_inputs.pop(op.measurable_input_idx)
402+
403+
backward_value = op.transform_elemwise.backward(value, *other_inputs)
404+
405+
# Some transformations, like squaring may produce multiple backward values
406+
if isinstance(backward_value, tuple):
407+
raise NotImplementedError
408+
409+
input_logcdf = _logcdf_helper(measurable_input, backward_value)
410+
411+
# The jacobian is used to ensure a value in the supported domain was provided
412+
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
413+
414+
return pt.switch(pt.isnan(jacobian), -np.inf, input_logcdf)
415+
416+
417+
@_icdf.register(MeasurableTransform)
418+
def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs):
419+
"""Compute the inverse CDF graph for a `MeasurabeTransform`."""
420+
other_inputs = list(inputs)
421+
measurable_input = other_inputs.pop(op.measurable_input_idx)
422+
423+
input_icdf = _icdf_helper(measurable_input, value)
424+
icdf = op.transform_elemwise.forward(input_icdf, *other_inputs)
425+
426+
return icdf
427+
428+
393429
@node_rewriter([reciprocal])
394430
def measurable_reciprocal_to_power(fgraph, node):
395431
"""Convert reciprocal of `MeasurableVariable`s to power."""

tests/logprob/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,8 +432,8 @@ def test_probability_direct_dispatch(func, scipy_func):
432432
"func, scipy_func, test_value",
433433
[
434434
(logp, "logpdf", 5.0),
435-
pytest.param(logcdf, "logcdf", 5.0, marks=pytest.mark.xfail(raises=NotImplementedError)),
436-
pytest.param(icdf, "ppf", 0.7, marks=pytest.mark.xfail(raises=NotImplementedError)),
435+
(logcdf, "logcdf", 5.0),
436+
(icdf, "ppf", 0.7),
437437
],
438438
)
439439
def test_probability_inference(func, scipy_func, test_value):

0 commit comments

Comments
 (0)