@@ -1880,35 +1880,30 @@ def logp(value, mu, rowchol, colchol):
1880
1880
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
1881
1881
1882
1882
1883
- class KroneckerNormalRV (RandomVariable ):
1884
- name = "kroneckernormal"
1883
+ class KroneckerNormalRV (SymbolicRandomVariable ):
1885
1884
ndim_supp = 1
1886
- ndims_params = [1 , 0 , 2 ]
1887
- dtype = "floatX"
1888
1885
_print_name = ("KroneckerNormal" , "\\ operatorname{KroneckerNormal}" )
1889
1886
1890
- def _supp_shape_from_params (self , dist_params , param_shapes = None ):
1891
- return supp_shape_from_ref_param_shape (
1892
- ndim_supp = self .ndim_supp ,
1893
- dist_params = dist_params ,
1894
- param_shapes = param_shapes ,
1895
- ref_param_idx = 0 ,
1896
- )
1897
-
1898
- def rng_fn (self , rng , mu , sigma , * covs , size = None ):
1899
- size = size if size else covs [- 1 ]
1900
- covs = covs [:- 1 ] if covs [- 1 ] == size else covs
1901
-
1902
- cov = reduce (scipy .linalg .kron , covs )
1903
-
1904
- if sigma :
1905
- cov = cov + sigma ** 2 * np .eye (cov .shape [0 ])
1887
+ @classmethod
1888
+ def rv_op (cls , mu , sigma , * covs , size = None , rng = None ):
1889
+ mu = pt .as_tensor (mu )
1890
+ sigma = pt .as_tensor (sigma )
1891
+ covs = [pt .as_tensor (cov ) for cov in covs ]
1892
+ rng = normalize_rng_param (rng )
1893
+ size = normalize_size_param (size )
1906
1894
1907
- x = multivariate_normal .rng_fn (rng = rng , mean = mu , cov = cov , size = size )
1908
- return x
1895
+ cov = reduce (pt .linalg .kron , covs )
1896
+ cov = cov + sigma ** 2 * pt .eye (cov .shape [- 2 ])
1897
+ next_rng , draws = multivariate_normal (mean = mu , cov = cov , size = size , rng = rng ).owner .outputs
1909
1898
1899
+ covs_sig = "," .join (f"(a{ i } ,b{ i } )" for i in range (len (covs )))
1900
+ signature = f"[rng],[size],(m),(),{ covs_sig } ->[rng],(m)"
1910
1901
1911
- kroneckernormal = KroneckerNormalRV ()
1902
+ return KroneckerNormalRV (
1903
+ inputs = [rng , size , mu , sigma , * covs ],
1904
+ outputs = [next_rng , draws ],
1905
+ signature = signature ,
1906
+ )(rng , size , mu , sigma , * covs )
1912
1907
1913
1908
1914
1909
class KroneckerNormal (Continuous ):
@@ -1999,7 +1994,8 @@ class KroneckerNormal(Continuous):
1999
1994
.. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
2000
1995
"""
2001
1996
2002
- rv_op = kroneckernormal
1997
+ rv_type = KroneckerNormalRV
1998
+ rv_op = KroneckerNormalRV .rv_op
2003
1999
2004
2000
@classmethod
2005
2001
def dist (cls , mu , covs = None , chols = None , evds = None , sigma = None , * args , ** kwargs ):
@@ -2024,14 +2020,10 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)
2024
2020
2025
2021
return super ().dist ([mu , sigma , * covs ], ** kwargs )
2026
2022
2027
- def support_point (rv , size , mu , covs , chols , evds ):
2028
- mean = mu
2029
- if not rv_size_is_none (size ):
2030
- support_point_size = pt .concatenate ([size , mu .shape ])
2031
- mean = pt .full (support_point_size , mu )
2032
- return mean
2023
+ def support_point (rv , rng , size , mu , sigma , * covs ):
2024
+ return pt .full_like (rv , mu )
2033
2025
2034
- def logp (value , mu , sigma , * covs ):
2026
+ def logp (value , rng , size , mu , sigma , * covs ):
2035
2027
"""
2036
2028
Calculate log-probability of Multivariate Normal distribution
2037
2029
with Kronecker-structured covariance at specified value.
0 commit comments