Skip to content

Commit 1249c86

Browse files
committed
Derive matmul probability
1 parent c3c1d94 commit 1249c86

File tree

6 files changed

+204
-2
lines changed

6 files changed

+204
-2
lines changed

.github/workflows/tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ jobs:
116116
tests/logprob/test_censoring.py
117117
tests/logprob/test_composite_logprob.py
118118
tests/logprob/test_cumsum.py
119+
tests/logprob/test_linalg.py
119120
tests/logprob/test_mixture.py
120121
tests/logprob/test_order.py
121122
tests/logprob/test_rewriting.py

pymc/logprob/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
import pymc.logprob.censoring
5050
import pymc.logprob.cumsum
5151
import pymc.logprob.checks
52+
import pymc.logprob.linalg
5253
import pymc.logprob.mixture
5354
import pymc.logprob.order
5455
import pymc.logprob.scan

pymc/logprob/abstract.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from pytensor.graph import Apply, Op, Variable
4444
from pytensor.graph.utils import MetaType
4545
from pytensor.tensor import TensorVariable
46+
from pytensor.tensor.blockwise import Blockwise
4647
from pytensor.tensor.elemwise import Elemwise
4748
from pytensor.tensor.random.op import RandomVariable
4849

@@ -168,6 +169,10 @@ def __str__(self):
168169
return f"Measurable{super().__str__()}"
169170

170171

