Skip to content

Refactored Wishart and MatrixNormal distribution #4777

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 105 additions & 146 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from pymc3.distributions.continuous import ChiSquared, Normal, assert_negative_support
from pymc3.distributions.dist_math import bound, factln, logpow, multigammaln
from pymc3.distributions.distribution import Continuous, Discrete
from pymc3.distributions.shape_utils import broadcast_dist_samples_to, to_tuple
from pymc3.math import kron_diag, kron_dot

__all__ = [
Expand Down Expand Up @@ -739,6 +740,26 @@ def __str__(self):
matrix_pos_def = PosDefMatrix()


class WishartRV(RandomVariable):
name = "wishart"
ndim_supp = 2
ndims_params = [0, 2]
dtype = "floatX"
_print_name = ("Wishart", "\\operatorname{Wishart}")

def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None):
# The shape of second parameter `V` defines the shape of the output.
return dist_params[1].shape

@classmethod
def rng_fn(cls, rng, nu, V, size=None):
size = size if size else 1 # Default size for Scipy's wishart.rvs is 1
return stats.wishart.rvs(np.int(nu), V, size=size, random_state=rng)


wishart = WishartRV()


class Wishart(Continuous):
r"""
Wishart log-likelihood.
Expand Down Expand Up @@ -775,9 +796,13 @@ class Wishart(Continuous):
This distribution is unusable in a PyMC3 model. You should instead
use LKJCholeskyCov or LKJCorr.
"""
rv_op = wishart

@classmethod
def dist(cls, nu, V, *args, **kwargs):
nu = at.as_tensor_variable(intX(nu))
V = at.as_tensor_variable(floatX(V))

def __init__(self, nu, V, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"The Wishart distribution can currently not be used "
"for MCMC sampling. The probability of sampling a "
Expand All @@ -787,34 +812,13 @@ def __init__(self, nu, V, *args, **kwargs):
"https://github.com/pymc-devs/pymc3/issues/538.",
UserWarning,
)
self.nu = nu = at.as_tensor_variable(nu)
self.p = p = at.as_tensor_variable(V.shape[0])
self.V = V = at.as_tensor_variable(V)
self.mean = nu * V
self.mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan)

def random(self, point=None, size=None):
"""
Draw random values from Wishart distribution.
# mean = nu * V
# p = V.shape[0]
# mode = at.switch(at.ge(nu, p + 1), (nu - p - 1) * V, np.nan)
return super().dist([nu, V], *args, **kwargs)

Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
Returns
-------
array
"""
# nu, V = draw_values([self.nu, self.V], point=point, size=size)
# size = 1 if size is None else size
# return generate_samples(stats.wishart.rvs, nu.item(), V, broadcast_shape=(size,))

def logp(self, X):
def logp(X, nu, V):
"""
Calculate log-probability of Wishart distribution
at specified value.
Expand All @@ -828,9 +832,8 @@ def logp(self, X):
-------
TensorVariable
"""
nu = self.nu
p = self.p
V = self.V

p = V.shape[0]

IVI = det(V)
IXI = det(X)
Expand Down Expand Up @@ -1445,6 +1448,36 @@ def _distr_parameters_for_repr(self):
return ["eta", "n"]


class MatrixNormalRV(RandomVariable):
name = "matrixnormal"
ndim_supp = 2
ndims_params = [2, 2, 2]
dtype = "floatX"
_print_name = ("MatrixNormal", "\\operatorname{MatrixNormal}")

@classmethod
def rng_fn(cls, rng, mu, rowchol, colchol, size=None):

size = to_tuple(size)
dist_shape = to_tuple([rowchol.shape[0], colchol.shape[0]])
output_shape = size + dist_shape

# Broadcasting all parameters
(mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])

colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
colchol = np.swapaxes(colchol, -1, -2) # Take transpose

standard_normal = rng.standard_normal(output_shape)
samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))

return samples


matrixnormal = MatrixNormalRV()


