Skip to content

Commit 8d254f9

Browse files
committed
Add mvnormal logp dlogp benchmark test
1 parent 2eae804 commit 8d254f9

File tree

1 file changed

+40
-2
lines changed

1 file changed

+40
-2
lines changed

tests/tensor/test_blockwise.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from pytensor.gradient import grad
1010
from pytensor.graph import Apply, Op
1111
from pytensor.graph.replace import vectorize_node
12-
from pytensor.tensor import tensor
12+
from pytensor.tensor import diagonal, log, tensor
1313
from pytensor.tensor.blockwise import Blockwise, _parse_gufunc_signature
1414
from pytensor.tensor.nlinalg import MatrixInverse
15-
from pytensor.tensor.slinalg import Cholesky, Solve
15+
from pytensor.tensor.slinalg import Cholesky, Solve, cholesky, solve_triangular
1616

1717

1818
def test_vectorize_blockwise():
@@ -270,3 +270,41 @@ class TestSolveVector(BlockwiseOpTester):
270270
class TestSolveMatrix(BlockwiseOpTester):
271271
core_op = Solve(lower=True, b_ndim=2)
272272
signature = "(m, m),(m, n) -> (m, n)"
273+
274+
275+
@pytest.mark.parametrize(
276+
"mu_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"mu:{arg}"
277+
)
278+
@pytest.mark.parametrize(
279+
"cov_batch_shape", [(), (1000,), (4, 1000)], ids=lambda arg: f"cov:{arg}"
280+
)
281+
def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchmark):
282+
rng = np.random.default_rng(sum(map(ord, "batched_mvnormal")))
283+
284+
value_batch_shape = mu_batch_shape
285+
if len(cov_batch_shape) > len(mu_batch_shape):
286+
value_batch_shape = cov_batch_shape
287+
288+
value = tensor("value", shape=(*value_batch_shape, 10))
289+
mu = tensor("mu", shape=(*mu_batch_shape, 10))
290+
cov = tensor("cov", shape=(*cov_batch_shape, 10, 10))
291+
292+
test_values = [
293+
rng.normal(size=value.type.shape),
294+
rng.normal(size=mu.type.shape),
295+
np.eye(cov.type.shape[-1]) * np.abs(rng.normal(size=cov.type.shape)),
296+
]
297+
298+
chol_cov = cholesky(cov, lower=True, on_error="raise")
299+
delta_trans = solve_triangular(chol_cov, value - mu, b_ndim=1)
300+
quaddist = (delta_trans**2).sum(axis=-1)
301+
diag = diagonal(chol_cov, axis1=-2, axis2=-1)
302+
logdet = log(diag).sum(axis=-1)
303+
k = value.shape[-1]
304+
norm = -0.5 * k * (np.log(2 * np.pi))
305+
306+
logp = norm - 0.5 * quaddist - logdet
307+
dlogp = grad(logp.sum(), wrt=[value, mu, cov])
308+
309+
fn = pytensor.function([value, mu, cov], [logp, *dlogp])
310+
benchmark(fn, *test_values)

0 commit comments

Comments
 (0)