Skip to content

Commit eb55cd5

Browse files
committed
add r2m2d2cp
1 parent 56d8406 commit eb55cd5

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed
Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import Sequence, Union
2+
3+
import pymc as pm
4+
import pytensor.tensor as pt
5+
6+
7+
def _psivar2musigma(psi: pt.TensorVariable, var: pt.TensorVariable):
8+
pi = pt.erfinv(2 * psi - 1)
9+
f = (1 / (2 * pi**2 + 1)) ** 0.5
10+
sigma = pt.expand_dims(var, -1) ** 0.5 * f
11+
mu = sigma * pi * 2**0.5
12+
return mu, sigma
13+
14+
15+
def _R2D2M2CP_beta(
16+
name: str,
17+
variance: pt.TensorVariable,
18+
param_sigma: pt.TensorVariable,
19+
r2: pt.TensorVariable,
20+
phi: pt.TensorVariable,
21+
psi: pt.TensorVariable,
22+
*,
23+
dims: Union[str, Sequence[str]],
24+
centered=False,
25+
):
26+
"""R2D2M2CP_beta prior.
27+
name: str
28+
Name for the distribution
29+
variance: tensor
30+
standard deviation of the outcome
31+
param_sigma: tensor
32+
standard deviation of the explanatory variables
33+
r2: tensor
34+
expected R2 for the linear regression
35+
phi: tensor
36+
variance weights that sums up to 1
37+
psi: tensor
38+
probability of a coefficients to be positive
39+
"""
40+
tau2 = r2 / (1 - r2)
41+
explained_variance = phi * tau2 * pt.expand_dims(variance, -1)
42+
mu_param, std_param = _psivar2musigma(psi, explained_variance)
43+
if not centered:
44+
with pm.Model(name):
45+
raw = pm.Normal("raw", dims=dims)
46+
beta = pm.Deterministic(name, (raw * std_param + mu_param) / param_sigma, dims=dims)
47+
else:
48+
beta = pm.Normal(name, mu_param / param_sigma, std_param / param_sigma, dims=dims)
49+
return beta
50+
51+
52+
def R2D2M2CP(
53+
name,
54+
variance,
55+
param_sigma,
56+
*,
57+
dims,
58+
r2,
59+
variables_importance=None,
60+
variance_explained=None,
61+
r2_std=None,
62+
positive_probs=0.5,
63+
positive_probs_std=None,
64+
centered=False,
65+
):
66+
"""R2D2M2CP Prior.
67+
68+
Parameters
69+
----------
70+
name : str
71+
Name for the distribution
72+
variance : tensor
73+
Output variance
74+
param_sigma : tensor
75+
Input standard deviation
76+
dims : Union[str, Sequence[str]]
77+
Dims for the distribution
78+
r2 : tensor
79+
:math:`R^2` estimate
80+
variables_importance : tensor, optional
81+
Optional estimate for variables importance, positive, , by default None
82+
variance_explained : tensor, optional
83+
Alternative estimate for variables importance which is point estimate of
84+
variance explained, should sum up to one, by default None
85+
r2_std : tensor, optional
86+
Optional uncertainty over :math:`R^2`, by default None
87+
positive_probs : tensor, optional
88+
Optional probability of variables contribution to be positive, by default 0.5
89+
positive_probs_std : tensor, optional
90+
Optional uncertainty over effect direction probability, by default None
91+
centered : bool, optional
92+
Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both
93+
94+
Returns
95+
-------
96+
residual_variance, coefficients
97+
Output variance is split in residual variance and explained variance.
98+
99+
Raises
100+
------
101+
TypeError
102+
If parametrization is wrong.
103+
104+
Notes
105+
-----
106+
- ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
107+
- R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
108+
"""
109+
if not isinstance(dims, (list, tuple)):
110+
dims = (dims,)
111+
*hierarchy, dim = dims
112+
param_sigma = pt.as_tensor(param_sigma)
113+
variance = pt.as_tensor(variance)
114+
with pm.Model(name) as model:
115+
if variables_importance is not None and len(model.coords[dim]) > 1:
116+
if variance_explained is not None:
117+
raise TypeError("Can't use variable importance with variance explained")
118+
phi = pm.Dirichlet("phi", pt.as_tensor(variables_importance), dims=hierarchy + [dim])
119+
elif variance_explained:
120+
phi = pt.as_tensor(variance_explained)
121+
else:
122+
phi = 1 / len(model.coords[dim])
123+
if r2_std is not None:
124+
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=hierarchy)
125+
if positive_probs_std is not None:
126+
psi = pm.Beta(
127+
"psi",
128+
mu=pt.as_tensor(positive_probs),
129+
sigma=pt.as_tensor(positive_probs_std),
130+
dims=hierarchy + [dim],
131+
)
132+
else:
133+
psi = pt.as_tensor(positive_probs)
134+
beta = _R2D2M2CP_beta(
135+
name, variance, param_sigma, r2, phi, psi, dims=hierarchy + [dim], centered=centered
136+
)
137+
variance_resid = (1 - r2) * variance
138+
return variance_resid, beta

0 commit comments

Comments
 (0)