Skip to content

Commit 9b712bf

Browse files
authored
Derive logprob of less and greater than comparisons (#6662)
1 parent f2bb88b commit 9b712bf

File tree

7 files changed

+237
-14
lines changed

7 files changed

+237
-14
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ jobs:
104104
tests/distributions/test_truncated.py
105105
tests/logprob/test_abstract.py
106106
tests/logprob/test_basic.py
107+
tests/logprob/test_binary.py
107108
tests/logprob/test_censoring.py
108109
tests/logprob/test_composite_logprob.py
109110
tests/logprob/test_cumsum.py

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
# isort: off
4040
# Add rewrites to the DBs
41+
import pymc.logprob.binary
4142
import pymc.logprob.censoring
4243
import pymc.logprob.cumsum
4344
import pymc.logprob.checks

pymc/logprob/binary.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import List, Optional
15+
16+
import numpy as np
17+
import pytensor.tensor as pt
18+
19+
from pytensor.graph.basic import Node
20+
from pytensor.graph.fg import FunctionGraph
21+
from pytensor.graph.rewriting.basic import node_rewriter
22+
from pytensor.scalar.basic import GT, LT
23+
from pytensor.tensor.math import gt, lt
24+
25+
from pymc.logprob.abstract import (
26+
MeasurableElemwise,
27+
MeasurableVariable,
28+
_logcdf_helper,
29+
_logprob,
30+
_logprob_helper,
31+
)
32+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
33+
from pymc.logprob.utils import check_potential_measurability, ignore_logprob
34+
35+
36+
class MeasurableComparison(MeasurableElemwise):
37+
"""A placeholder used to specify a log-likelihood for a binary comparison RV sub-graph."""
38+
39+
valid_scalar_types = (GT, LT)
40+
41+
42+
@node_rewriter(tracks=[gt, lt])
43+
def find_measurable_comparisons(
44+
fgraph: FunctionGraph, node: Node
45+
) -> Optional[List[MeasurableComparison]]:
46+
rv_map_feature = getattr(fgraph, "preserve_rv_mappings", None)
47+
if rv_map_feature is None:
48+
return None # pragma: no cover
49+
50+
if isinstance(node.op, MeasurableComparison):
51+
return None # pragma: no cover
52+
53+
(compared_var,) = node.outputs
54+
base_var, const = node.inputs
55+
56+
if not (
57+
base_var.owner
58+
and isinstance(base_var.owner.op, MeasurableVariable)
59+
and base_var not in rv_map_feature.rv_values
60+
):
61+
return None
62+
63+
# check for potential measurability of const
64+
if not check_potential_measurability((const,), rv_map_feature):
65+
return None
66+
67+
# Make base_var unmeasurable
68+
unmeasurable_base_var = ignore_logprob(base_var)
69+
70+
compared_op = MeasurableComparison(node.op.scalar_op)
71+
compared_rv = compared_op.make_node(unmeasurable_base_var, const).default_output()
72+
compared_rv.name = compared_var.name
73+
return [compared_rv]
74+
75+
76+
measurable_ir_rewrites_db.register(
77+
"find_measurable_comparisons",
78+
find_measurable_comparisons,
79+
"basic",
80+
"comparison",
81+
)
82+
83+
84+
@_logprob.register(MeasurableComparison)
85+
def comparison_logprob(op, values, base_rv, operand, **kwargs):
86+
(value,) = values
87+
88+
base_rv_op = base_rv.owner.op
89+
90+
logcdf = _logcdf_helper(base_rv, operand, **kwargs)
91+
logccdf = pt.log1mexp(logcdf)
92+
93+
condn_exp = pt.eq(value, np.array(True))
94+
95+
if isinstance(op.scalar_op, GT):
96+
logprob = pt.switch(condn_exp, logccdf, logcdf)
97+
elif isinstance(op.scalar_op, LT):
98+
if base_rv.dtype.startswith("int"):
99+
logpmf = _logprob_helper(base_rv, operand, **kwargs)
100+
logcdf_lt_true = _logcdf_helper(base_rv, operand - 1, **kwargs)
101+
logprob = pt.switch(condn_exp, logcdf_lt_true, pt.logaddexp(logccdf, logpmf))
102+
else:
103+
logprob = pt.switch(condn_exp, logcdf, logccdf)
104+
else:
105+
raise TypeError(f"Unsupported scalar_op {op.scalar_op}")
106+
107+
if base_rv_op.name:
108+
logprob.name = f"{base_rv_op}_logprob"
109+
logcdf.name = f"{base_rv_op}_logcdf"
110+
111+
return logprob

pymc/logprob/transforms.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@
8585
_logprob_helper,
8686
)
8787
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
88-
from pymc.logprob.utils import ignore_logprob, walk_model
88+
from pymc.logprob.utils import check_potential_measurability, ignore_logprob
8989

9090

