22
22
__all__ = ["R2D2M2CP" ]
23
23
24
24
25
- def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable ):
25
+ def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable , psi_mask ):
26
26
pi = pt .erfinv (2 * psi - 1 )
27
27
f = (1 / (2 * pi ** 2 + 1 )) ** 0.5
28
28
sigma = explained_var ** 0.5 * f
29
29
mu = sigma * pi * 2 ** 0.5
30
- return mu , sigma
30
+ if psi_mask is not None :
31
+ return (
32
+ pt .where (psi_mask , mu , explained_var ** 0.5 ),
33
+ pt .where (psi_mask , sigma , 0 ),
34
+ )
35
+ else :
36
+ return mu , sigma
31
37
32
38
33
39
def _R2D2M2CP_beta (
@@ -38,6 +44,7 @@ def _R2D2M2CP_beta(
38
44
phi : pt .TensorVariable ,
39
45
psi : pt .TensorVariable ,
40
46
* ,
47
+ psi_mask ,
41
48
dims : Union [str , Sequence [str ]],
42
49
centered = False ,
43
50
):
@@ -60,7 +67,7 @@ def _R2D2M2CP_beta(
60
67
"""
61
68
tau2 = r2 / (1 - r2 )
62
69
explained_variance = phi * pt .expand_dims (tau2 * output_sigma ** 2 , - 1 )
63
- mu_param , std_param = _psivar2musigma (psi , explained_variance )
70
+ mu_param , std_param = _psivar2musigma (psi , explained_variance , psi_mask = psi_mask )
64
71
if not centered :
65
72
with pm .Model (name ):
66
73
raw = pm .Normal ("raw" , dims = dims )
@@ -102,9 +109,53 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
102
109
psi = pm .Deterministic ("psi" , psi , dims = dims )
103
110
else :
104
111
psi = pm .Beta ("psi" , mu = positive_probs , sigma = positive_probs_std , dims = dims )
112
+ mask = None
105
113
return mask , psi
106
114
107
115
116
+ def _psi (positive_probs , positive_probs_std , * , dims ):
117
+ if positive_probs_std is not None :
118
+ mask , psi = _psi_masked (
119
+ "psi" ,
120
+ mu = pt .as_tensor (positive_probs ),
121
+ sigma = pt .as_tensor (positive_probs_std ),
122
+ dims = dims ,
123
+ )
124
+ else :
125
+ positive_probs = pt .as_tensor (positive_probs )
126
+ if not isinstance (positive_probs , pt .Constant ):
127
+ raise TypeError ("Only constant values for positive_probs are allowed" )
128
+ psi = _broadcast_as_dims (positive_probs .data , dims = dims )
129
+ mask = psi != 1
130
+ if (mask ).all ():
131
+ mask = None
132
+ return mask , psi
133
+
134
+
135
+ def _phi (
136
+ variables_importance ,
137
+ variance_explained ,
138
+ variance_explained_concentration ,
139
+ * ,
140
+ dims ,
141
+ ):
142
+ * broadcast_dims , dim = dims
143
+ model = pm .modelcontext (None )
144
+ if variables_importance is not None :
145
+ if variance_explained is not None :
146
+ raise TypeError ("Can't use variable importance with variance explained" )
147
+ if len (model .coords [dim ]) <= 1 :
148
+ raise TypeError ("Can't use variable importance with less than two variables" )
149
+ phi = pm .Dirichlet ("phi" , pt .as_tensor (variables_importance ), dims = broadcast_dims + [dim ])
150
+ elif variance_explained is not None :
151
+ if len (model .coords [dim ]) <= 1 :
152
+ raise TypeError ("Can't use variance explained with less than two variables" )
153
+ phi = pt .as_tensor (variance_explained )
154
+ else :
155
+ phi = pt .as_tensor (1 / len (model .coords [dim ]))
156
+ return phi
157
+
158
+
108
159
def R2D2M2CP (
109
160
name ,
110
161
output_sigma ,
@@ -114,6 +165,7 @@ def R2D2M2CP(
114
165
r2 ,
115
166
variables_importance = None ,
116
167
variance_explained = None ,
168
+ variance_explained_concentration = None ,
117
169
r2_std = None ,
118
170
positive_probs = 0.5 ,
119
171
positive_probs_std = None ,
@@ -138,6 +190,8 @@ def R2D2M2CP(
138
190
variance_explained : tensor, optional
139
191
Alternative estimate for variables importance which is point estimate of
140
192
variance explained, should sum up to one, by default None
193
+ variance_explained_concentration : tensor, optional
194
+ Confidence around variance explained estimate
141
195
r2_std : tensor, optional
142
196
Optional uncertainty over :math:`R^2`, by default None
143
197
positive_probs : tensor, optional
@@ -161,8 +215,8 @@ def R2D2M2CP(
161
215
-----
162
216
The R2D2M2CP prior is a modification of R2D2M2 prior.
163
217
164
- - ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
165
- - R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
218
+ - ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132
219
+ - R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
166
220
167
221
Examples
168
222
--------
@@ -294,33 +348,17 @@ def R2D2M2CP(
294
348
* broadcast_dims , dim = dims
295
349
input_sigma = pt .as_tensor (input_sigma )
296
350
output_sigma = pt .as_tensor (output_sigma )
297
- positive_probs = pt .as_tensor (positive_probs )
298
- with pm .Model (name ) as model :
299
- if variables_importance is not None :
300
- if variance_explained is not None :
301
- raise TypeError ("Can't use variable importance with variance explained" )
302
- if len (model .coords [dim ]) <= 1 :
303
- raise TypeError ("Can't use variable importance with less than two variables" )
304
- phi = pm .Dirichlet (
305
- "phi" , pt .as_tensor (variables_importance ), dims = broadcast_dims + [dim ]
306
- )
307
- elif variance_explained is not None :
308
- if len (model .coords [dim ]) <= 1 :
309
- raise TypeError ("Can't use variance explained with less than two variables" )
310
- phi = pt .as_tensor (variance_explained )
311
- else :
312
- phi = pt .as_tensor (1 / len (model .coords [dim ]))
351
+ with pm .Model (name ):
313
352
if r2_std is not None :
314
353
r2 = pm .Beta ("r2" , mu = r2 , sigma = r2_std , dims = broadcast_dims )
315
- if positive_probs_std is not None :
316
- psi = pm .Beta (
317
- "psi" ,
318
- mu = pt .as_tensor (positive_probs ),
319
- sigma = pt .as_tensor (positive_probs_std ),
320
- dims = broadcast_dims + [dim ],
321
- )
322
- else :
323
- psi = pt .as_tensor (positive_probs )
354
+ phi = _phi (
355
+ variables_importance = variables_importance ,
356
+ variance_explained = variance_explained ,
357
+ )
358
+ mask , psi = _psi (
359
+ positive_probs = positive_probs , positive_probs_std = positive_probs_std , dims = dims
360
+ )
361
+
324
362
beta = _R2D2M2CP_beta (
325
363
name ,
326
364
output_sigma ,
@@ -330,6 +368,7 @@ def R2D2M2CP(
330
368
psi ,
331
369
dims = broadcast_dims + [dim ],
332
370
centered = centered ,
371
+ psi_mask = mask ,
333
372
)
334
373
resid_sigma = (1 - r2 ) ** 0.5 * output_sigma
335
374
return resid_sigma , beta
0 commit comments