172+
class MeasurableBlockwise(MeasurableOp, Blockwise):
173+
"""Base class for Measurable Blockwise variables."""
174+
175+
171176
class ValuedRV(Op):
172177
r"""Represents the association of a measurable variable and its value.
173178

pymc/logprob/linalg.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pytensor.tensor as pt
15+
16+
from pytensor.graph.rewriting.basic import node_rewriter
17+
from pytensor.tensor.math import _matrix_matrix_matmul
18+
19+
from pymc.logprob.abstract import MeasurableBlockwise, MeasurableOp, _logprob, _logprob_helper
20+
from pymc.logprob.rewriting import measurable_ir_rewrites_db
21+
from pymc.logprob.utils import check_potential_measurability, filter_measurable_variables
22+
23+
24+
class MeasurableMatMul(MeasurableBlockwise):
25+
"""Measurable matrix multiplication operation."""
26+
27+
right_measurable: bool
28+
29+
def __init__(self, measurable_right: bool, **kwargs):
30+
self.right_measurable = measurable_right
31+
super().__init__(**kwargs)
32+
33+
34+
@_logprob.register(MeasurableMatMul)
35+
def logprob_measurable_matmul(op, values, l, r): # noqa: E741
36+
[y_value] = values
37+
if op.right_measurable:
38+
A, x = l, r
39+
x_value = pt.linalg.solve(A, y_value)
40+
else:
41+
x, A = l, r
42+
x_value = pt.linalg.solve(A.mT, y_value.mT).mT
43+
44+
x_logp = _logprob_helper(x, x_value)
45+
46+
# The operation has a support dimensionality of 2
47+
# We need to reduce it if it's still present in the base logp
48+
if x_logp.type.ndim == x_value.type.ndim:
49+
x_logp = pt.sum(x_logp, axis=(-1, -2))
50+
elif x_logp.type.ndim == x_value.type.ndim - 1:
51+
x_logp = pt.sum(x_logp, axis=-1)
52+
53+
_, log_abs_jac_det = pt.linalg.slogdet(A)
54+
55+
return x_logp - log_abs_jac_det
56+
57+
58+
@node_rewriter(tracks=[_matrix_matrix_matmul])
59+
def find_measurable_matmul(fgraph, node):
60+
"""Find measurable matrix-matrix multiplication operations."""
61+
if isinstance(node.op, MeasurableOp):
62+
return None
63+
64+
[out] = node.outputs
65+
[l, r] = node.inputs # noqa: E741
66+
67+
# Check that not both a and r are measurable
68+
measurable_inputs = filter_measurable_variables([l, r])
69+
if len(measurable_inputs) != 1:
70+
return None
71+
72+
[measurable_input] = measurable_inputs
73+
74+
# Check the measurable input is not broadcasted
75+
if measurable_input.type.broadcastable[:-2] != out.type.broadcastable[:-2]:
76+
return None
77+
78+
measurable_right = measurable_input is r
79+
A = l if measurable_right else r
80+
81+
# Check if the static shape already reveals a non-square matrix,
82+
if (
83+
A.type.shape[-1] is not None
84+
and A.type.shape[-2] is not None
85+
and A.type.shape[-1] != A.type.shape[-2]
86+
):
87+
return None
88+
89+
# Check the other input is not potentially measurable
90+
if check_potential_measurability([A]):
91+
return None
92+
93+
measurable_matmul = MeasurableMatMul(measurable_right=measurable_right, **node.op._props_dict())
94+
return [measurable_matmul(l, r)]
95+
96+
97+
measurable_ir_rewrites_db.register(
98+
find_measurable_matmul.__name__,
99+
find_measurable_matmul,
100+
"basic",
101+
"linalg",
102+
)

pymc/logprob/rewriting.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,11 @@ def remove_DiracDelta(fgraph, node):
152152
logprob_rewrites_db.register(
153153
"local_exp_over_1_plus_exp", out2in(local_exp_over_1_plus_exp), "basic"
154154
)
155-
logprob_rewrites_db.register("pre-canonicalize", optdb.query("+canonicalize"), "basic")
155+
logprob_rewrites_db.register(
156+
"pre-canonicalize",
157+
optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"),
158+
"basic",
159+
)
156160

157161
# These rewrites convert un-measurable variables into their measurable forms,
158162
# but they need to be reapplied, because some of the measurable forms require
@@ -175,7 +179,11 @@ def remove_DiracDelta(fgraph, node):
175179
)
176180

177181

178-
logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")
182+
logprob_rewrites_db.register(
183+
"post-canonicalize",
184+
optdb.query("+canonicalize", "-local_eager_useless_unbatched_blockwise"),
185+
"basic",
186+
)
179187

180188
# Rewrites that remove IR Ops
181189
cleanup_ir_rewrites_db = LocalGroupDB()

tests/logprob/test_linalg.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright 2024 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import numpy as np
15+
import pytest
16+
17+
from pytensor.tensor.type import tensor
18+
19+
from pymc.distributions import MatrixNormal, MvNormal, Normal
20+
from pymc.logprob.basic import logp
21+
22+
23+
@pytest.mark.parametrize("univariate", [True, False])
24+
@pytest.mark.parametrize("batch_shape", [(), (3,)])
25+
def test_matrix_vector_transform(univariate, batch_shape):
26+
rng = np.random.default_rng(755)
27+
28+
μ = rng.normal(size=(*batch_shape, 2))
29+
if univariate:
30+
σ = np.abs(rng.normal(size=(*batch_shape, 2)))
31+
Σ = np.eye(2) * (σ**2)[..., None]
32+
x = Normal.dist(mu=μ, sigma=σ)
33+
else:
34+
A = rng.normal(size=(*batch_shape, 2, 2))
35+
Σ = np.swapaxes(A, -1, -2) @ A
36+
x = MvNormal.dist(mu=μ, cov=Σ)
37+
38+
c = rng.normal(size=(*batch_shape, 2))
39+
B = rng.normal(size=(*batch_shape, 2, 2))
40+
y = c + (B @ x[..., None]).squeeze(-1)
41+
42+
# An affine transformed MvNormal is still a MvNormal
43+
# https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Affine_transformation
44+
ref_dist = MvNormal.dist(
45+
mu=c + (B @ μ[..., None]).squeeze(-1), cov=B @ Σ @ np.swapaxes(B, -1, -2)
46+
)
47+
test_y = rng.normal(size=(*batch_shape, 2))
48+
np.testing.assert_allclose(
49+
logp(y, test_y).eval(),
50+
logp(ref_dist, test_y).eval(),
51+
)
52+
53+
54+
def test_matrix_matrix_transform():
55+
rng = np.random.default_rng(46)
56+
57+
n, p = 2, 3
58+
M = rng.normal(size=(n, p))
59+
A = rng.normal(size=(n, n)) * 0.1
60+
U = A.T @ A
61+
B = rng.normal(size=(p, p)) * 0.1
62+
V = B.T @ B
63+
X = MatrixNormal.dist(mu=M, rowcov=U, colcov=V)
64+
65+
D = rng.normal(size=(n, n))
66+
C = rng.normal(size=(p, p))
67+
Y = D @ X @ C
68+
69+
# A linearly transformed MatrixNormal is still a MatrixNormal
70+
# https://en.wikipedia.org/wiki/Matrix_normal_distribution#Transformation
71+
ref_dist = MatrixNormal.dist(mu=D @ M @ C, rowcov=D @ U @ D.T, colcov=C.T @ V @ C)
72+
test_Y = rng.normal(size=(n, p))
73+
np.testing.assert_allclose(
74+
logp(Y, test_Y).eval(),
75+
logp(ref_dist, test_Y).eval(),
76+
rtol=1e-5,
77+
)
78+
79+
80+
def test_broadcasted_matmul_fails():
81+
x = Normal.dist(size=(3, 2))
82+
A = tensor("A", shape=(4, 3, 3))
83+
y = A @ x
84+
with pytest.raises(NotImplementedError):
85+
logp(y, y.type())

0 commit comments

Comments
 (0)