24
24
import pytensor .tensor as pt
25
25
import scipy
26
26
27
+ from pytensor .graph import node_rewriter
27
28
from pytensor .graph .basic import Apply , Variable
28
29
from pytensor .graph .op import Op
29
30
from pytensor .raise_op import Assert
39
40
from pytensor .tensor .exceptions import NotScalarConstantError
40
41
from pytensor .tensor .linalg import cholesky , det , eigh , solve_triangular , trace
41
42
from pytensor .tensor .linalg import inv as matrix_inverse
42
- from pytensor .tensor .random .basic import dirichlet , multinomial , multivariate_normal
43
+ from pytensor .tensor .random .basic import MvNormalRV , dirichlet , multinomial , multivariate_normal
43
44
from pytensor .tensor .random .op import RandomVariable
44
45
from pytensor .tensor .random .utils import (
45
46
broadcast_params ,
77
78
)
78
79
from pymc .distributions .transforms import Interval , ZeroSumTransform , _default_transform
79
80
from pymc .logprob .abstract import _logprob
81
+ from pymc .logprob .rewriting import (
82
+ specialization_ir_rewrites_db ,
83
+ )
80
84
from pymc .math import kron_diag , kron_dot
81
85
from pymc .pytensorf import normalize_rng_param
82
86
from pymc .util import check_dist_not_registered
@@ -157,6 +161,13 @@ def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
157
161
return cov
158
162
159
163
164
+ def _logdet_from_cholesky (chol : TensorVariable ) -> tuple [TensorVariable , TensorVariable ]:
165
+ diag = pt .diagonal (chol , axis1 = - 2 , axis2 = - 1 )
166
+ logdet = pt .log (diag ).sum (axis = - 1 )
167
+ posdef = pt .all (diag > 0 , axis = - 1 )
168
+ return logdet , posdef
169
+
170
+
160
171
def quaddist_chol (value , mu , cov ):
161
172
"""Compute (x - mu).T @ Sigma^-1 @ (x - mu) and the logdet of Sigma."""
162
173
if value .ndim == 0 :
@@ -167,23 +178,23 @@ def quaddist_chol(value, mu, cov):
167
178
else :
168
179
onedim = False
169
180
170
- delta = value - mu
171
181
chol_cov = nan_lower_cholesky (cov )
182
+ logdet , posdef = _logdet_from_cholesky (chol_cov )
172
183
173
- diag = pt .diagonal (chol_cov , axis1 = - 2 , axis2 = - 1 )
174
- # Check if the covariance matrix is positive definite.
175
- ok = pt .all (diag > 0 , axis = - 1 )
176
- # If not, replace the diagonal. We return -inf later, but
177
- # need to prevent solve_lower from throwing an exception.
178
- chol_cov = pt .switch (ok [..., None , None ], chol_cov , 1 )
184
+ # solve_triangular will raise if there are nans
185
+ # (which happens if the cholesky fails)
186
+ chol_cov .dprint (print_type = True , depth = 1 )
187
+ posdef .dprint (print_type = True , depth = 1 )
188
+ chol_cov = pt .switch (posdef [..., None , None ], chol_cov , 1 )
189
+
190
+ delta = value - mu
179
191
delta_trans = solve_lower (chol_cov , delta , b_ndim = 1 )
180
192
quaddist = (delta_trans ** 2 ).sum (axis = - 1 )
181
- logdet = pt .log (diag ).sum (axis = - 1 )
182
193
183
194
if onedim :
184
- return quaddist [0 ], logdet , ok
195
+ return quaddist [0 ], logdet , posdef
185
196
else :
186
- return quaddist , logdet , ok
197
+ return quaddist , logdet , posdef
187
198
188
199
189
200
class MvNormal (Continuous ):
@@ -283,16 +294,80 @@ def logp(value, mu, cov):
283
294
-------
284
295
TensorVariable
285
296
"""
286
- quaddist , logdet , ok = quaddist_chol (value , mu , cov )
297
+ quaddist , logdet , posdef = quaddist_chol (value , mu , cov )
287
298
k = value .shape [- 1 ].astype ("floatX" )
288
299
norm = - 0.5 * k * np .log (2 * np .pi )
289
300
return check_parameters (
290
301
norm - 0.5 * quaddist - logdet ,
291
- ok ,
292
- msg = "posdef" ,
302
+ posdef ,
303
+ msg = "posdef covariance " ,
293
304
)
294
305
295
306
307
+ class PrecisionMvNormalRV (SymbolicRandomVariable ):
308
+ r"""A specialized multivariate normal random variable defined in terms of precision.
309
+
310
+ This class is introduced during specialization logprob rewrites, and not meant to be used directly.
311
+ """
312
+
313
+ name = "precision_multivariate_normal"
314
+ extended_signature = "[rng],[size],(n),(n,n)->(n)"
315
+ _print_name = ("PrecisionMultivariateNormal" , "\\ operatorname{PrecisionMultivariateNormal}" )
316
+
317
+ @classmethod
318
+ def rv_op (cls , mean , tau , * , rng = None , size = None ):
319
+ rng = normalize_rng_param (rng )
320
+ size = normalize_size_param (size )
321
+ cov = pt .linalg .inv (tau )
322
+ next_rng , draws = multivariate_normal (mean , cov , size = size , rng = rng ).owner .outputs
323
+ return cls (
324
+ inputs = [rng , size , mean , tau ],
325
+ outputs = [next_rng , draws ],
326
+ )(rng , size , mean , tau )
327
+
328
+
329
+ @_logprob .register
330
+ def precision_mv_normal_logp (op : PrecisionMvNormalRV , value , rng , size , mean , tau , ** kwargs ):
331
+ [value ] = value
332
+ k = value .shape [- 1 ].astype ("floatX" )
333
+
334
+ delta = value - mean
335
+ quadratic_form = delta .T @ tau @ delta
336
+ logdet , posdef = _logdet_from_cholesky (nan_lower_cholesky (tau ))
337
+ logp = - 0.5 * (k * pt .log (2 * np .pi ) + quadratic_form ) + logdet
338
+
339
+ return check_parameters (
340
+ logp ,
341
+ posdef ,
342
+ msg = "posdef precision" ,
343
+ )
344
+
345
+
346
+ @node_rewriter (tracks = [MvNormalRV ])
347
+ def mv_normal_to_precision_mv_normal (fgraph , node ):
348
+ """Replaces MvNormal(mu, inv(tau)) -> PrecisionMvNormal(mu, tau)
349
+
350
+ This is introduced in logprob rewrites to provide a more efficient logp for a MvNormal
351
+ that is defined by a precision matrix.
352
+
353
+ Note: This won't be introduced when calling `pm.logp` as that will dispatch directly
354
+ without triggering the logprob rewrites.
355
+ """
356
+
357
+ rng , size , mu , cov = node .inputs
358
+ if cov .owner and cov .owner .op == matrix_inverse :
359
+ tau = cov .owner .inputs [0 ]
360
+ return PrecisionMvNormalRV .rv_op (mu , tau , size = size , rng = rng ).owner .outputs
361
+ return None
362
+
363
+
364
+ specialization_ir_rewrites_db .register (
365
+ mv_normal_to_precision_mv_normal .__name__ ,
366
+ mv_normal_to_precision_mv_normal ,
367
+ "basic" ,
368
+ )
369
+
370
+
296
371
class MvStudentTRV (RandomVariable ):
297
372
name = "multivariate_studentt"
298
373
signature = "(),(n),(n,n)->(n)"
0 commit comments