13
13
# limitations under the License.
14
14
15
15
16
- from typing import Sequence , Union
16
+ from collections import namedtuple
17
+ from typing import Sequence , Tuple , Union
17
18
18
19
import numpy as np
19
20
import pymc as pm
22
23
__all__ = ["R2D2M2CP" ]
23
24
24
25
25
- def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable , psi_mask ):
26
+ def _psivar2musigma (
27
+ psi : pt .TensorVariable ,
28
+ explained_var : pt .TensorVariable ,
29
+ psi_mask : Union [pt .TensorLike , None ],
30
+ ) -> Tuple [pt .TensorVariable , pt .TensorVariable ]:
31
+ sign = pt .sign (psi - 0.5 )
32
+ if psi_mask is not None :
33
+ # any computation might be ignored for ~psi_mask
34
+ # sign and explained_var are used
35
+ psi = pt .where (psi_mask , psi , 0.5 )
26
36
pi = pt .erfinv (2 * psi - 1 )
27
37
f = (1 / (2 * pi ** 2 + 1 )) ** 0.5
28
38
sigma = explained_var ** 0.5 * f
29
39
mu = sigma * pi * 2 ** 0.5
30
40
if psi_mask is not None :
31
41
return (
32
- pt .where (psi_mask , mu , pt . sign ( pi ) * explained_var ** 0.5 ),
42
+ pt .where (psi_mask , mu , sign * explained_var ** 0.5 ),
33
43
pt .where (psi_mask , sigma , 0 ),
34
44
)
35
45
else :
@@ -47,7 +57,7 @@ def _R2D2M2CP_beta(
47
57
psi_mask ,
48
58
dims : Union [str , Sequence [str ]],
49
59
centered = False ,
50
- ):
60
+ ) -> pt . TensorVariable :
51
61
"""R2D2M2CP beta prior.
52
62
53
63
Parameters
@@ -65,7 +75,7 @@ def _R2D2M2CP_beta(
65
75
psi: tensor
66
76
probability of a coefficients to be positive
67
77
"""
68
- explained_variance = phi * pt .expand_dims (r2 * output_sigma ** 2 , - 1 )
78
+ explained_variance = phi * pt .expand_dims (r2 * output_sigma ** 2 , ( - 1 ,) )
69
79
mu_param , std_param = _psivar2musigma (psi , explained_variance , psi_mask = psi_mask )
70
80
if not centered :
71
81
with pm .Model (name ):
@@ -107,7 +117,10 @@ def _R2D2M2CP_beta(
107
117
return beta
108
118
109
119
110
- def _broadcast_as_dims (* values , dims ):
120
+ def _broadcast_as_dims (
121
+ * values : np .ndarray ,
122
+ dims : Sequence [str ],
123
+ ) -> Union [Tuple [np .ndarray , ...], np .ndarray ]:
111
124
model = pm .modelcontext (None )
112
125
shape = [len (model .coords [d ]) for d in dims ]
113
126
ret = tuple (np .broadcast_to (v , shape ) for v in values )
@@ -117,7 +130,12 @@ def _broadcast_as_dims(*values, dims):
117
130
return ret
118
131
119
132
120
- def _psi_masked (positive_probs , positive_probs_std , * , dims ):
133
+ def _psi_masked (
134
+ positive_probs : pt .TensorLike ,
135
+ positive_probs_std : pt .TensorLike ,
136
+ * ,
137
+ dims : Sequence [str ],
138
+ ) -> Tuple [Union [pt .TensorLike , None ], pt .TensorVariable ]:
121
139
if not (
122
140
isinstance (positive_probs , pt .Constant ) and isinstance (positive_probs_std , pt .Constant )
123
141
):
@@ -152,7 +170,12 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
152
170
return mask , psi
153
171
154
172
155
- def _psi (positive_probs , positive_probs_std , * , dims ):
173
+ def _psi (
174
+ positive_probs : pt .TensorLike ,
175
+ positive_probs_std : Union [pt .TensorLike , None ],
176
+ * ,
177
+ dims : Sequence [str ],
178
+ ) -> Tuple [Union [pt .TensorLike , None ], pt .TensorVariable ]:
156
179
if positive_probs_std is not None :
157
180
mask , psi = _psi_masked (
158
181
positive_probs = pt .as_tensor (positive_probs ),
@@ -171,12 +194,12 @@ def _psi(positive_probs, positive_probs_std, *, dims):
171
194
172
195
173
196
def _phi (
174
- variables_importance ,
175
- variance_explained ,
176
- importance_concentration ,
197
+ variables_importance : Union [ pt . TensorLike , None ] ,
198
+ variance_explained : Union [ pt . TensorLike , None ] ,
199
+ importance_concentration : Union [ pt . TensorLike , None ] ,
177
200
* ,
178
- dims ,
179
- ):
201
+ dims : Sequence [ str ] ,
202
+ ) -> pt . TensorVariable :
180
203
* broadcast_dims , dim = dims
181
204
model = pm .modelcontext (None )
182
205
if variables_importance is not None :
@@ -200,47 +223,50 @@ def _phi(
200
223
return phi
201
224
202
225
226
+ R2D2M2CPOut = namedtuple ("R2D2M2CPOut" , ["eps" , "beta" ])
227
+
228
+
203
229
def R2D2M2CP (
204
- name ,
205
- output_sigma ,
206
- input_sigma ,
230
+ name : str ,
231
+ output_sigma : pt . TensorLike ,
232
+ input_sigma : pt . TensorLike ,
207
233
* ,
208
- dims ,
209
- r2 ,
210
- variables_importance = None ,
211
- variance_explained = None ,
212
- importance_concentration = None ,
213
- r2_std = None ,
214
- positive_probs = 0.5 ,
215
- positive_probs_std = None ,
216
- centered = False ,
217
- ):
234
+ dims : Sequence [ str ] ,
235
+ r2 : pt . TensorLike ,
236
+ variables_importance : Union [ pt . TensorLike , None ] = None ,
237
+ variance_explained : Union [ pt . TensorLike , None ] = None ,
238
+ importance_concentration : Union [ pt . TensorLike , None ] = None ,
239
+ r2_std : Union [ pt . TensorLike , None ] = None ,
240
+ positive_probs : Union [ pt . TensorLike , None ] = 0.5 ,
241
+ positive_probs_std : Union [ pt . TensorLike , None ] = None ,
242
+ centered : bool = False ,
243
+ ) -> R2D2M2CPOut :
218
244
"""R2D2M2CP Prior.
219
245
220
246
Parameters
221
247
----------
222
248
name : str
223
249
Name for the distribution
224
- output_sigma : tensor
250
+ output_sigma : Tensor
225
251
Output standard deviation
226
- input_sigma : tensor
252
+ input_sigma : Tensor
227
253
Input standard deviation
228
254
dims : Union[str, Sequence[str]]
229
255
Dims for the distribution
230
- r2 : tensor
256
+ r2 : Tensor
231
257
:math:`R^2` estimate
232
- variables_importance : tensor , optional
258
+ variables_importance : Tensor , optional
233
259
Optional estimate for variables importance, positive, by default None
234
- variance_explained : tensor , optional
260
+ variance_explained : Tensor , optional
235
261
Alternative estimate for variables importance which is point estimate of
236
262
variance explained, should sum up to one, by default None
237
- importance_concentration : tensor , optional
263
+ importance_concentration : Tensor , optional
238
264
Confidence around variance explained or variable importance estimate
239
- r2_std : tensor , optional
265
+ r2_std : Tensor , optional
240
266
Optional uncertainty over :math:`R^2`, by default None
241
- positive_probs : tensor , optional
267
+ positive_probs : Tensor , optional
242
268
Optional probability of variables contribution to be positive, by default 0.5
243
- positive_probs_std : tensor , optional
269
+ positive_probs_std : Tensor , optional
244
270
Optional uncertainty over effect direction probability, by default None
245
271
centered : bool, optional
246
272
Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both
@@ -419,4 +445,4 @@ def R2D2M2CP(
419
445
psi_mask = mask ,
420
446
)
421
447
resid_sigma = (1 - r2 ) ** 0.5 * output_sigma
422
- return resid_sigma , beta
448
+ return R2D2M2CPOut ( resid_sigma , beta )
0 commit comments