|
43 | 43 | from .continuous import ChiSquared, Normal
|
44 | 44 | from .special import gammaln, multigammaln
|
45 | 45 | from .dist_math import bound, logpow, factln
|
46 |
| -from .shape_utils import to_tuple |
| 46 | +from .shape_utils import to_tuple, broadcast_dist_samples_to |
47 | 47 | from ..math import kron_dot, kron_diag, kron_solve_lower, kronecker
|
48 | 48 |
|
49 | 49 |
|
@@ -261,50 +261,23 @@ def random(self, point=None, size=None):
|
261 | 261 |
|
262 | 262 | param_attribute = getattr(self, "chol_cov" if self._cov_type == "chol" else self._cov_type)
|
263 | 263 | mu, param = draw_values([self.mu, param_attribute], point=point, size=size)
|
264 |
| - check_fast_drawable_or_point = lambda param, point: is_fast_drawable(param) or ( |
265 |
| - point and param.name in point |
266 |
| - ) |
267 |
| - |
268 |
| - if tuple(self.shape): |
269 |
| - dist_shape = tuple(self.shape) |
270 |
| - batch_shape = dist_shape[:-1] |
271 |
| - else: |
272 |
| - if check_fast_drawable_or_point(self.mu, point): |
273 |
| - batch_shape = mu.shape[:-1] |
274 |
| - else: |
275 |
| - batch_shape = mu.shape[len(size) : -1] |
276 |
| - dist_shape = batch_shape + param.shape[-1:] |
277 | 264 |
|
278 |
| - # First, distribution shape (batch+event) is computed and then |
279 |
| - # deterministic nature of random method can be obtained by appending it to sample_shape. |
280 |
| - output_shape = size + dist_shape |
281 |
| - extra_dims = len(output_shape) - mu.ndim |
282 |
| - |
283 |
| - # It was not a good idea to check mu.shape[:len(size)] == size, |
284 |
| - # because it can get mixed among batch and event dimensions. Here, we explicitly chop off |
285 |
| - # the size (sample_shape) and only broadcast batch and event dimensions. |
286 |
| - if check_fast_drawable_or_point(self.mu, point): |
287 |
| - mu = mu.reshape((1,) * extra_dims + mu.shape) |
288 |
| - else: |
289 |
| - mu = mu.reshape(size + (1,) * extra_dims + mu.shape[len(size) :]) |
290 |
| - |
291 |
| - # Adding batch dimensions to parametrization |
292 |
| - if size and param.shape[:-2] == size: |
293 |
| - param = param.reshape(size + (1,) * len(batch_shape) + param.shape[-2:]) |
294 |
| - |
295 |
| - mu = np.broadcast_to(mu, output_shape) |
296 |
| - param = np.broadcast_to(param, output_shape + param.shape[-1:]) |
297 |
| - if mu.shape[-1] != param.shape[-1]: |
298 |
| - raise ValueError(f"Shapes for mu and {self._cov_type} don't match") |
| 265 | + dist_shape = to_tuple(self.shape) |
| 266 | + mu = broadcast_dist_samples_to(to_shape=dist_shape, samples=[mu], size=size)[0] |
| 267 | + param = broadcast_dist_samples_to( |
| 268 | + to_shape=dist_shape + dist_shape[-1:], samples=[param], size=size |
| 269 | + )[0] |
299 | 270 |
|
300 | 271 | if self._cov_type == "cov":
|
301 | 272 | chol = np.linalg.cholesky(param)
|
302 | 273 | elif self._cov_type == "chol":
|
303 | 274 | chol = param
|
304 |
| - else: |
305 |
| - inverse = np.linalg.inv(param) |
306 |
| - chol = np.linalg.cholesky(inverse) |
| 275 | + else: # tau -> chol -> swapaxes (chol, -1, -2) -> inv ... |
| 276 | + lower_chol = np.linalg.cholesky(param) |
| 277 | + upper_chol = np.swapaxes(lower_chol, -1, -2) |
| 278 | + chol = np.linalg.inv(upper_chol) |
307 | 279 |
|
| 280 | + output_shape = size + dist_shape |
308 | 281 | standard_normal = np.random.standard_normal(output_shape)
|
309 | 282 | return mu + np.einsum("...ij,...j->...i", chol, standard_normal)
|
310 | 283 |
|
@@ -404,13 +377,13 @@ def random(self, point=None, size=None):
|
404 | 377 | nu, mu = draw_values([self.nu, self.mu], point=point, size=size)
|
405 | 378 | if self._cov_type == "cov":
|
406 | 379 | (cov,) = draw_values([self.cov], point=point, size=size)
|
407 |
| - dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov) |
| 380 | + dist = MvNormal.dist(mu=np.zeros_like(mu), cov=cov, shape=self.shape) |
408 | 381 | elif self._cov_type == "tau":
|
409 | 382 | (tau,) = draw_values([self.tau], point=point, size=size)
|
410 |
| - dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau) |
| 383 | + dist = MvNormal.dist(mu=np.zeros_like(mu), tau=tau, shape=self.shape) |
411 | 384 | else:
|
412 | 385 | (chol,) = draw_values([self.chol_cov], point=point, size=size)
|
413 |
| - dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol) |
| 386 | + dist = MvNormal.dist(mu=np.zeros_like(mu), chol=chol, shape=self.shape) |
414 | 387 |
|
415 | 388 | samples = dist.random(point, size)
|
416 | 389 |
|
@@ -1920,6 +1893,7 @@ def random(self, point=None, size=None):
|
1920 | 1893 | """
|
1921 | 1894 | # Expand params into terms MvNormal can understand to force consistency
|
1922 | 1895 | self._setup_random()
|
| 1896 | + self.mv_params["shape"] = self.shape |
1923 | 1897 | dist = MvNormal.dist(**self.mv_params)
|
1924 | 1898 | return dist.random(point, size)
|
1925 | 1899 |
|
|
0 commit comments