Skip to content

Commit 48e56c3

Browse files
ricardoV94theorashidelizavetasemenovaaseyboldt
committed
Implement specialized MvNormal density based on precision matrix
Co-authored-by: theorashid <[email protected]> Co-authored-by: elizavetasemenova <[email protected]> Co-authored-by: aseyboldt <[email protected]>
1 parent b407c01 commit 48e56c3

File tree

3 files changed

+127
-14
lines changed

3 files changed

+127
-14
lines changed

pymc/distributions/multivariate.py

Lines changed: 89 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import pytensor.tensor as pt
2525
import scipy
2626

27+
from pytensor.graph import node_rewriter
2728
from pytensor.graph.basic import Apply, Variable
2829
from pytensor.graph.op import Op
2930
from pytensor.raise_op import Assert
@@ -39,7 +40,7 @@
3940
from pytensor.tensor.exceptions import NotScalarConstantError
4041
from pytensor.tensor.linalg import cholesky, det, eigh, solve_triangular, trace
4142
from pytensor.tensor.linalg import inv as matrix_inverse
42-
from pytensor.tensor.random.basic import dirichlet, multinomial, multivariate_normal
43+
from pytensor.tensor.random.basic import MvNormalRV, dirichlet, multinomial, multivariate_normal
4344
from pytensor.tensor.random.op import RandomVariable
4445
from pytensor.tensor.random.utils import (
4546
broadcast_params,
@@ -77,6 +78,9 @@
7778
)
7879
from pymc.distributions.transforms import Interval, ZeroSumTransform, _default_transform
7980
from pymc.logprob.abstract import _logprob
81+
from pymc.logprob.rewriting import (
82+
specialization_ir_rewrites_db,
83+
)
8084
from pymc.math import kron_diag, kron_dot
8185
from pymc.pytensorf import normalize_rng_param
8286
from pymc.util import check_dist_not_registered
@@ -157,6 +161,13 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
157161
return cov
158162

159163

164+
def _logdet_from_cholesky(chol: TensorVariable) -> tuple[TensorVariable, TensorVariable]:
165+
diag = pt.diagonal(chol, axis1=-2, axis2=-1)
166+
logdet = pt.log(diag).sum(axis=-1)
167+
posdef = pt.all(diag > 0, axis=-1)
168+
return logdet, posdef
169+
170+
160171
def quaddist_chol(value, mu, cov):
161172
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
162173
if value.ndim == 0:
@@ -167,23 +178,23 @@ def quaddist_chol(value, mu, cov):
167178
else:
168179
onedim = False
169180

170-
delta = value - mu
171181
chol_cov = nan_lower_cholesky(cov)
182+
logdet, posdef = _logdet_from_cholesky(chol_cov)
172183

173-
diag = pt.diagonal(chol_cov, axis1=-2, axis2=-1)
174-
# Check if the covariance matrix is positive definite.
175-
ok = pt.all(diag > 0, axis=-1)
176-
# If not, replace the diagonal. We return -inf later, but
177-
# need to prevent solve_lower from throwing an exception.
178-
chol_cov = pt.switch(ok[..., None, None], chol_cov, 1)
184+
# solve_triangular will raise if there are nans
185+
# (which happens if the cholesky fails)
186+
chol_cov.dprint(print_type=True, depth=1)
187+
posdef.dprint(print_type=True, depth=1)
188+
chol_cov = pt.switch(posdef[..., None, None], chol_cov, 1)
189+
190+
delta = value - mu
179191
delta_trans = solve_lower(chol_cov, delta, b_ndim=1)
180192
quaddist = (delta_trans**2).sum(axis=-1)
181-
logdet = pt.log(diag).sum(axis=-1)
182193

183194
if onedim:
184-
return quaddist[0], logdet, ok
195+
return quaddist[0], logdet, posdef
185196
else:
186-
return quaddist, logdet, ok
197+
return quaddist, logdet, posdef
187198

188199

