Skip to content

Commit f3ce16f

Browse files
committed
Infer logp for elemwise transformations of multivariate variables
1 parent d860f63 commit f3ce16f

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

pymc/logprob/transforms.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -377,12 +377,15 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
377377
else:
378378
input_logprob = _logprob_helper(measurable_input, backward_value)
379379

380-
if input_logprob.ndim < value.ndim:
381-
# Do we just need to sum the jacobian terms across the support dims?
382-
raise NotImplementedError("Transform of multivariate RVs not implemented")
383-
384380
jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs)
385381

382+
if input_logprob.ndim < value.ndim:
383+
# For multivariate variables, the Jacobian is diagonal.
384+
# We can get the right result by summing the last dimensions
385+
# of `transform_elemwise.log_jac_det`
386+
ndim_supp = value.ndim - input_logprob.ndim
387+
jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0)))
388+
386389
# The jacobian is used to ensure a value in the supported domain was provided
387390
return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian)
388391

tests/logprob/test_transforms.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -964,3 +964,28 @@ def scan_step(prev_innov):
964964
"innov": np.full((4,), -0.5),
965965
}
966966
np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point))
967+
968+
969+
@pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])])
970+
@pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])])
971+
def test_multivariate_transform(shift, scale):
972+
mu = np.array([0, 0.9, -2.1])
973+
cov = np.array([[1, 0, 0.9], [0, 1, 0], [0.9, 0, 1]])
974+
x_rv_raw = pt.random.multivariate_normal(mu, cov=cov)
975+
x_rv = shift + x_rv_raw * scale
976+
x_rv.name = "x"
977+
978+
x_vv = x_rv.clone()
979+
logp = factorized_joint_logprob({x_rv: x_vv})[x_vv]
980+
assert_no_rvs(logp)
981+
982+
x_vv_test = np.array([5.0, 4.9, -6.3])
983+
scale_mat = scale * np.eye(x_vv_test.shape[0])
984+
np.testing.assert_almost_equal(
985+
logp.eval({x_vv: x_vv_test}),
986+
sp.stats.multivariate_normal.logpdf(
987+
x_vv_test,
988+
shift + mu * scale,
989+
scale_mat @ cov @ scale_mat.T,
990+
),
991+
)

0 commit comments

Comments
 (0)