15
15
16
16
from typing import Sequence , Union
17
17
18
+ import numpy as np
18
19
import pymc as pm
19
20
import pytensor .tensor as pt
20
21
21
22
__all__ = ["R2D2M2CP" ]
22
23
23
24
24
- def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable ):
25
+ def _psivar2musigma (psi : pt .TensorVariable , explained_var : pt .TensorVariable , psi_mask ):
25
26
pi = pt .erfinv (2 * psi - 1 )
26
27
f = (1 / (2 * pi ** 2 + 1 )) ** 0.5
27
28
sigma = explained_var ** 0.5 * f
28
29
mu = sigma * pi * 2 ** 0.5
29
- return mu , sigma
30
+ if psi_mask is not None :
31
+ return (
32
+ pt .where (psi_mask , mu , pt .sign (pi ) * explained_var ** 0.5 ),
33
+ pt .where (psi_mask , sigma , 0 ),
34
+ )
35
+ else :
36
+ return mu , sigma
30
37
31
38
32
39
def _R2D2M2CP_beta (
@@ -37,6 +44,7 @@ def _R2D2M2CP_beta(
37
44
phi : pt .TensorVariable ,
38
45
psi : pt .TensorVariable ,
39
46
* ,
47
+ psi_mask ,
40
48
dims : Union [str , Sequence [str ]],
41
49
centered = False ,
42
50
):
@@ -59,16 +67,141 @@ def _R2D2M2CP_beta(
59
67
"""
60
68
tau2 = r2 / (1 - r2 )
61
69
explained_variance = phi * pt .expand_dims (tau2 * output_sigma ** 2 , - 1 )
62
- mu_param , std_param = _psivar2musigma (psi , explained_variance )
70
+ mu_param , std_param = _psivar2musigma (psi , explained_variance , psi_mask = psi_mask )
63
71
if not centered :
64
72
with pm .Model (name ):
65
- raw = pm .Normal ("raw" , dims = dims )
73
+ if psi_mask is not None and psi_mask .any ():
74
+ # limit case where some probs are not 1 or 0
75
+ # setsubtensor is required
76
+ r_idx = psi_mask .nonzero ()
77
+ with pm .Model ("raw" ):
78
+ raw = pm .Normal ("masked" , shape = len (r_idx [0 ]))
79
+ raw = pt .set_subtensor (pt .zeros_like (mu_param )[r_idx ], raw )
80
+ raw = pm .Deterministic ("raw" , raw , dims = dims )
81
+ elif psi_mask is not None :
82
+ # all variables are deterministic
83
+ raw = pt .zeros_like (mu_param )
84
+ else :
85
+ raw = pm .Normal ("raw" , dims = dims )
66
86
beta = pm .Deterministic (name , (raw * std_param + mu_param ) / input_sigma , dims = dims )
67
87
else :
68
- beta = pm .Normal (name , mu_param / input_sigma , std_param / input_sigma , dims = dims )
88
+ if psi_mask is not None and psi_mask .any ():
89
+ # limit case where some probs are not 1 or 0
90
+ # setsubtensor is required
91
+ r_idx = psi_mask .nonzero ()
92
+ with pm .Model (name ):
93
+ mean = (mu_param / input_sigma )[r_idx ]
94
+ sigma = (std_param / input_sigma )[r_idx ]
95
+ masked = pm .Normal (
96
+ "masked" ,
97
+ mean ,
98
+ sigma ,
99
+ shape = len (r_idx [0 ]),
100
+ )
101
+ beta = pt .set_subtensor (mean , masked )
102
+ beta = pm .Deterministic (name , beta , dims = dims )
103
+ elif psi_mask is not None :
104
+ # all variables are deterministic
105
+ beta = pm .Deterministic (name , (mu_param / input_sigma ), dims = dims )
106
+ else :
107
+ beta = pm .Normal (name , mu_param / input_sigma , std_param / input_sigma , dims = dims )
69
108
return beta
70
109
71
110
111
+ def _broadcast_as_dims (* values , dims ):
112
+ model = pm .modelcontext (None )
113
+ shape = [len (model .coords [d ]) for d in dims ]
114
+ ret = tuple (np .broadcast_to (v , shape ) for v in values )
115
+ # strip output
116
+ if len (values ) == 1 :
117
+ ret = ret [0 ]
118
+ return ret
119
+
120
+
121
+ def _psi_masked (positive_probs , positive_probs_std , * , dims ):
122
+ if not (
123
+ isinstance (positive_probs , pt .Constant ) and isinstance (positive_probs_std , pt .Constant )
124
+ ):
125
+ raise TypeError (
126
+ "Only constant values for positive_probs and positive_probs_std are accepted"
127
+ )
128
+ positive_probs , positive_probs_std = _broadcast_as_dims (
129
+ positive_probs .data , positive_probs_std .data , dims = dims
130
+ )
131
+ mask = ~ np .bitwise_or (positive_probs == 1 , positive_probs == 0 )
132
+ if np .bitwise_and (~ mask , positive_probs_std != 0 ).any ():
133
+ raise ValueError ("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0" )
134
+ if (~ mask ).any () and mask .any ():
135
+ # limit case where some probs are not 1 or 0
136
+ # setsubtensor is required
137
+ r_idx = mask .nonzero ()
138
+ with pm .Model ("psi" ):
139
+ psi = pm .Beta (
140
+ "masked" ,
141
+ mu = positive_probs [r_idx ],
142
+ sigma = positive_probs_std [r_idx ],
143
+ shape = len (r_idx [0 ]),
144
+ )
145
+ psi = pt .set_subtensor (pt .as_tensor (positive_probs )[r_idx ], psi )
146
+ psi = pm .Deterministic ("psi" , psi , dims = dims )
147
+ elif (~ mask ).all ():
148
+ # limit case where all the probs are limit case
149
+ psi = pt .as_tensor (positive_probs )
150
+ else :
151
+ psi = pm .Beta ("psi" , mu = positive_probs , sigma = positive_probs_std , dims = dims )
152
+ mask = None
153
+ return mask , psi
154
+
155
+
156
+ def _psi (positive_probs , positive_probs_std , * , dims ):
157
+ if positive_probs_std is not None :
158
+ mask , psi = _psi_masked (
159
+ positive_probs = pt .as_tensor (positive_probs ),
160
+ positive_probs_std = pt .as_tensor (positive_probs_std ),
161
+ dims = dims ,
162
+ )
163
+ else :
164
+ positive_probs = pt .as_tensor (positive_probs )
165
+ if not isinstance (positive_probs , pt .Constant ):
166
+ raise TypeError ("Only constant values for positive_probs are allowed" )
167
+ psi = _broadcast_as_dims (positive_probs .data , dims = dims )
168
+ mask = np .atleast_1d (~ np .bitwise_or (psi == 1 , psi == 0 ))
169
+ if mask .all ():
170
+ mask = None
171
+ return mask , psi
172
+
173
+
174
+ def _phi (
175
+ variables_importance ,
176
+ variance_explained ,
177
+ importance_concentration ,
178
+ * ,
179
+ dims ,
180
+ ):
181
+ * broadcast_dims , dim = dims
182
+ model = pm .modelcontext (None )
183
+ if variables_importance is not None :
184
+ if variance_explained is not None :
185
+ raise TypeError ("Can't use variable importance with variance explained" )
186
+ if len (model .coords [dim ]) <= 1 :
187
+ raise TypeError ("Can't use variable importance with less than two variables" )
188
+ variables_importance = pt .as_tensor (variables_importance )
189
+ if importance_concentration is not None :
190
+ variables_importance *= importance_concentration
191
+ return pm .Dirichlet ("phi" , variables_importance , dims = broadcast_dims + [dim ])
192
+ elif variance_explained is not None :
193
+ if len (model .coords [dim ]) <= 1 :
194
+ raise TypeError ("Can't use variance explained with less than two variables" )
195
+ phi = pt .as_tensor (variance_explained )
196
+ else :
197
+ phi = 1 / len (model .coords [dim ])
198
+ phi = _broadcast_as_dims (phi , dims = dims )
199
+ if importance_concentration is not None :
200
+ return pm .Dirichlet ("phi" , importance_concentration * phi , dims = broadcast_dims + [dim ])
201
+ else :
202
+ return phi
203
+
204
+
72
205
def R2D2M2CP (
73
206
name ,
74
207
output_sigma ,
@@ -78,6 +211,7 @@ def R2D2M2CP(
78
211
r2 ,
79
212
variables_importance = None ,
80
213
variance_explained = None ,
214
+ importance_concentration = None ,
81
215
r2_std = None ,
82
216
positive_probs = 0.5 ,
83
217
positive_probs_std = None ,
@@ -102,6 +236,8 @@ def R2D2M2CP(
102
236
variance_explained : tensor, optional
103
237
Alternative estimate for variables importance which is point estimate of
104
238
variance explained, should sum up to one, by default None
239
+ importance_concentration : tensor, optional
240
+ Confidence around variance explained or variable importance estimate
105
241
r2_std : tensor, optional
106
242
Optional uncertainty over :math:`R^2`, by default None
107
243
positive_probs : tensor, optional
@@ -125,8 +261,8 @@ def R2D2M2CP(
125
261
-----
126
262
The R2D2M2CP prior is a modification of R2D2M2 prior.
127
263
128
- - ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
129
- - R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
264
+ - ``(R2D2M2)`` CP is taken from https://arxiv.org/abs/2208.07132
265
+ - R2D2M2 ``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
130
266
131
267
Examples
132
268
--------
@@ -259,31 +395,20 @@ def R2D2M2CP(
259
395
input_sigma = pt .as_tensor (input_sigma )
260
396
output_sigma = pt .as_tensor (output_sigma )
261
397
with pm .Model (name ) as model :
262
- if variables_importance is not None :
263
- if variance_explained is not None :
264
- raise TypeError ("Can't use variable importance with variance explained" )
265
- if len (model .coords [dim ]) <= 1 :
266
- raise TypeError ("Can't use variable importance with less than two variables" )
267
- phi = pm .Dirichlet (
268
- "phi" , pt .as_tensor (variables_importance ), dims = broadcast_dims + [dim ]
269
- )
270
- elif variance_explained is not None :
271
- if len (model .coords [dim ]) <= 1 :
272
- raise TypeError ("Can't use variance explained with less than two variables" )
273
- phi = pt .as_tensor (variance_explained )
274
- else :
275
- phi = 1 / len (model .coords [dim ])
398
+ if not all (isinstance (model .dim_lengths [d ], pt .TensorConstant ) for d in dims ):
399
+ raise ValueError (f"{ dims !r} should be constant length immutable dims" )
276
400
if r2_std is not None :
277
401
r2 = pm .Beta ("r2" , mu = r2 , sigma = r2_std , dims = broadcast_dims )
278
- if positive_probs_std is not None :
279
- psi = pm .Beta (
280
- "psi" ,
281
- mu = pt .as_tensor (positive_probs ),
282
- sigma = pt .as_tensor (positive_probs_std ),
283
- dims = broadcast_dims + [dim ],
284
- )
285
- else :
286
- psi = pt .as_tensor (positive_probs )
402
+ phi = _phi (
403
+ variables_importance = variables_importance ,
404
+ variance_explained = variance_explained ,
405
+ importance_concentration = importance_concentration ,
406
+ dims = dims ,
407
+ )
408
+ mask , psi = _psi (
409
+ positive_probs = positive_probs , positive_probs_std = positive_probs_std , dims = dims
410
+ )
411
+
287
412
beta = _R2D2M2CP_beta (
288
413
name ,
289
414
output_sigma ,
@@ -293,6 +418,7 @@ def R2D2M2CP(
293
418
psi ,
294
419
dims = broadcast_dims + [dim ],
295
420
centered = centered ,
421
+ psi_mask = mask ,
296
422
)
297
423
resid_sigma = (1 - r2 ) ** 0.5 * output_sigma
298
424
return resid_sigma , beta
0 commit comments