Skip to content

Commit c25a988

Browse files
committed
Used shapes_utils.broadcast_dist_samples_to function for broadcasting
1 parent b52f1bb commit c25a988

File tree

1 file changed

+15
-41
lines changed

1 file changed

+15
-41
lines changed

pymc3/distributions/multivariate.py

Lines changed: 15 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
from .continuous import ChiSquared, Normal
4444
from .special import gammaln, multigammaln
4545
from .dist_math import bound, logpow, factln
46-
from .shape_utils import to_tuple
46+
from .shape_utils import to_tuple, broadcast_dist_samples_to
4747
from ..math import kron_dot, kron_diag, kron_solve_lower, kronecker
4848

4949

@@ -261,50 +261,23 @@ def random(self, point=None, size=None):
261261

262262
param_attribute = getattr(self, "chol_cov" if self._cov_type == "chol" else self._cov_type)
263263
mu, param = draw_values([self.mu, param_attribute], point=point, size=size)
264-
check_fast_drawable_or_point = lambda param, point: is_fast_drawable(param) or (
265-
point and param.name in point
266-
)
267-
268-
if tuple(self.shape):
269-
dist_shape = tuple(self.shape)
270-
batch_shape = dist_shape[:-1]
271-
else:
272-
if check_fast_drawable_or_point(self.mu, point):
273-
batch_shape = mu.shape[:-1]
274-
else:
275-
batch_shape = mu.shape[len(size) : -1]
276-
dist_shape = batch_shape + param.shape[-1:]
277264

278-
# First, distribution shape (batch+event) is computed and then
279-
# deterministic nature of random method can be obtained by appending it to sample_shape.
280-
output_shape = size + dist_shape
281-
extra_dims = len(output_shape) - mu.ndim
282-
283-
# It was not a good idea to check mu.shape[:len(size)] == size,
284-
# because it can get mixed among batch and event dimensions. Here, we explicitly chop off
285-
# the size (sample_shape) and only broadcast batch and event dimensions.
286-
if check_fast_drawable_or_point(self.mu, point):
287-
mu = mu.reshape((1,) * extra_dims + mu.shape)
288-
else:
289-
mu = mu.reshape(size + (1,) * extra_dims + mu.shape[len(size) :])
290-
291-
# Adding batch dimensions to parametrization
292-
if size and param.shape[:-2] == size:
293-
param = param.reshape(size + (1,) * len(batch_shape) + param.shape[-2:])
294-
295-
mu = np.broadcast_to(mu, output_shape)
296-
param = np.broadcast_to(param, output_shape + param.shape[-1:])
297-
if mu.shape[-1] != param.shape[-1]:
298-
raise ValueError(f"Shapes for mu and {self._cov_type} don't match")
265+
dist_shape = to_tuple(self.shape)
266+
mu = broadcast_dist_samples_to(to_shape=dist_shape, samples=[mu], size=size)[0]
267+
param = broadcast_dist_samples_to(
268+
to_shape=dist_shape + dist_shape[-1:], samples=[param], size=size
269+
)[0]
299270

300271
if self._cov_type == "cov":
301272
chol = np.linalg.cholesky(param)
302273
elif self._cov_type == "chol":
303274
chol = param
304-
else:
305-
inverse = np.linalg.inv(param)
306-
chol = np.linalg.cholesky(inverse)
275+
else: # tau -> chol -> swapaxes (chol, -1, -2) -> inv ...
276+
lower_chol = np.linalg.cholesky(param)
277+
upper_chol = np.swapaxes(lower_chol, -1, -2)
278+
chol = np.linalg.inv(upper_chol)
307279

280+
output_shape = size + dist_shape
308281
standard_normal = np.random.standard_normal(output_shape)
309282
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)
310283

@@ -404,13 +377,13 @@ def random(self, point=None, size=None):
404377
nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
405378
if self._cov_type == "cov":
406379
(cov,) = draw_values([self.cov], point=point, size=size)
407-
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov)
380+
dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape)
408381
elif self._cov_type == "tau":
409382
(tau,) = draw_values([self.tau], point=point, size=size)
410-
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau)
383+
dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape)
411384
else:
412385
(chol,) = draw_values([self.chol_cov], point=point, size=size)
413-
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol)
386+
dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape)
414387

415388
samples = dist.random(point, size)
416389

@@ -1920,6 +1893,7 @@ def random(self, point=None, size=None):
19201893
"""
19211894
# Expand params into terms MvNormal can understand to force consistency
19221895
self._setup_random()
1896+
self.mv_params["shape"] = self.shape
19231897
dist = MvNormal.dist(**self.mv_params)
19241898
return dist.random(point, size)
19251899

0 commit comments

Comments
 (0)