Skip to content

Commit 6a37514

Browse files
committed
Modified logic and added tests
1 parent 3c19688 commit 6a37514

File tree

2 files changed

+123
-5
lines changed

2 files changed

+123
-5
lines changed

pymc3/distributions/multivariate.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,10 +256,18 @@ def random(self, point=None, size=None):
256256
mu, param = draw_values([self.mu, param_attribute], point=point, size=size)
257257

258258
dist_shape = to_tuple(self.shape)
259-
mu = broadcast_dist_samples_to(to_shape=dist_shape, samples=[mu], size=size)[0]
260-
param = broadcast_dist_samples_to(
261-
to_shape=dist_shape + dist_shape[-1:], samples=[param], size=size
262-
)[0]
259+
output_shape = size + dist_shape
260+
261+
# Simple, there can be only be 1 batch dimension, only available from `mu`.
262+
# Insert it into `param` before events, if there is a sample shape in front.
263+
if param.ndim > 2 and dist_shape[:-1]:
264+
param = param.reshape(size + (1,) + param.shape[-2:])
265+
266+
mu = broadcast_dist_samples_to(to_shape=output_shape, samples=[mu], size=size)[0]
267+
param = np.broadcast_to(param, shape=output_shape + dist_shape[-1:])
268+
269+
assert mu.shape == output_shape
270+
assert param.shape == output_shape + dist_shape[-1:]
263271

264272
if self._cov_type == "cov":
265273
chol = np.linalg.cholesky(param)
@@ -270,7 +278,6 @@ def random(self, point=None, size=None):
270278
upper_chol = np.swapaxes(lower_chol, -1, -2)
271279
chol = np.linalg.inv(upper_chol)
272280

273-
output_shape = size + dist_shape
274281
standard_normal = np.random.standard_normal(output_shape)
275282
return mu + np.einsum("...ij,...j->...i", chol, standard_normal)
276283

pymc3/tests/test_distributions_random.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
1516
import pytest
1617
import numpy as np
1718
import numpy.testing as npt
@@ -27,6 +28,7 @@
2728
draw_values,
2829
_DrawValuesContext,
2930
_DrawValuesContextBlocker,
31+
to_tuple,
3032
)
3133
from .helpers import SeededTest
3234
from .test_distributions import (
@@ -1544,3 +1546,112 @@ def test_Triangular(
15441546
prior_samples=prior_samples,
15451547
)
15461548
assert prior["target"].shape == (prior_samples,) + shape
1549+
1550+
1551+
def generate_shapes(include_params=False, xfail=False):
1552+
# fmt: off
1553+
mudim_as_event = [
1554+
[None, 1, 3, 10, (10, 3), 100],
1555+
[(3,)],
1556+
[(1,), (3,)],
1557+
["cov", "chol", "tau"]
1558+
]
1559+
# fmt: on
1560+
mudim_as_dist = [
1561+
[None, 1, 3, 10, (10, 3), 100],
1562+
[(10, 3)],
1563+
[(1,), (3,), (1, 1), (1, 3), (10, 1), (10, 3)],
1564+
["cov", "chol", "tau"],
1565+
]
1566+
if not include_params:
1567+
del mudim_as_event[-1]
1568+
del mudim_as_dist[-1]
1569+
data = itertools.chain(itertools.product(*mudim_as_event), itertools.product(*mudim_as_dist))
1570+
if xfail:
1571+
data = list(data)
1572+
for index in range(len(data)):
1573+
if data[index][0] in (None, 1):
1574+
data[index] = pytest.param(
1575+
*data[index], marks=pytest.mark.xfail(reason="wait for PR #4214")
1576+
)
1577+
return data
1578+
1579+
1580+
class TestMvNormal(SeededTest):
1581+
@pytest.mark.parametrize(
1582+
["sample_shape", "dist_shape", "mu_shape", "param"],
1583+
generate_shapes(include_params=True, xfail=False),
1584+
ids=str,
1585+
)
1586+
def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param):
1587+
dist = pm.MvNormal.dist(mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape)
1588+
output_shape = to_tuple(sample_shape) + dist_shape
1589+
assert dist.random(size=sample_shape).shape == output_shape
1590+
1591+
@pytest.mark.parametrize(
1592+
["sample_shape", "dist_shape", "mu_shape"],
1593+
generate_shapes(include_params=False, xfail=True),
1594+
ids=str,
1595+
)
1596+
def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape):
1597+
with pm.Model() as model:
1598+
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
1599+
sd_dist = pm.Exponential.dist(1.0, shape=3)
1600+
chol, corr, stds = pm.LKJCholeskyCov(
1601+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
1602+
)
1603+
mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape)
1604+
prior = pm.sample_prior_predictive(samples=sample_shape)
1605+
1606+
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
1607+
1608+
@pytest.mark.parametrize(
1609+
["sample_shape", "dist_shape", "mu_shape"],
1610+
generate_shapes(include_params=False, xfail=True),
1611+
ids=str,
1612+
)
1613+
def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape):
1614+
with pm.Model() as model:
1615+
mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape)
1616+
sd_dist = pm.Exponential.dist(1.0, shape=3)
1617+
chol, corr, stds = pm.LKJCholeskyCov(
1618+
"chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True
1619+
)
1620+
mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape)
1621+
prior = pm.sample_prior_predictive(samples=sample_shape)
1622+
1623+
assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape
1624+
1625+
def test_issue_3758(self):
1626+
np.random.seed(42)
1627+
ndim = 50
1628+
with pm.Model() as model:
1629+
a = pm.Normal("a", sigma=100, shape=ndim)
1630+
b = pm.Normal("b", mu=a, sigma=1, shape=ndim)
1631+
c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim)
1632+
d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim)
1633+
samples = pm.sample_prior_predictive(1000)
1634+
1635+
for var in "abcd":
1636+
assert not np.isnan(np.std(samples[var]))
1637+
1638+
def test_issue_3829(self):
1639+
with pm.Model() as model:
1640+
x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5))
1641+
trace_pp = pm.sample_prior_predictive(50)
1642+
1643+
assert np.shape(trace_pp["x"][0]) == (2, 5)
1644+
1645+
def test_issue_3706(self):
1646+
N = 10
1647+
Sigma = np.eye(2)
1648+
1649+
with pm.Model() as model:
1650+
1651+
X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2))
1652+
betas = pm.Normal("betas", 0, 1, shape=2)
1653+
y = pm.Deterministic("y", pm.math.dot(X, betas))
1654+
1655+
prior_pred = pm.sample_prior_predictive(1)
1656+
1657+
assert prior_pred["X"].shape == (1, N, 2)

0 commit comments

Comments
 (0)