Skip to content

Commit 770a16a

Browse files
committed
split initialization functions into helpers
1 parent dbad516 commit 770a16a

File tree

1 file changed

+69
-30
lines changed
  • pymc_experimental/distributions/multivariate

1 file changed

+69
-30
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

+69-30
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,18 @@
2222
__all__ = ["R2D2M2CP"]
2323

2424

25-
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable):
25+
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, psi_mask):
2626
pi = pt.erfinv(2 * psi - 1)
2727
f = (1 / (2 * pi**2 + 1)) ** 0.5
2828
sigma = explained_var**0.5 * f
2929
mu = sigma * pi * 2**0.5
30-
return mu, sigma
30+
if psi_mask is not None:
31+
return (
32+
pt.where(psi_mask, mu, explained_var**0.5),
33+
pt.where(psi_mask, sigma, 0),
34+
)
35+
else:
36+
return mu, sigma
3137

3238

3339
def _R2D2M2CP_beta(
@@ -38,6 +44,7 @@ def _R2D2M2CP_beta(
3844
phi: pt.TensorVariable,
3945
psi: pt.TensorVariable,
4046
*,
47+
psi_mask,
4148
dims: Union[str, Sequence[str]],
4249
centered=False,
4350
):
@@ -60,7 +67,7 @@ def _R2D2M2CP_beta(
6067
"""
6168
tau2 = r2 / (1 - r2)
6269
explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1)
63-
mu_param, std_param = _psivar2musigma(psi, explained_variance)
70+
mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask)
6471
if not centered:
6572
with pm.Model(name):
6673
raw = pm.Normal("raw", dims=dims)
@@ -102,9 +109,53 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
102109
psi = pm.Deterministic("psi", psi, dims=dims)
103110
else:
104111
psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims)
112+
mask = None
105113
return mask, psi
106114

107115

116+
def _psi(positive_probs, positive_probs_std, *, dims):
117+
if positive_probs_std is not None:
118+
mask, psi = _psi_masked(
119+
"psi",
120+
mu=pt.as_tensor(positive_probs),
121+
sigma=pt.as_tensor(positive_probs_std),
122+
dims=dims,
123+
)
124+
else:
125+
positive_probs = pt.as_tensor(positive_probs)
126+
if not isinstance(positive_probs, pt.Constant):
127+
raise TypeError("Only constant values for positive_probs are allowed")
128+
psi = _broadcast_as_dims(positive_probs.data, dims=dims)
129+
mask = psi != 1
130+
if (mask).all():
131+
mask = None
132+
return mask, psi
133+
134+
135+
def _phi(
136+
variables_importance,
137+
variance_explained,
138+
variance_explained_concentration,
139+
*,
140+
dims,
141+
):
142+
*broadcast_dims, dim = dims
143+
model = pm.modelcontext(None)
144+
if variables_importance is not None:
145+
if variance_explained is not None:
146+
raise TypeError("Can't use variable importance with variance explained")
147+
if len(model.coords[dim]) <= 1:
148+
raise TypeError("Can't use variable importance with less than two variables")
149+
phi = pm.Dirichlet("phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim])
150+
elif variance_explained is not None:
151+
if len(model.coords[dim]) <= 1:
152+
raise TypeError("Can't use variance explained with less than two variables")
153+
phi = pt.as_tensor(variance_explained)
154+
else:
155+
phi = pt.as_tensor(1 / len(model.coords[dim]))
156+
return phi
157+
158+
108159
def R2D2M2CP(
109160
name,
110161
output_sigma,
@@ -114,6 +165,7 @@ def R2D2M2CP(
114165
r2,
115166
variables_importance=None,
116167
variance_explained=None,
168+
variance_explained_concentration=None,
117169
r2_std=None,
118170
positive_probs=0.5,
119171
positive_probs_std=None,
@@ -138,6 +190,8 @@ def R2D2M2CP(
138190
variance_explained : tensor, optional
139191
Alternative estimate for variables importance which is point estimate of
140192
variance explained, should sum up to one, by default None
193+
variance_explained_concentration : tensor, optional
194+
Confidence around variance explained estimate
141195
r2_std : tensor, optional
142196
Optional uncertainty over :math:`R^2`, by default None
143197
positive_probs : tensor, optional
@@ -161,8 +215,8 @@ def R2D2M2CP(
161215
-----
162216
The R2D2M2CP prior is a modification of R2D2M2 prior.
163217
164-
- ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
165-
- R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
218+
- ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132
219+
- R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
166220
167221
Examples
168222
--------
@@ -294,33 +348,17 @@ def R2D2M2CP(
294348
*broadcast_dims, dim = dims
295349
input_sigma = pt.as_tensor(input_sigma)
296350
output_sigma = pt.as_tensor(output_sigma)
297-
positive_probs = pt.as_tensor(positive_probs)
298-
with pm.Model(name) as model:
299-
if variables_importance is not None:
300-
if variance_explained is not None:
301-
raise TypeError("Can't use variable importance with variance explained")
302-
if len(model.coords[dim]) <= 1:
303-
raise TypeError("Can't use variable importance with less than two variables")
304-
phi = pm.Dirichlet(
305-
"phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim]
306-
)
307-
elif variance_explained is not None:
308-
if len(model.coords[dim]) <= 1:
309-
raise TypeError("Can't use variance explained with less than two variables")
310-
phi = pt.as_tensor(variance_explained)
311-
else:
312-
phi = pt.as_tensor(1 / len(model.coords[dim]))
351+
with pm.Model(name):
313352
if r2_std is not None:
314353
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
315-
if positive_probs_std is not None:
316-
psi = pm.Beta(
317-
"psi",
318-
mu=pt.as_tensor(positive_probs),
319-
sigma=pt.as_tensor(positive_probs_std),
320-
dims=broadcast_dims + [dim],
321-
)
322-
else:
323-
psi = pt.as_tensor(positive_probs)
354+
phi = _phi(
355+
variables_importance=variables_importance,
356+
variance_explained=variance_explained,
357+
)
358+
mask, psi = _psi(
359+
positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims
360+
)
361+
324362
beta = _R2D2M2CP_beta(
325363
name,
326364
output_sigma,
@@ -330,6 +368,7 @@ def R2D2M2CP(
330368
psi,
331369
dims=broadcast_dims + [dim],
332370
centered=centered,
371+
psi_mask=mask,
333372
)
334373
resid_sigma = (1 - r2) ** 0.5 * output_sigma
335374
return resid_sigma, beta

0 commit comments

Comments
 (0)