Skip to content

Commit 9f8ea52

Browse files
committed
Derive probability for broadcast operation
1 parent 413af04 commit 9f8ea52

File tree

5 files changed

+174
-0
lines changed

5 files changed

+174
-0
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ jobs:
112112
tests/logprob/test_mixture.py
113113
tests/logprob/test_rewriting.py
114114
tests/logprob/test_scan.py
115+
tests/logprob/test_shape.py
115116
tests/logprob/test_tensor.py
116117
tests/logprob/test_transforms.py
117118
tests/logprob/test_utils.py

mypy.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ disallow_untyped_defs = False
1010
disallow_untyped_decorators = False
1111
ignore_missing_imports = True
1212
warn_unused_ignores = False
13+
disable_error_code = annotation-unchecked

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import pymc.logprob.checks
5151
import pymc.logprob.mixture
5252
import pymc.logprob.scan
53+
import pymc.logprob.shape
5354
import pymc.logprob.tensor
5455
import pymc.logprob.transforms
5556

pymc/logprob/shape.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Optional
15+
16+
import numpy as np
17+
import pytensor.tensor as pt
18+
19+
from pytensor.graph import node_rewriter
20+
from pytensor.tensor.extra_ops import BroadcastTo
21+
22+
from pymc.logprob.abstract import MeasurableVariable, _logprob, _logprob_helper
23+
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db
24+
25+
26+
class MeasurableBroadcast(BroadcastTo):
27+
pass
28+
29+
30+
MeasurableVariable.register(MeasurableBroadcast)
31+
32+
33+
measurable_broadcast = MeasurableBroadcast()
34+
35+
36+
@_logprob.register(MeasurableBroadcast)
37+
def broadcast_logprob(op, values, rv, *shape, **kwargs):
38+
"""Log-probability expression for (statically-)broadcasted RV
39+
40+
The probability is the same as the base RV, if no broadcasting had happened:
41+
42+
``logp(broadcast_to(normal(size=(3, 1)), (2, 3, 4)), zeros((2, 3, 4))) == logp(normal(size=(3, 1)), zeros((3, 1)))``
43+
44+
And zero if the value couldn't have possibly originated via broadcasting:
45+
46+
``logp(broadcast_to(normal(size=(1,)), (3,)), [1, 2, 3]) == [-np.inf]``
47+
48+
"""
49+
[value] = values
50+
51+
n_new_dims = len(shape) - rv.ndim
52+
assert n_new_dims >= 0
53+
54+
# Enumerate broadcasted dims
55+
expanded_dims = tuple(range(n_new_dims))
56+
broadcast_dims = tuple(
57+
i + n_new_dims
58+
for i, (v_bcast, rv_bcast) in enumerate(
59+
zip(value.broadcastable[n_new_dims:], rv.broadcastable)
60+
)
61+
if (not v_bcast) and rv_bcast
62+
)
63+
64+
# "Unbroadcast" value via indexing.
65+
# All entries in the broadcasted dimensions should be the same, so we simply select the first of each.
66+
indices = []
67+
for i in range(value.ndim):
68+
# Remove expanded dims
69+
if i in expanded_dims:
70+
indices.append(0)
71+
# Keep first entry of broadcasted (but not expanded) dims
72+
elif i in broadcast_dims:
73+
indices.append(slice(0, 1))
74+
else:
75+
indices.append(slice(None))
76+
77+
unbroadcast_value = value[tuple(indices)]
78+
logp = _logprob_helper(rv, unbroadcast_value)
79+
80+
# Check that dependent values were indeed identical, by comparing with a re-broadcasted value
81+
valid_value = pt.broadcast_to(unbroadcast_value, shape)
82+
# Note: This could fail due to float-precision issues.
83+
# If that proves to be a problem we should switch to `pt.allclose`
84+
check = pt.all(pt.eq(value, valid_value))
85+
logp = pt.switch(check, logp, -np.inf)
86+
87+
# Reintroduce expanded_dims in the returned logp
88+
if n_new_dims > 0:
89+
logp = pt.shape_padleft(logp, n_new_dims)
90+
91+
return logp
92+
93+
94+
@node_rewriter([BroadcastTo])
95+
def find_measurable_broadcast(fgraph, node):
96+
r"""Finds `BroadcastTo`\s for which a `logprob` can be computed."""
97+
98+
if isinstance(node.op, MeasurableBroadcast):
99+
return None # pragma: no cover
100+
101+
rv_map_feature: Optional[PreserveRVMappings] = getattr(fgraph, "preserve_rv_mappings", None)
102+
103+
if rv_map_feature is None:
104+
return None # pragma: no cover
105+
106+
base_rv, *shape = node.inputs
107+
108+
if not rv_map_feature.request_measurable([base_rv]):
109+
return None
110+
111+
new_rv = measurable_broadcast.make_node(base_rv, *shape).default_output()
112+
113+
return [new_rv]
114+
115+
116+
measurable_ir_rewrites_db.register(
117+
"find_measurable_broadcast",
118+
find_measurable_broadcast,
119+
"basic",
120+
"shape",
121+
)

tests/logprob/test_shape.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytensor
16+
import pytensor.tensor as pt
17+
import scipy.stats as st
18+
19+
from pymc import logp
20+
21+
22+
def test_measurable_broadcast():
23+
b_shape = pt.vector("b_shape", shape=(3,), dtype=int)
24+
25+
x = pt.random.normal(size=(3, 1))
26+
bcast_x = pt.broadcast_to(x, shape=b_shape)
27+
bcast_x.name = "bcast_x"
28+
29+
bcast_x_value = bcast_x.clone()
30+
logp_bcast_x = logp(bcast_x, bcast_x_value)
31+
logp_fn = pytensor.function([b_shape, bcast_x_value], logp_bcast_x, on_unused_input="ignore")
32+
33+
# assert_allclose also asserts shapes match (if neither is scalar)
34+
np.testing.assert_allclose(
35+
logp_fn([1, 3, 1], np.zeros((1, 3, 1))),
36+
st.norm.logpdf(np.zeros((1, 3, 1))),
37+
)
38+
np.testing.assert_allclose(
39+
logp_fn([1, 3, 5], np.zeros((1, 3, 5))),
40+
st.norm.logpdf(np.zeros((1, 3, 1))),
41+
)
42+
np.testing.assert_allclose(
43+
logp_fn([2, 3, 5], np.broadcast_to(np.arange(3).reshape(1, 3, 1), (2, 3, 5))),
44+
st.norm.logpdf(np.arange(3).reshape(1, 3, 1)),
45+
)
46+
# Invalid broadcast value
47+
np.testing.assert_array_equal(
48+
logp_fn([1, 3, 5], np.arange(3 * 5).reshape(1, 3, 5)),
49+
np.full(shape=(1, 3, 1), fill_value=-np.inf),
50+
)

0 commit comments

Comments
 (0)