9191
class TransformedVariable(Op):
@@ -573,19 +573,8 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Node) -> Optional[Li
573573
# Check that other inputs are not potentially measurable, in which case this rewrite
574574
# would be invalid
575575
other_inputs = tuple(inp for inp in node.inputs if inp is not measurable_input)
576-
if any(
577-
ancestor_node
578-
for ancestor_node in walk_model(
579-
other_inputs,
580-
walk_past_rvs=False,
581-
stop_at_vars=set(rv_map_feature.rv_values),
582-
)
583-
if (
584-
ancestor_node.owner
585-
and isinstance(ancestor_node.owner.op, MeasurableVariable)
586-
and ancestor_node not in rv_map_feature.rv_values
587-
)
588-
):
576+
577+
if not check_potential_measurability(other_inputs, rv_map_feature):
589578
return None
590579

591580
# Make base_measure outputs unmeasurable

pymc/logprob/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,24 @@ def indices_from_subtensor(idx_list, indices):
210210
)
211211

212212

213+
def check_potential_measurability(inputs: Tuple[TensorVariable], rv_map_feature):
214+
if any(
215+
ancestor_node
216+
for ancestor_node in walk_model(
217+
inputs,
218+
walk_past_rvs=False,
219+
stop_at_vars=set(rv_map_feature.rv_values),
220+
)
221+
if (
222+
ancestor_node.owner
223+
and isinstance(ancestor_node.owner.op, MeasurableVariable)
224+
and ancestor_node not in rv_map_feature.rv_values
225+
)
226+
):
227+
return None
228+
return True
229+
230+
213231
class ParameterValueError(ValueError):
214232
"""Exception for invalid parameters values in logprob graphs"""
215233

scripts/run_mypy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
pymc/distributions/timeseries.py
3030
pymc/distributions/truncated.py
3131
pymc/initial_point.py
32+
pymc/logprob/binary.py
3233
pymc/logprob/censoring.py
3334
pymc/logprob/basic.py
3435
pymc/logprob/mixture.py

tests/logprob/test_binary.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor
16+
import pytensor.tensor as pt
17+
import pytest
18+
import scipy.stats as st
19+
20+
from pytensor import function
21+
22+
from pymc import logp
23+
from pymc.logprob import factorized_joint_logprob
24+
from pymc.testing import assert_no_rvs
25+
26+
27+
@pytest.mark.parametrize(
28+
"comparison_op, exp_logp_true, exp_logp_false",
29+
[
30+
(pt.lt, st.norm(0, 1).logcdf, st.norm(0, 1).logsf),
31+
(pt.gt, st.norm(0, 1).logsf, st.norm(0, 1).logcdf),
32+
],
33+
)
34+
def test_continuous_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
35+
x_rv = pt.random.normal(0, 1)
36+
comp_x_rv = comparison_op(x_rv, 0.5)
37+
38+
comp_x_vv = comp_x_rv.clone()
39+
40+
logprob = logp(comp_x_rv, comp_x_vv)
41+
assert_no_rvs(logprob)
42+
43+
logp_fn = pytensor.function([comp_x_vv], logprob)
44+
45+
assert np.isclose(logp_fn(0), exp_logp_false(0.5))
46+
assert np.isclose(logp_fn(1), exp_logp_true(0.5))
47+
48+
49+
@pytest.mark.parametrize(
50+
"comparison_op, exp_logp_true, exp_logp_false",
51+
[
52+
(
53+
pt.lt,
54+
lambda x: st.poisson(2).logcdf(x - 1),
55+
lambda x: np.logaddexp(st.poisson(2).logsf(x), st.poisson(2).logpmf(x)),
56+
),
57+
(
58+
pt.gt,
59+
st.poisson(2).logsf,
60+
st.poisson(2).logcdf,
61+
),
62+
],
63+
)
64+
def test_discrete_rv_comparison(comparison_op, exp_logp_true, exp_logp_false):
65+
x_rv = pt.random.poisson(2)
66+
cens_x_rv = comparison_op(x_rv, 3)
67+
68+
cens_x_vv = cens_x_rv.clone()
69+
70+
logprob = logp(cens_x_rv, cens_x_vv)
71+
assert_no_rvs(logprob)
72+
73+
logp_fn = pytensor.function([cens_x_vv], logprob)
74+
75+
assert np.isclose(logp_fn(1), exp_logp_true(3))
76+
assert np.isclose(logp_fn(0), exp_logp_false(3))
77+
78+
79+
def test_potentially_measurable_operand():
80+
x_rv = pt.random.normal(2)
81+
z_rv = pt.random.normal(x_rv)
82+
y_rv = pt.lt(x_rv, z_rv)
83+
84+
y_vv = y_rv.clone()
85+
z_vv = z_rv.clone()
86+
87+
logprob = factorized_joint_logprob({z_rv: z_vv, y_rv: y_vv})[y_vv]
88+
assert_no_rvs(logprob)
89+
90+
fn = function([z_vv, y_vv], logprob)
91+
z_vv_test = 0.5
92+
y_vv_test = True
93+
np.testing.assert_array_almost_equal(
94+
fn(z_vv_test, y_vv_test),
95+
st.norm(2, 1).logcdf(z_vv_test),
96+
)
97+
98+
with pytest.raises(
99+
NotImplementedError,
100+
match="Logprob method not implemented",
101+
):
102+
logp(y_rv, y_vv).eval({y_vv: y_vv_test})

0 commit comments

Comments
 (0)