|
9 | 9 | from pytensor.gradient import grad
|
10 | 10 | from pytensor.graph import Apply, Op
|
11 | 11 | from pytensor.graph.replace import vectorize_node
|
12 |
| -from pytensor.tensor import tensor |
| 12 | +from pytensor.tensor import diagonal, log, tensor |
13 | 13 | from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
|
14 | 14 | from pytensor.tensor.nlinalg import MatrixInverse
|
15 |
| -from pytensor.tensor.slinalg import Cholesky, Solve |
| 15 | +from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular |
16 | 16 |
|
17 | 17 |
|
18 | 18 | def test_vectorize_blockwise():
|
@@ -270,3 +270,41 @@ class TestSolveVector(BlockwiseOpTester):
|
270 | 270 | class TestSolveMatrix(BlockwiseOpTester):
|
271 | 271 | core_op = Solve(lower=True, b_ndim=2)
|
272 | 272 | signature = "(m, m),(m, n) -> (m, n)"
|
| 273 | + |
| 274 | + |
| 275 | +@pytest.mark.parametrize( |
| 276 | + "mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}" |
| 277 | +) |
| 278 | +@pytest.mark.parametrize( |
| 279 | + "cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}" |
| 280 | +) |
| 281 | +def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark): |
| 282 | + rng = np.random.default_rng(sum(map(ord, "batched_mvnormal"))) |
| 283 | + |
| 284 | + value_batch_shape = mu_batch_shape |
| 285 | + if len(cov_batch_shape) > len(mu_batch_shape): |
| 286 | + value_batch_shape = cov_batch_shape |
| 287 | + |
| 288 | + value = tensor("value", shape=(*value_batch_shape, 10)) |
| 289 | + mu = tensor("mu", shape=(*mu_batch_shape, 10)) |
| 290 | + cov = tensor("cov", shape=(*cov_batch_shape, 10, 10)) |
| 291 | + |
| 292 | + test_values = [ |
| 293 | + rng.normal(size=value.type.shape), |
| 294 | + rng.normal(size=mu.type.shape), |
| 295 | + np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)), |
| 296 | + ] |
| 297 | + |
| 298 | + chol_cov = cholesky(cov, lower=True, on_error="raise") |
| 299 | + delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1) |
| 300 | + quaddist = (delta_trans**2).sum(axis=-1) |
| 301 | + diag = diagonal(chol_cov, axis1=-2, axis2=-1) |
| 302 | + logdet = log(diag).sum(axis=-1) |
| 303 | + k = value.shape[-1] |
| 304 | + norm = -0.5 * k * (np.log(2 * np.pi)) |
| 305 | + |
| 306 | + logp = norm - 0.5 * quaddist - logdet |
| 307 | + dlogp = grad(logp.sum(), wrt=[value, mu, cov]) |
| 308 | + |
| 309 | + fn = pytensor.function([value, mu, cov], [logp, *dlogp]) |
| 310 | + benchmark(fn, *test_values) |
0 commit comments