Skip to content

Commit ae9fcac

Browse files
Implemented logprob for SpecifyShape and CheckandRaise (#6538)
1 parent f7861b5 commit ae9fcac

File tree

3 files changed

+257
-0
lines changed

3 files changed

+257
-0
lines changed

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
# Add rewrites to the DBs
4141
import pymc.logprob.censoring
4242
import pymc.logprob.cumsum
43+
import pymc.logprob.checks
4344
import pymc.logprob.mixture
4445
import pymc.logprob.scan
4546
import pymc.logprob.tensor

pymc/logprob/checks.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# MIT License
16+
#
17+
# Copyright (c) 2021-2022 aesara-devs
18+
#
19+
# Permission is hereby granted, free of charge, to any person obtaining a copy
20+
# of this software and associated documentation files (the "Software"), to deal
21+
# in the Software without restriction, including without limitation the rights
22+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23+
# copies of the Software, and to permit persons to whom the Software is
24+
# furnished to do so, subject to the following conditions:
25+
#
26+
# The above copyright notice and this permission notice shall be included in all
27+
# copies or substantial portions of the Software.
28+
#
29+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35+
# SOFTWARE.
36+
37+
from typing import List, Optional
38+
39+
import pytensor.tensor as pt
40+
41+
from pytensor.graph.rewriting.basic import node_rewriter
42+
from pytensor.raise_op import CheckAndRaise, ExceptionType
43+
from pytensor.tensor.shape import SpecifyShape
44+
45+
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
46+
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
47+
from pymc.logprob.utils import ignore_logprob
48+
49+
50+
class MeasurableSpecifyShape(SpecifyShape):
51+
"""A placeholder used to specify a log-likelihood for a specify-shape sub-graph."""
52+
53+
54+
MeasurableVariable.register(MeasurableSpecifyShape)
55+
56+
57+
@_logprob.register(MeasurableSpecifyShape)
58+
def logprob_specify_shape(op, values, inner_rv, *shapes, **kwargs):
59+
(value,) = values
60+
# transfer specify_shape from rv to value
61+
value = pt.specify_shape(value, shapes)
62+
return _logprob_helper(inner_rv, value)
63+
64+
65+
@node_rewriter([SpecifyShape])
66+
def find_measurable_specify_shapes(fgraph, node) -> Optional[List[MeasurableSpecifyShape]]:
67+
r"""Finds `SpecifyShapeOp`\s for which a `logprob` can be computed."""
68+
69+
if isinstance(node.op, MeasurableSpecifyShape):
70+
return None # pragma: no cover
71+
72+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
73+
74+
if rv_map_feature is None:
75+
return None # pragma: no cover
76+
77+
rv = node.outputs[0]
78+
79+
base_rv, *shape = node.inputs
80+
81+
if not (
82+
base_rv.owner
83+
and isinstance(base_rv.owner.op, MeasurableVariable)
84+
and base_rv not in rv_map_feature.rv_values
85+
):
86+
return None # pragma: no cover
87+
88+
new_op = MeasurableSpecifyShape()
89+
# Make base_var unmeasurable
90+
unmeasurable_base_rv = ignore_logprob(base_rv)
91+
new_rv = new_op.make_node(unmeasurable_base_rv, *shape).default_output()
92+
new_rv.name = rv.name
93+
94+
return [new_rv]
95+
96+
97+
measurable_ir_rewrites_db.register(
98+
"find_measurable_specify_shapes",
99+
find_measurable_specify_shapes,
100+
"basic",
101+
"specify_shape",
102+
)
103+
104+
105+
class MeasurableCheckAndRaise(CheckAndRaise):
106+
"""A placeholder used to specify a log-likelihood for an assert sub-graph."""
107+
108+
109+
MeasurableVariable.register(MeasurableCheckAndRaise)
110+
111+
112+
@_logprob.register(MeasurableCheckAndRaise)
113+
def logprob_assert(op, values, inner_rv, *assertion, **kwargs):
114+
(value,) = values
115+
# transfer assertion from rv to value
116+
value = op(assertion, value)
117+
return _logprob_helper(inner_rv, value)
118+
119+
120+
@node_rewriter([CheckAndRaise])
121+
def find_measurable_asserts(fgraph, node) -> Optional[List[MeasurableCheckAndRaise]]:
122+
r"""Finds `AssertOp`\s for which a `logprob` can be computed."""
123+
124+
if isinstance(node.op, MeasurableCheckAndRaise):
125+
return None # pragma: no cover
126+
127+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
128+
129+
if rv_map_feature is None:
130+
return None # pragma: no cover
131+
132+
rv = node.outputs[0]
133+
134+
base_rv, *conds = node.inputs
135+
136+
if not (
137+
base_rv.owner
138+
and isinstance(base_rv.owner.op, MeasurableVariable)
139+
and base_rv not in rv_map_feature.rv_values
140+
):
141+
return None # pragma: no cover
142+
143+
exception_type = ExceptionType()
144+
new_op = MeasurableCheckAndRaise(exc_type=exception_type)
145+
# Make base_var unmeasurable
146+
unmeasurable_base_rv = ignore_logprob(base_rv)
147+
new_rv = new_op.make_node(unmeasurable_base_rv, *conds).default_output()
148+
new_rv.name = rv.name
149+
150+
return [new_rv]
151+
152+
153+
measurable_ir_rewrites_db.register(
154+
"find_measurable_asserts",
155+
find_measurable_asserts,
156+
"basic",
157+
"assert",
158+
)

tests/logprob/test_checks.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
#
15+
# MIT License
16+
#
17+
# Copyright (c) 2021-2022 aesara-devs
18+
#
19+
# Permission is hereby granted, free of charge, to any person obtaining a copy
20+
# of this software and associated documentation files (the "Software"), to deal
21+
# in the Software without restriction, including without limitation the rights
22+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
23+
# copies of the Software, and to permit persons to whom the Software is
24+
# furnished to do so, subject to the following conditions:
25+
#
26+
# The above copyright notice and this permission notice shall be included in all
27+
# copies or substantial portions of the Software.
28+
#
29+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
30+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
31+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
32+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
33+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
34+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
35+
# SOFTWARE.
36+
37+
import numpy as np
38+
import pytensor
39+
import pytensor.tensor as pt
40+
import pytest
41+
42+
from pytensor.raise_op import Assert
43+
from scipy import stats
44+
45+
from pymc.distributions import Dirichlet
46+
from pymc.logprob.joint_logprob import factorized_joint_logprob
47+
from tests.distributions.test_multivariate import dirichlet_logpdf
48+
49+
50+
def test_specify_shape_logprob():
51+
# 1. Create graph using SpecifyShape
52+
# Use symbolic last dimension, so that SpecifyShape is not useless
53+
last_dim = pt.scalar(name="last_dim", dtype="int64")
54+
x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim))
55+
x_base.name = "x"
56+
x_rv = pt.specify_shape(x_base, shape=(5, 3))
57+
x_rv.name = "x"
58+
59+
# 2. Request logp
60+
x_vv = x_rv.clone()
61+
[x_logp] = factorized_joint_logprob({x_rv: x_vv}).values()
62+
63+
# 3. Test logp
64+
x_logp_fn = pytensor.function([last_dim, x_vv], x_logp)
65+
66+
# 3.1 Test valid logp
67+
x_vv_test = stats.dirichlet(np.ones((3,))).rvs(size=(5,))
68+
np.testing.assert_array_almost_equal(
69+
x_logp_fn(last_dim=3, x=x_vv_test),
70+
dirichlet_logpdf(x_vv_test, np.ones((3,))),
71+
)
72+
73+
# 3.2 Test shape error
74+
x_vv_test_invalid = stats.dirichlet(np.ones((1,))).rvs(size=(5,))
75+
with pytest.raises(TypeError, match=re.escape("not compatible with the data's ((5, 1))")):
76+
x_logp_fn(last_dim=1, x=x_vv_test_invalid)
77+
78+
79+
def test_assert_logprob():
80+
rv = pt.random.normal()
81+
assert_op = Assert("Test assert")
82+
# Example: Add assert that rv must be positive
83+
assert_rv = assert_op(rv > 0, rv)
84+
assert_rv.name = "assert_rv"
85+
86+
assert_vv = assert_rv.clone()
87+
assert_logp = factorized_joint_logprob({assert_rv: assert_vv})[assert_vv]
88+
89+
# Check valid value is correct and doesn't raise
90+
# Since here the value to the rv satisfies the condition, no error is raised.
91+
valid_value = 3.0
92+
with pytest.raises(AssertionError, match="Test assert"):
93+
assert_logp.eval({assert_vv: valid_value})
94+
95+
# Check invalid value
96+
# Since here the value to the rv is negative, an exception is raised as the condition is not met
97+
with pytest.raises(AssertionError, match="Test assert"):
98+
assert_logp.eval({assert_vv: -5.0})

0 commit comments

Comments
 (0)