189200
class MvNormal(Continuous):
@@ -283,16 +294,80 @@ def logp(value, mu, cov):
283294
-------
284295
TensorVariable
285296
"""
286-
quaddist, logdet, ok = quaddist_chol(value, mu, cov)
297+
quaddist, logdet, posdef = quaddist_chol(value, mu, cov)
287298
k = value.shape[-1].astype("floatX")
288299
norm = -0.5 * k * np.log(2 * np.pi)
289300
return check_parameters(
290301
norm - 0.5 * quaddist - logdet,
291-
ok,
292-
msg="posdef",
302+
posdef,
303+
msg="posdef covariance",
293304
)
294305

295306

307+
class PrecisionMvNormalRV(SymbolicRandomVariable):
308+
r"""A specialized multivariate normal random variable defined in terms of precision.
309+
310+
This class is introduced during specialization logprob rewrites, and not meant to be used directly.
311+
"""
312+
313+
name = "precision_multivariate_normal"
314+
extended_signature = "[rng],[size],(n),(n,n)->(n)"
315+
_print_name = ("PrecisionMultivariateNormal", "\\operatorname{PrecisionMultivariateNormal}")
316+
317+
@classmethod
318+
def rv_op(cls, mean, tau, *, rng=None, size=None):
319+
rng = normalize_rng_param(rng)
320+
size = normalize_size_param(size)
321+
cov = pt.linalg.inv(tau)
322+
next_rng, draws = multivariate_normal(mean, cov, size=size, rng=rng).owner.outputs
323+
return cls(
324+
inputs=[rng, size, mean, tau],
325+
outputs=[next_rng, draws],
326+
)(rng, size, mean, tau)
327+
328+
329+
@_logprob.register
330+
def precision_mv_normal_logp(op: PrecisionMvNormalRV, value, rng, size, mean, tau, **kwargs):
331+
[value] = value
332+
k = value.shape[-1].astype("floatX")
333+
334+
delta = value - mean
335+
quadratic_form = delta.T @ tau @ delta
336+
logdet, posdef = _logdet_from_cholesky(nan_lower_cholesky(tau))
337+
logp = -0.5 * (k * pt.log(2 * np.pi) + quadratic_form) + logdet
338+
339+
return check_parameters(
340+
logp,
341+
posdef,
342+
msg="posdef precision",
343+
)
344+
345+
346+
@node_rewriter(tracks=[MvNormalRV])
347+
def mv_normal_to_precision_mv_normal(fgraph, node):
348+
"""Replaces MvNormal(mu, inv(tau)) -> PrecisionMvNormal(mu, tau)
349+
350+
This is introduced in logprob rewrites to provide a more efficient logp for a MvNormal
351+
that is defined by a precision matrix.
352+
353+
Note: This won't be introduced when calling `pm.logp` as that will dispatch directly
354+
without triggering the logprob rewrites.
355+
"""
356+
357+
rng, size, mu, cov = node.inputs
358+
if cov.owner and cov.owner.op == matrix_inverse:
359+
tau = cov.owner.inputs[0]
360+
return PrecisionMvNormalRV.rv_op(mu, tau, size=size, rng=rng).owner.outputs
361+
return None
362+
363+
364+
specialization_ir_rewrites_db.register(
365+
mv_normal_to_precision_mv_normal.__name__,
366+
mv_normal_to_precision_mv_normal,
367+
"basic",
368+
)
369+
370+
296371
class MvStudentTRV(RandomVariable):
297372
name = "multivariate_studentt"
298373
signature = "(),(n),(n,n)->(n)"

pymc/logprob/rewriting.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
out2in,
6161
)
6262
from pytensor.graph.rewriting.db import (
63+
EquilibriumDB,
6364
LocalGroupDB,
6465
RewriteDatabase,
6566
RewriteDatabaseQuery,
@@ -379,6 +380,14 @@ def incsubtensor_rv_replace(fgraph, node):
379380
measurable_ir_rewrites_db.register("subtensor_lift", local_subtensor_rv_lift, "basic")
380381
measurable_ir_rewrites_db.register("incsubtensor_lift", incsubtensor_rv_replace, "basic")
381382

383+
# These rewrites are used to introduce specalized operations with better logprob graphs
384+
specialization_ir_rewrites_db = EquilibriumDB()
385+
specialization_ir_rewrites_db.name = "specialization_ir_rewrites_db"
386+
logprob_rewrites_db.register(
387+
"specialization_ir_rewrites_db", specialization_ir_rewrites_db, "basic"
388+
)
389+
390+
382391
logprob_rewrites_db.register("post-canonicalize", optdb.query("+canonicalize"), "basic")
383392

384393
# Rewrites that remove IR Ops

tests/distributions/test_multivariate.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@
2626
from pytensor import tensor as pt
2727
from pytensor.tensor import TensorVariable
2828
from pytensor.tensor.blockwise import Blockwise
29+
from pytensor.tensor.nlinalg import MatrixInverse
2930
from pytensor.tensor.random.utils import broadcast_params
3031
from pytensor.tensor.slinalg import Cholesky
3132

3233
import pymc as pm
3334

35+
from pymc import Model
3436
from pymc.distributions.multivariate import (
3537
MultivariateIntervalTransform,
3638
_LKJCholeskyCov,
@@ -2468,3 +2470,30 @@ def test_mvstudentt_mu_convenience():
24682470
x = pm.MvStudentT.dist(nu=4, mu=np.ones((10, 1, 1)), scale=np.full((2, 3, 3), np.eye(3)))
24692471
mu = x.owner.inputs[3]
24702472
np.testing.assert_allclose(mu.eval(), np.ones((10, 2, 3)))
2473+
2474+
2475+
def test_precision_mv_normal_optimization():
2476+
rng = np.random.default_rng(sum(map(ord, "be precise")))
2477+
2478+
n = 30
2479+
L = rng.uniform(low=0.1, high=1.0, size=(n, n))
2480+
Sigma_test = L @ L.T
2481+
mu_test = np.zeros(n)
2482+
Q_test = np.linalg.inv(Sigma_test)
2483+
y_test = rng.normal(size=n)
2484+
2485+
with Model() as m:
2486+
Q = pm.Flat("Q", shape=(n, n))
2487+
y = pm.MvNormal("y", mu=mu_test, tau=Q)
2488+
2489+
y_logp_fn = m.compile_logp(vars=[y]).f
2490+
2491+
# Check we don't have any MatrixInverses in the logp
2492+
assert not any(
2493+
node for node in y_logp_fn.maker.fgraph.apply_nodes if isinstance(node.op, MatrixInverse)
2494+
)
2495+
2496+
np.testing.assert_allclose(
2497+
y_logp_fn(y=y_test, Q=Q_test),
2498+
st.multivariate_normal.logpdf(y_test, mu_test, cov=Sigma_test),
2499+
)

0 commit comments

Comments
 (0)