Skip to content

Commit 2d4f4da

Browse files
committed
Make KroneckerNormal a SymbolicRV with a valid signature
1 parent c1b0f21 commit 2d4f4da

File tree

1 file changed

+23
-31
lines changed

1 file changed

+23
-31
lines changed

pymc/distributions/multivariate.py

Lines changed: 23 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1880,35 +1880,30 @@ def logp(value, mu, rowchol, colchol):
18801880
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet
18811881

18821882

1883-
class KroneckerNormalRV(RandomVariable):
1884-
name = "kroneckernormal"
1883+
class KroneckerNormalRV(SymbolicRandomVariable):
18851884
ndim_supp = 1
1886-
ndims_params = [1, 0, 2]
1887-
dtype = "floatX"
18881885
_print_name = ("KroneckerNormal", "\\operatorname{KroneckerNormal}")
18891886

1890-
def _supp_shape_from_params(self, dist_params, param_shapes=None):
1891-
return supp_shape_from_ref_param_shape(
1892-
ndim_supp=self.ndim_supp,
1893-
dist_params=dist_params,
1894-
param_shapes=param_shapes,
1895-
ref_param_idx=0,
1896-
)
1897-
1898-
def rng_fn(self, rng, mu, sigma, *covs, size=None):
1899-
size = size if size else covs[-1]
1900-
covs = covs[:-1] if covs[-1] == size else covs
1901-
1902-
cov = reduce(scipy.linalg.kron, covs)
1903-
1904-
if sigma:
1905-
cov = cov + sigma**2 * np.eye(cov.shape[0])
1887+
@classmethod
1888+
def rv_op(cls, mu, sigma, *covs, size=None, rng=None):
1889+
mu = pt.as_tensor(mu)
1890+
sigma = pt.as_tensor(sigma)
1891+
covs = [pt.as_tensor(cov) for cov in covs]
1892+
rng = normalize_rng_param(rng)
1893+
size = normalize_size_param(size)
19061894

1907-
x = multivariate_normal.rng_fn(rng=rng, mean=mu, cov=cov, size=size)
1908-
return x
1895+
cov = reduce(pt.linalg.kron, covs)
1896+
cov = cov + sigma**2 * pt.eye(cov.shape[-2])
1897+
next_rng, draws = multivariate_normal(mean=mu, cov=cov, size=size, rng=rng).owner.outputs
19091898

1899+
covs_sig = ",".join(f"(a{i},b{i})" for i in range(len(covs)))
1900+
signature = f"[rng],[size],(m),(),{covs_sig}->[rng],(m)"
19101901

1911-
kroneckernormal = KroneckerNormalRV()
1902+
return KroneckerNormalRV(
1903+
inputs=[rng, size, mu, sigma, *covs],
1904+
outputs=[next_rng, draws],
1905+
signature=signature,
1906+
)(rng, size, mu, sigma, *covs)
19121907

19131908

19141909
class KroneckerNormal(Continuous):
@@ -1999,7 +1994,8 @@ class KroneckerNormal(Continuous):
19991994
.. [1] Saatchi, Y. (2011). "Scalable inference for structured Gaussian process models"
20001995
"""
20011996

2002-
rv_op = kroneckernormal
1997+
rv_type = KroneckerNormalRV
1998+
rv_op = KroneckerNormalRV.rv_op
20031999

20042000
@classmethod
20052001
def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs):
@@ -2024,14 +2020,10 @@ def dist(cls, mu, covs=None, chols=None, evds=None, sigma=None, *args, **kwargs)
20242020

20252021
return super().dist([mu, sigma, *covs], **kwargs)
20262022

2027-
def support_point(rv, size, mu, covs, chols, evds):
2028-
mean = mu
2029-
if not rv_size_is_none(size):
2030-
support_point_size = pt.concatenate([size, mu.shape])
2031-
mean = pt.full(support_point_size, mu)
2032-
return mean
2023+
def support_point(rv, rng, size, mu, sigma, *covs):
2024+
return pt.full_like(rv, mu)
20332025

2034-
def logp(value, mu, sigma, *covs):
2026+
def logp(value, rng, size, mu, sigma, *covs):
20352027
"""
20362028
Calculate log-probability of Multivariate Normal distribution
20372029
with Kronecker-structured covariance at specified value.

0 commit comments

Comments
 (0)