class MatrixNormal(Continuous):
r"""
Matrix-valued normal log-likelihood.
Expand Down Expand Up @@ -1533,175 +1566,101 @@ class MatrixNormal(Continuous):
vals = pm.MatrixNormal('vals', mu=mu, colchol=colchol, rowcov=rowcov,
observed=data, shape=(m, n))
"""
rv_op = matrixnormal

def __init__(
self,
mu=0,
@classmethod
def dist(
cls,
mu,
rowcov=None,
rowchol=None,
rowtau=None,
colcov=None,
colchol=None,
coltau=None,
shape=None,
*args,
**kwargs,
):
self._setup_matrices(colcov, colchol, coltau, rowcov, rowchol, rowtau)
if shape is None:
raise TypeError("shape is a required argument")
assert len(shape) == 2, "shape must have length 2: mxn"
self.shape = shape
super().__init__(shape=shape, *args, **kwargs)
self.mu = at.as_tensor_variable(mu)
self.mean = self.median = self.mode = self.mu
self.solve_lower = solve_lower_triangular
self.solve_upper = solve_upper_triangular

def _setup_matrices(self, colcov, colchol, coltau, rowcov, rowchol, rowtau):

cholesky = Cholesky(lower=True, on_error="raise")

if mu.ndim == 1:
raise ValueError(
"1x1 Matrix was provided. Please use Normal distribution for such cases."
)

# Among-row matrices
if len([i for i in [rowtau, rowcov, rowchol] if i is not None]) != 1:
if len([i for i in [rowcov, rowchol] if i is not None]) != 1:
raise ValueError(
"Incompatible parameterization. "
"Specify exactly one of rowtau, rowcov, "
"or rowchol."
"Incompatible parameterization. Specify exactly one of rowcov, or rowchol."
)
if rowcov is not None:
self.m = rowcov.shape[0]
self._rowcov_type = "cov"
rowcov = at.as_tensor_variable(rowcov)
if rowcov.ndim != 2:
raise ValueError("rowcov must be two dimensional.")
self.rowchol_cov = cholesky(rowcov)
self.rowcov = rowcov
elif rowtau is not None:
raise ValueError("rowtau not supported at this time")
self.m = rowtau.shape[0]
self._rowcov_type = "tau"
rowtau = at.as_tensor_variable(rowtau)
if rowtau.ndim != 2:
raise ValueError("rowtau must be two dimensional.")
self.rowchol_tau = cholesky(rowtau)
self.rowtau = rowtau
rowchol_cov = cholesky(rowcov)
else:
self.m = rowchol.shape[0]
self._rowcov_type = "chol"
if rowchol.ndim != 2:
raise ValueError("rowchol must be two dimensional.")
self.rowchol_cov = at.as_tensor_variable(rowchol)
rowchol_cov = at.as_tensor_variable(rowchol)

# Among-column matrices
if len([i for i in [coltau, colcov, colchol] if i is not None]) != 1:
if len([i for i in [colcov, colchol] if i is not None]) != 1:
raise ValueError(
"Incompatible parameterization. "
"Specify exactly one of coltau, colcov, "
"or colchol."
"Incompatible parameterization. Specify exactly one of colcov, or colchol."
)
if colcov is not None:
self.n = colcov.shape[0]
self._colcov_type = "cov"
colcov = at.as_tensor_variable(colcov)
if colcov.ndim != 2:
raise ValueError("colcov must be two dimensional.")
self.colchol_cov = cholesky(colcov)
self.colcov = colcov
elif coltau is not None:
raise ValueError("coltau not supported at this time")
self.n = coltau.shape[0]
self._colcov_type = "tau"
coltau = at.as_tensor_variable(coltau)
if coltau.ndim != 2:
raise ValueError("coltau must be two dimensional.")
self.colchol_tau = cholesky(coltau)
self.coltau = coltau
colchol_cov = cholesky(colcov)
else:
self.n = colchol.shape[0]
self._colcov_type = "chol"
if colchol.ndim != 2:
raise ValueError("colchol must be two dimensional.")
self.colchol_cov = at.as_tensor_variable(colchol)
colchol_cov = at.as_tensor_variable(colchol)

def random(self, point=None, size=None):
mu = at.as_tensor_variable(floatX(mu))
# mean = median = mode = mu

return super().dist([mu, rowchol_cov, colchol_cov], **kwargs)

def logp(value, mu, rowchol, colchol):
"""
Draw random values from Matrix-valued Normal distribution.
Calculate log-probability of Matrix-valued Normal distribution
at specified value.
Parameters
----------
point: dict, optional
Dict of variable values on which random values are to be
conditioned (uses default point if not specified).
size: int, optional
Desired size of random sample (returns one sample if not
specified).
value: numeric
Value for which log-probability is calculated.
Returns
-------
array
TensorVariable
"""
# mu, colchol, rowchol = draw_values(
# [self.mu, self.colchol_cov, self.rowchol_cov], point=point, size=size
# )
# size = to_tuple(size)
# dist_shape = to_tuple(self.shape)
# output_shape = size + dist_shape
#
# # Broadcasting all parameters
# (mu,) = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)
# rowchol = np.broadcast_to(rowchol, shape=size + rowchol.shape[-2:])
#
# colchol = np.broadcast_to(colchol, shape=size + colchol.shape[-2:])
# colchol = np.swapaxes(colchol, -1, -2) # Take transpose
#
# standard_normal = np.random.standard_normal(output_shape)
# samples = mu + np.matmul(rowchol, np.matmul(standard_normal, colchol))
# return samples

def _trquaddist(self, value):
"""Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
the logdet of colcov and rowcov."""

delta = value - self.mu
rowchol_cov = self.rowchol_cov
colchol_cov = self.colchol_cov
# Compute Tr[colcov^-1 @ (x - mu).T @ rowcov^-1 @ (x - mu)] and
# the logdet of colcov and rowcov.
delta = value - mu

# Find exponent piece by piece
right_quaddist = self.solve_lower(rowchol_cov, delta)
right_quaddist = solve_lower_triangular(rowchol, delta)
quaddist = at.nlinalg.matrix_dot(right_quaddist.T, right_quaddist)
quaddist = self.solve_lower(colchol_cov, quaddist)
quaddist = self.solve_upper(colchol_cov.T, quaddist)
quaddist = solve_lower_triangular(colchol, quaddist)
quaddist = solve_upper_triangular(colchol.T, quaddist)
trquaddist = at.nlinalg.trace(quaddist)

coldiag = at.diag(colchol_cov)
rowdiag = at.diag(rowchol_cov)
coldiag = at.diag(colchol)
rowdiag = at.diag(rowchol)
half_collogdet = at.sum(at.log(coldiag)) # logdet(M) = 2*Tr(log(L))
half_rowlogdet = at.sum(at.log(rowdiag)) # Using Cholesky: M = L L^T
return trquaddist, half_collogdet, half_rowlogdet

def logp(self, value):
"""
Calculate log-probability of Matrix-valued Normal distribution
at specified value.

Parameters
----------
value: numeric
Value for which log-probability is calculated.
m = rowchol.shape[0]
n = colchol.shape[0]

Returns
-------
TensorVariable
"""
trquaddist, half_collogdet, half_rowlogdet = self._trquaddist(value)
m = self.m
n = self.n
norm = -0.5 * m * n * pm.floatX(np.log(2 * np.pi))
return norm - 0.5 * trquaddist - m * half_collogdet - n * half_rowlogdet

def _distr_parameters_for_repr(self):
mapping = {"tau": "tau", "cov": "cov", "chol": "chol_cov"}
return ["mu", "row" + mapping[self._rowcov_type], "col" + mapping[self._colcov_type]]
return ["mu"]


class KroneckerNormalRV(RandomVariable):
Expand Down
Loading