|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import itertools |
15 | 16 | import pytest
|
16 | 17 | import numpy as np
|
17 | 18 | import numpy.testing as npt
|
|
27 | 28 | draw_values,
|
28 | 29 | _DrawValuesContext,
|
29 | 30 | _DrawValuesContextBlocker,
|
| 31 | + to_tuple, |
30 | 32 | )
|
31 | 33 | from .helpers import SeededTest
|
32 | 34 | from .test_distributions import (
|
@@ -1544,3 +1546,112 @@ def test_Triangular(
|
1544 | 1546 | prior_samples=prior_samples,
|
1545 | 1547 | )
|
1546 | 1548 | assert prior["target"].shape == (prior_samples,) + shape
|
| 1549 | + |
| 1550 | + |
| 1551 | +def generate_shapes(include_params=False, xfail=False): |
| 1552 | + # fmt: off |
| 1553 | + mudim_as_event = [ |
| 1554 | + [None, 1, 3, 10, (10, 3), 100], |
| 1555 | + [(3,)], |
| 1556 | + [(1,), (3,)], |
| 1557 | + ["cov", "chol", "tau"] |
| 1558 | + ] |
| 1559 | + # fmt: on |
| 1560 | + mudim_as_dist = [ |
| 1561 | + [None, 1, 3, 10, (10, 3), 100], |
| 1562 | + [(10, 3)], |
| 1563 | + [(1,), (3,), (1, 1), (1, 3), (10, 1), (10, 3)], |
| 1564 | + ["cov", "chol", "tau"], |
| 1565 | + ] |
| 1566 | + if not include_params: |
| 1567 | + del mudim_as_event[-1] |
| 1568 | + del mudim_as_dist[-1] |
| 1569 | + data = itertools.chain(itertools.product(*mudim_as_event), itertools.product(*mudim_as_dist)) |
| 1570 | + if xfail: |
| 1571 | + data = list(data) |
| 1572 | + for index in range(len(data)): |
| 1573 | + if data[index][0] in (None, 1): |
| 1574 | + data[index] = pytest.param( |
| 1575 | + *data[index], marks=pytest.mark.xfail(reason="wait for PR #4214") |
| 1576 | + ) |
| 1577 | + return data |
| 1578 | + |
| 1579 | + |
| 1580 | +class TestMvNormal(SeededTest): |
| 1581 | + @pytest.mark.parametrize( |
| 1582 | + ["sample_shape", "dist_shape", "mu_shape", "param"], |
| 1583 | + generate_shapes(include_params=True, xfail=False), |
| 1584 | + ids=str, |
| 1585 | + ) |
| 1586 | + def test_with_np_arrays(self, sample_shape, dist_shape, mu_shape, param): |
| 1587 | + dist = pm.MvNormal.dist(mu=np.ones(mu_shape), **{param: np.eye(3)}, shape=dist_shape) |
| 1588 | + output_shape = to_tuple(sample_shape) + dist_shape |
| 1589 | + assert dist.random(size=sample_shape).shape == output_shape |
| 1590 | + |
| 1591 | + @pytest.mark.parametrize( |
| 1592 | + ["sample_shape", "dist_shape", "mu_shape"], |
| 1593 | + generate_shapes(include_params=False, xfail=True), |
| 1594 | + ids=str, |
| 1595 | + ) |
| 1596 | + def test_with_chol_rv(self, sample_shape, dist_shape, mu_shape): |
| 1597 | + with pm.Model() as model: |
| 1598 | + mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape) |
| 1599 | + sd_dist = pm.Exponential.dist(1.0, shape=3) |
| 1600 | + chol, corr, stds = pm.LKJCholeskyCov( |
| 1601 | + "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True |
| 1602 | + ) |
| 1603 | + mv = pm.MvNormal("mv", mu, chol=chol, shape=dist_shape) |
| 1604 | + prior = pm.sample_prior_predictive(samples=sample_shape) |
| 1605 | + |
| 1606 | + assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape |
| 1607 | + |
| 1608 | + @pytest.mark.parametrize( |
| 1609 | + ["sample_shape", "dist_shape", "mu_shape"], |
| 1610 | + generate_shapes(include_params=False, xfail=True), |
| 1611 | + ids=str, |
| 1612 | + ) |
| 1613 | + def test_with_cov_rv(self, sample_shape, dist_shape, mu_shape): |
| 1614 | + with pm.Model() as model: |
| 1615 | + mu = pm.Normal("mu", 0.0, 1.0, shape=mu_shape) |
| 1616 | + sd_dist = pm.Exponential.dist(1.0, shape=3) |
| 1617 | + chol, corr, stds = pm.LKJCholeskyCov( |
| 1618 | + "chol_cov", n=3, eta=2, sd_dist=sd_dist, compute_corr=True |
| 1619 | + ) |
| 1620 | + mv = pm.MvNormal("mv", mu, cov=pm.math.dot(chol, chol.T), shape=dist_shape) |
| 1621 | + prior = pm.sample_prior_predictive(samples=sample_shape) |
| 1622 | + |
| 1623 | + assert prior["mv"].shape == to_tuple(sample_shape) + dist_shape |
| 1624 | + |
| 1625 | + def test_issue_3758(self): |
| 1626 | + np.random.seed(42) |
| 1627 | + ndim = 50 |
| 1628 | + with pm.Model() as model: |
| 1629 | + a = pm.Normal("a", sigma=100, shape=ndim) |
| 1630 | + b = pm.Normal("b", mu=a, sigma=1, shape=ndim) |
| 1631 | + c = pm.MvNormal("c", mu=a, chol=np.linalg.cholesky(np.eye(ndim)), shape=ndim) |
| 1632 | + d = pm.MvNormal("d", mu=a, cov=np.eye(ndim), shape=ndim) |
| 1633 | + samples = pm.sample_prior_predictive(1000) |
| 1634 | + |
| 1635 | + for var in "abcd": |
| 1636 | + assert not np.isnan(np.std(samples[var])) |
| 1637 | + |
| 1638 | + def test_issue_3829(self): |
| 1639 | + with pm.Model() as model: |
| 1640 | + x = pm.MvNormal("x", mu=np.zeros(5), cov=np.eye(5), shape=(2, 5)) |
| 1641 | + trace_pp = pm.sample_prior_predictive(50) |
| 1642 | + |
| 1643 | + assert np.shape(trace_pp["x"][0]) == (2, 5) |
| 1644 | + |
| 1645 | + def test_issue_3706(self): |
| 1646 | + N = 10 |
| 1647 | + Sigma = np.eye(2) |
| 1648 | + |
| 1649 | + with pm.Model() as model: |
| 1650 | + |
| 1651 | + X = pm.MvNormal("X", mu=np.zeros(2), cov=Sigma, shape=(N, 2)) |
| 1652 | + betas = pm.Normal("betas", 0, 1, shape=2) |
| 1653 | + y = pm.Deterministic("y", pm.math.dot(X, betas)) |
| 1654 | + |
| 1655 | + prior_pred = pm.sample_prior_predictive(1) |
| 1656 | + |
| 1657 | + assert prior_pred["X"].shape == (1, N, 2) |
0 commit comments