Skip to content

Commit 356232c

Browse files
authored
Implement limit case of R2D2M2CP for P (#182)
* add helper function to initialize masked psi * split initialization functions into helpers * add importance concentration parameter * rework non centered case for beta init * fix sign for mean sigma helper * add centered parametrization for the limit case * fix typo * fix typo * fix implementation for non limit cases * make positive tests pass * check limit case requires std 0 * assert masked variables are created * add failing test * fix nans * fix corner cases * fix bug with missed variable * add skipif float32 * add requirement for dims to be immutable for the prior as it is required for the limit case masking * fix error message
1 parent 2da3c81 commit 356232c

File tree

2 files changed

+284
-42
lines changed

2 files changed

+284
-42
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

+156-30
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,25 @@
1515

1616
from typing import Sequence, Union
1717

18+
import numpy as np
1819
import pymc as pm
1920
import pytensor.tensor as pt
2021

2122
__all__ = ["R2D2M2CP"]
2223

2324

24-
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable):
25+
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, psi_mask):
2526
pi = pt.erfinv(2 * psi - 1)
2627
f = (1 / (2 * pi**2 + 1)) ** 0.5
2728
sigma = explained_var**0.5 * f
2829
mu = sigma * pi * 2**0.5
29-
return mu, sigma
30+
if psi_mask is not None:
31+
return (
32+
pt.where(psi_mask, mu, pt.sign(pi) * explained_var**0.5),
33+
pt.where(psi_mask, sigma, 0),
34+
)
35+
else:
36+
return mu, sigma
3037

3138

