Skip to content

Commit c98ff2e

Browse files
committed
Implement Symmetric Multivariate Laplace distribution
1 parent 0f5a818 commit c98ff2e

File tree

3 files changed

+200
-0
lines changed

3 files changed

+200
-0
lines changed

pymc_experimental/distributions/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Skellam,
2525
)
2626
from pymc_experimental.distributions.histogram_utils import histogram_approximation
27+
from pymc_experimental.distributions.multivariate.laplace import MvLaplace
2728
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
2829
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
2930

@@ -32,6 +33,7 @@
3233
"DiscreteMarkovChain",
3334
"GeneralizedPoisson",
3435
"GenExtreme",
36+
"MvLaplace",
3537
"R2D2M2CP",
3638
"Skellam",
3739
"histogram_approximation",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
import numpy as np
2+
import pytensor.tensor as pt
3+
import scipy
4+
5+
from pymc.distributions.dist_math import check_parameters
6+
from pymc.distributions.distribution import Continuous, SymbolicRandomVariable
7+
from pymc.distributions.multivariate import quaddist_chol, quaddist_matrix
8+
from pymc.distributions.shape_utils import implicit_size_from_params, rv_size_is_none
9+
from pymc.pytensorf import normalize_rng_param
10+
from pytensor.gradient import grad_not_implemented
11+
from pytensor.scalar import BinaryScalarOp, upgrade_to_float
12+
from pytensor.tensor.elemwise import Elemwise
13+
from pytensor.tensor.random.utils import normalize_size_param
14+
15+
16+
class Kv(BinaryScalarOp):
17+
"""
18+
Modified Bessel function of the second kind of real order v.
19+
"""
20+
21+
nfunc_spec = ("scipy.special.kv", 2, 1)
22+
23+
@staticmethod
24+
def st_impl(v, x):
25+
return scipy.special.kv(v, x)
26+
27+
def impl(self, v, x):
28+
return self.st_impl(v, x)
29+
30+
def grad(self, inputs, grads):
31+
v, x = inputs
32+
(gz,) = grads
33+
return [grad_not_implemented(self, 0, v), gz * scalar_kvp(v, x)]
34+
35+
def c_code(self, *args, **kwargs):
36+
raise NotImplementedError()
37+
38+
39+
kv = Elemwise(Kv(upgrade_to_float, name="kv"))
40+
41+
42+
class Kvp(BinaryScalarOp):
43+
"""
44+
First-order derivative of real-order Modified Bessel function of the second kind Kv(z)
45+
"""
46+
47+
nfunc_spec = ("scipy.special.kvp", 2, 1)
48+
49+
@staticmethod
50+
def st_impl(v, x):
51+
return scipy.special.kvp(v, x)
52+
53+
def impl(self, v, x):
54+
return self.st_impl(v, x)
55+
56+
def c_code(self, *args, **kwargs):
57+
raise NotImplementedError()
58+
59+
60+
scalar_kvp = Kvp(upgrade_to_float, name="kvp")
61+
62+
63+
class MultivariateLaplaceRV(SymbolicRandomVariable):
64+
name = "multivariate_laplace"
65+
extended_signature = "[rng],[size],(m),(m,m)->[rng],(m)"
66+
_print_name = ("MultivariateLaplace", "\\operatorname{MultivariateLaplace}")
67+
68+
@classmethod
69+
def rv_op(cls, mu, cov, *, size=None, rng=None):
70+
mu = pt.as_tensor(mu)
71+
cov = pt.as_tensor(cov)
72+
rng = normalize_rng_param(rng)
73+
size = normalize_size_param(size)
74+
75+
assert mu.type.ndim >= 1
76+
assert cov.type.ndim >= 2
77+
78+
if rv_size_is_none(size):
79+
size = implicit_size_from_params(mu, cov, ndims_params=(1, 2))
80+
81+
next_rng, e = pt.random.exponential(size=size, rng=rng).owner.outputs
82+
next_rng, z = pt.random.multivariate_normal(
83+
mean=pt.zeros(mu.shape[-1]), cov=cov, size=size, rng=next_rng
84+
).owner.outputs
85+
rv = mu + pt.sqrt(e)[..., None] * z
86+
87+
return cls(
88+
inputs=[rng, size, mu, cov],
89+
outputs=[next_rng, rv],
90+
)(rng, size, mu, cov)
91+
92+
93+
class MvLaplace(Continuous):
94+
r"""Multivariate (Symmetric) Laplace distribution."""
95+
96+
rv_type = MultivariateLaplaceRV
97+
rv_op = MultivariateLaplaceRV.rv_op
98+
99+
@classmethod
100+
def dist(cls, mu=0, cov=None, *, tau=None, chol=None, lower=True, **kwargs):
101+
cov = quaddist_matrix(cov, chol, tau, lower)
102+
103+
mu = pt.as_tensor_variable(mu)
104+
if mu.type.broadcastable[-1] != cov.type.broadcastable[-1]:
105+
mu, _ = pt.broadcast_arrays(mu, cov[..., -1])
106+
return super().dist([mu, cov], **kwargs)
107+
108+
def support_point(rv, size, mu, cov):
109+
if rv_size_is_none(size):
110+
broadcasted_mu, _ = pt.random.utils.broadcast_params([mu, cov], ndims_params=[1, 2])
111+
else:
112+
broadcast_shape = pt.concatenate([size, [mu.shape[-1]]])
113+
broadcasted_mu = pt.broadcast_to(mu, broadcast_shape)
114+
return broadcasted_mu
115+
116+
def logp(value, mu, cov):
117+
quaddist, logdet, posdef = quaddist_chol(value, mu, cov)
118+
119+
k = value.shape[-1].astype("floatX")
120+
norm = np.log(2) - 0.5 * k * np.log(2 * np.pi) - logdet
121+
122+
v = 1 - (k / 2)
123+
kernel = ((v / 2) * pt.log(quaddist / 2)) + pt.log(kv(v, pt.sqrt(2 * quaddist)))
124+
125+
logp_val = norm + kernel
126+
return check_parameters(logp_val, posdef, msg="posdef scale")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import numpy as np
2+
import pymc as pm
3+
import scipy
4+
5+
from pymc_experimental.distributions.multivariate.laplace import MvLaplace
6+
7+
8+
class TestMvLaplace:
9+
def test_mvlaplace_support_point(self):
10+
raise NotImplementedError()
11+
12+
def test_mvlaplace_mean(self):
13+
raise NotImplementedError()
14+
15+
def test_mvlaplace_random(self):
16+
mu = [-1, np.pi, 1]
17+
cov = [[1, 0.5, 0.25], [0.5, 2, 0.5], [0.25, 0.5, 3]]
18+
rv = MvLaplace.dist(mu=mu, cov=cov, size=10_000)
19+
20+
samples = pm.draw(rv, random_seed=13)
21+
assert samples.shape == (10_000, 3)
22+
np.testing.assert_allclose(np.mean(samples, axis=0), mu, rtol=0.05)
23+
np.testing.assert_allclose(np.cov(samples, rowvar=False), cov, rtol=0.1)
24+
25+
def test_laplace_logp(self):
26+
# Testing against special bivariate cases described in:
27+
# https://en.wikipedia.org/wiki/Multivariate_Laplace_distribution#Probability_density_function
28+
29+
# Zero mean, non-identity covariance case
30+
mu = np.zeros(2)
31+
s1 = 0.5
32+
s2 = 2.0
33+
r = -0.25
34+
cov = np.array(
35+
[
36+
[s1**2, r * s1 * s2],
37+
[r * s1 * s2, s2**2],
38+
]
39+
)
40+
rv = MvLaplace.dist(mu=mu, cov=cov)
41+
rv_val = np.random.normal(size=(2,))
42+
logp_eval = pm.logp(rv, rv_val).eval()
43+
44+
x1, x2 = rv_val
45+
logp_expected = np.log(
46+
(1 / (np.pi * s1 * s2 * np.sqrt(1 - r**2)))
47+
* scipy.special.kv(
48+
0,
49+
np.sqrt(
50+
(2 * ((x1**2 / s1**2) - (2 * r * x1 * x2 / (s1 * s2)) + (x2**2 / s2**2)))
51+
/ (1 - r**2)
52+
),
53+
)
54+
)
55+
np.testing.assert_allclose(
56+
logp_eval,
57+
logp_expected,
58+
)
59+
60+
# Non zero mean, identity covariance case
61+
mu = np.array([1, 3])
62+
rv = MvLaplace.dist(mu=mu, cov=np.eye(2))
63+
rv_val = np.random.normal(size=(2,))
64+
logp_eval = pm.logp(rv, rv_val).eval()
65+
66+
logp_expected = np.log(
67+
1 / np.pi * scipy.special.kv(0, np.sqrt(2 * np.sum((rv_val - mu) ** 2)))
68+
)
69+
np.testing.assert_allclose(
70+
logp_eval,
71+
logp_expected,
72+
)

0 commit comments

Comments
 (0)