6
6
from pymc import intX
7
7
from pymc .distributions .dist_math import check_parameters
8
8
from pymc .distributions .distribution import Continuous , SymbolicRandomVariable
9
- from pymc .distributions .multivariate import MvNormal
10
9
from pymc .distributions .shape_utils import get_support_shape_1d
11
10
from pymc .logprob .abstract import _logprob
12
11
from pytensor .graph .basic import Node
13
- from pytensor .tensor .random .basic import MvNormalRV
14
12
15
13
floatX = pytensor .config .floatX
16
14
COV_ZERO_TOL = 0
@@ -49,23 +47,6 @@ def make_signature(sequence_names):
49
47
return f"{ signature } ,[rng]->[rng],({ time } ,{ state_and_obs } )"
50
48
51
49
52
- class MvNormalSVDRV (MvNormalRV ):
53
- name = "multivariate_normal"
54
- signature = "(n),(n,n)->(n)"
55
- dtype = "floatX"
56
- _print_name = ("MultivariateNormal" , "\\ operatorname{MultivariateNormal}" )
57
-
58
-
59
- class MvNormalSVD (MvNormal ):
60
- """Dummy distribution intended to be rewritten into a JAX multivariate_normal with method="svd".
61
-
62
- A JAX MvNormal robust to low-rank covariance matrices
63
- """
64
-
65
- # TODO: Remove this entirely on next PyMC release; method will be exposed directly in MvNormal
66
- rv_op = MvNormalSVDRV (method = "svd" )
67
-
68
-
69
50
class LinearGaussianStateSpaceRV (SymbolicRandomVariable ):
70
51
default_output = 1
71
52
_print_name = ("LinearGuassianStateSpace" , "\\ operatorname{LinearGuassianStateSpace}" )
@@ -223,8 +204,12 @@ def step_fn(*args):
223
204
k = T .shape [0 ]
224
205
a = state [:k ]
225
206
226
- middle_rng , a_innovation = MvNormalSVD .dist (mu = 0 , cov = Q , rng = rng ).owner .outputs
227
- next_rng , y_innovation = MvNormalSVD .dist (mu = 0 , cov = H , rng = middle_rng ).owner .outputs
207
+ middle_rng , a_innovation = pm .MvNormal .dist (
208
+ mu = 0 , cov = Q , rng = rng , method = "svd"
209
+ ).owner .outputs
210
+ next_rng , y_innovation = pm .MvNormal .dist (
211
+ mu = 0 , cov = H , rng = middle_rng , method = "svd"
212
+ ).owner .outputs
228
213
229
214
a_mu = c + T @ a
230
215
a_next = a_mu + R @ a_innovation
@@ -239,8 +224,8 @@ def step_fn(*args):
239
224
Z_init = Z_ if Z_ in non_sequences else Z_ [0 ]
240
225
H_init = H_ if H_ in non_sequences else H_ [0 ]
241
226
242
- init_x_ = MvNormalSVD . dist (a0_ , P0_ , rng = rng )
243
- init_y_ = MvNormalSVD . dist (Z_init @ init_x_ , H_init , rng = rng )
227
+ init_x_ = pm . MvNormal . dist (a0_ , P0_ , rng = rng , method = "svd" )
228
+ init_y_ = pm . MvNormal . dist (Z_init @ init_x_ , H_init , rng = rng , method = "svd" )
244
229
245
230
init_dist_ = pt .concatenate ([init_x_ , init_y_ ], axis = 0 )
246
231
@@ -400,7 +385,7 @@ def rv_op(cls, mus, covs, logp, size=None):
400
385
rng = pytensor .shared (np .random .default_rng ())
401
386
402
387
def step (mu , cov , rng ):
403
- new_rng , mvn = MvNormalSVD . dist (mu = mu , cov = cov , rng = rng ).owner .outputs
388
+ new_rng , mvn = pm . MvNormal . dist (mu = mu , cov = cov , rng = rng , method = "svd" ).owner .outputs
404
389
return mvn , {rng : new_rng }
405
390
406
391
mvn_seq , updates = pytensor .scan (
0 commit comments