3239
def _R2D2M2CP_beta(
@@ -37,6 +44,7 @@ def _R2D2M2CP_beta(
3744
phi: pt.TensorVariable,
3845
psi: pt.TensorVariable,
3946
*,
47+
psi_mask,
4048
dims: Union[str, Sequence[str]],
4149
centered=False,
4250
):
@@ -59,16 +67,141 @@ def _R2D2M2CP_beta(
5967
"""
6068
tau2 = r2 / (1 - r2)
6169
explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1)
62-
mu_param, std_param = _psivar2musigma(psi, explained_variance)
70+
mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask)
6371
if not centered:
6472
with pm.Model(name):
65-
raw = pm.Normal("raw", dims=dims)
73+
if psi_mask is not None and psi_mask.any():
74+
# limit case where some probs are not 1 or 0
75+
# setsubtensor is required
76+
r_idx = psi_mask.nonzero()
77+
with pm.Model("raw"):
78+
raw = pm.Normal("masked", shape=len(r_idx[0]))
79+
raw = pt.set_subtensor(pt.zeros_like(mu_param)[r_idx], raw)
80+
raw = pm.Deterministic("raw", raw, dims=dims)
81+
elif psi_mask is not None:
82+
# all variables are deterministic
83+
raw = pt.zeros_like(mu_param)
84+
else:
85+
raw = pm.Normal("raw", dims=dims)
6686
beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims)
6787
else:
68-
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
88+
if psi_mask is not None and psi_mask.any():
89+
# limit case where some probs are not 1 or 0
90+
# setsubtensor is required
91+
r_idx = psi_mask.nonzero()
92+
with pm.Model(name):
93+
mean = (mu_param / input_sigma)[r_idx]
94+
sigma = (std_param / input_sigma)[r_idx]
95+
masked = pm.Normal(
96+
"masked",
97+
mean,
98+
sigma,
99+
shape=len(r_idx[0]),
100+
)
101+
beta = pt.set_subtensor(mean, masked)
102+
beta = pm.Deterministic(name, beta, dims=dims)
103+
elif psi_mask is not None:
104+
# all variables are deterministic
105+
beta = pm.Deterministic(name, (mu_param / input_sigma), dims=dims)
106+
else:
107+
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
69108
return beta
70109

71110

111+
def _broadcast_as_dims(*values, dims):
112+
model = pm.modelcontext(None)
113+
shape = [len(model.coords[d]) for d in dims]
114+
ret = tuple(np.broadcast_to(v, shape) for v in values)
115+
# strip output
116+
if len(values) == 1:
117+
ret = ret[0]
118+
return ret
119+
120+
121+
def _psi_masked(positive_probs, positive_probs_std, *, dims):
122+
if not (
123+
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
124+
):
125+
raise TypeError(
126+
"Only constant values for positive_probs and positive_probs_std are accepted"
127+
)
128+
positive_probs, positive_probs_std = _broadcast_as_dims(
129+
positive_probs.data, positive_probs_std.data, dims=dims
130+
)
131+
mask = ~np.bitwise_or(positive_probs == 1, positive_probs == 0)
132+
if np.bitwise_and(~mask, positive_probs_std != 0).any():
133+
raise ValueError("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0")
134+
if (~mask).any() and mask.any():
135+
# limit case where some probs are not 1 or 0
136+
# setsubtensor is required
137+
r_idx = mask.nonzero()
138+
with pm.Model("psi"):
139+
psi = pm.Beta(
140+
"masked",
141+
mu=positive_probs[r_idx],
142+
sigma=positive_probs_std[r_idx],
143+
shape=len(r_idx[0]),
144+
)
145+
psi = pt.set_subtensor(pt.as_tensor(positive_probs)[r_idx], psi)
146+
psi = pm.Deterministic("psi", psi, dims=dims)
147+
elif (~mask).all():
148+
# limit case where all the probs are limit case
149+
psi = pt.as_tensor(positive_probs)
150+
else:
151+
psi = pm.Beta("psi", mu=positive_probs, sigma=positive_probs_std, dims=dims)
152+
mask = None
153+
return mask, psi
154+
155+
156+
def _psi(positive_probs, positive_probs_std, *, dims):
157+
if positive_probs_std is not None:
158+
mask, psi = _psi_masked(
159+
positive_probs=pt.as_tensor(positive_probs),
160+
positive_probs_std=pt.as_tensor(positive_probs_std),
161+
dims=dims,
162+
)
163+
else:
164+
positive_probs = pt.as_tensor(positive_probs)
165+
if not isinstance(positive_probs, pt.Constant):
166+
raise TypeError("Only constant values for positive_probs are allowed")
167+
psi = _broadcast_as_dims(positive_probs.data, dims=dims)
168+
mask = np.atleast_1d(~np.bitwise_or(psi == 1, psi == 0))
169+
if mask.all():
170+
mask = None
171+
return mask, psi
172+
173+
174+
def _phi(
175+
variables_importance,
176+
variance_explained,
177+
importance_concentration,
178+
*,
179+
dims,
180+
):
181+
*broadcast_dims, dim = dims
182+
model = pm.modelcontext(None)
183+
if variables_importance is not None:
184+
if variance_explained is not None:
185+
raise TypeError("Can't use variable importance with variance explained")
186+
if len(model.coords[dim]) <= 1:
187+
raise TypeError("Can't use variable importance with less than two variables")
188+
variables_importance = pt.as_tensor(variables_importance)
189+
if importance_concentration is not None:
190+
variables_importance *= importance_concentration
191+
return pm.Dirichlet("phi", variables_importance, dims=broadcast_dims + [dim])
192+
elif variance_explained is not None:
193+
if len(model.coords[dim]) <= 1:
194+
raise TypeError("Can't use variance explained with less than two variables")
195+
phi = pt.as_tensor(variance_explained)
196+
else:
197+
phi = 1 / len(model.coords[dim])
198+
phi = _broadcast_as_dims(phi, dims=dims)
199+
if importance_concentration is not None:
200+
return pm.Dirichlet("phi", importance_concentration * phi, dims=broadcast_dims + [dim])
201+
else:
202+
return phi
203+
204+
72205
def R2D2M2CP(
73206
name,
74207
output_sigma,
@@ -78,6 +211,7 @@ def R2D2M2CP(
78211
r2,
79212
variables_importance=None,
80213
variance_explained=None,
214+
importance_concentration=None,
81215
r2_std=None,
82216
positive_probs=0.5,
83217
positive_probs_std=None,
@@ -102,6 +236,8 @@ def R2D2M2CP(
102236
variance_explained : tensor, optional
103237
Alternative estimate for variables importance which is point estimate of
104238
variance explained, should sum up to one, by default None
239+
importance_concentration : tensor, optional
240+
Confidence around variance explained or variable importance estimate
105241
r2_std : tensor, optional
106242
Optional uncertainty over :math:`R^2`, by default None
107243
positive_probs : tensor, optional
@@ -125,8 +261,8 @@ def R2D2M2CP(
125261
-----
126262
The R2D2M2CP prior is a modification of R2D2M2 prior.
127263
128-
- ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
129-
- R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
264+
- ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132
265+
- R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
130266
131267
Examples
132268
--------
@@ -259,31 +395,20 @@ def R2D2M2CP(
259395
input_sigma = pt.as_tensor(input_sigma)
260396
output_sigma = pt.as_tensor(output_sigma)
261397
with pm.Model(name) as model:
262-
if variables_importance is not None:
263-
if variance_explained is not None:
264-
raise TypeError("Can't use variable importance with variance explained")
265-
if len(model.coords[dim]) <= 1:
266-
raise TypeError("Can't use variable importance with less than two variables")
267-
phi = pm.Dirichlet(
268-
"phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim]
269-
)
270-
elif variance_explained is not None:
271-
if len(model.coords[dim]) <= 1:
272-
raise TypeError("Can't use variance explained with less than two variables")
273-
phi = pt.as_tensor(variance_explained)
274-
else:
275-
phi = 1 / len(model.coords[dim])
398+
if not all(isinstance(model.dim_lengths[d], pt.TensorConstant) for d in dims):
399+
raise ValueError(f"{dims!r} should be constant length immutable dims")
276400
if r2_std is not None:
277401
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
278-
if positive_probs_std is not None:
279-
psi = pm.Beta(
280-
"psi",
281-
mu=pt.as_tensor(positive_probs),
282-
sigma=pt.as_tensor(positive_probs_std),
283-
dims=broadcast_dims + [dim],
284-
)
285-
else:
286-
psi = pt.as_tensor(positive_probs)
402+
phi = _phi(
403+
variables_importance=variables_importance,
404+
variance_explained=variance_explained,
405+
importance_concentration=importance_concentration,
406+
dims=dims,
407+
)
408+
mask, psi = _psi(
409+
positive_probs=positive_probs, positive_probs_std=positive_probs_std, dims=dims
410+
)
411+
287412
beta = _R2D2M2CP_beta(
288413
name,
289414
output_sigma,
@@ -293,6 +418,7 @@ def R2D2M2CP(
293418
psi,
294419
dims=broadcast_dims + [dim],
295420
centered=centered,
421+
psi_mask=mask,
296422
)
297423
resid_sigma = (1 - r2) ** 0.5 * output_sigma
298424
return resid_sigma, beta

0 commit comments

Comments
 (0)