@@ -70,17 +70,19 @@ def _R2D2M2CP_beta(
70
70
mu_param , std_param = _psivar2musigma (psi , explained_variance , psi_mask = psi_mask )
71
71
if not centered :
72
72
with pm .Model (name ):
73
- if psi_mask is not None :
73
+ if psi_mask is not None and psi_mask . any () :
74
74
r_idx = psi_mask .nonzero ()
75
75
with pm .Model ("raw" ):
76
76
raw = pm .Normal ("masked" , shape = len (r_idx [0 ]))
77
77
raw = pt .set_subtensor (pt .zeros_like (mu_param )[r_idx ], raw )
78
78
raw = pm .Deterministic ("raw" , raw , dims = dims )
79
+ elif psi_mask is not None :
80
+ raw = pt .zeros_like (mu_param )
79
81
else :
80
82
raw = pm .Normal ("raw" , dims = dims )
81
83
beta = pm .Deterministic (name , (raw * std_param + mu_param ) / input_sigma , dims = dims )
82
84
else :
83
- if psi_mask is not None :
85
+ if psi_mask is not None and psi_mask . any () :
84
86
r_idx = psi_mask .nonzero ()
85
87
with pm .Model (name ):
86
88
mean = (mu_param / input_sigma )[r_idx ]
@@ -93,6 +95,8 @@ def _R2D2M2CP_beta(
93
95
)
94
96
beta = pt .set_subtensor (mean , masked )
95
97
beta = pm .Deterministic (name , beta , dims = dims )
98
+ elif psi_mask is not None :
99
+ beta = mean
96
100
else :
97
101
beta = pm .Normal (name , mu_param / input_sigma , std_param / input_sigma , dims = dims )
98
102
return beta
@@ -121,7 +125,9 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
121
125
mask = ~ np .bitwise_or (positive_probs == 1 , positive_probs == 0 )
122
126
if np .bitwise_and (~ mask , positive_probs_std != 0 ).any ():
123
127
raise ValueError ("Can't have both positive_probs == '1 or 0' and positive_probs_std != 0" )
124
- if (~ mask ).any ():
128
+ if (~ mask ).any () and mask .any ():
129
+ # limit case where some probs are not 1 or 0
130
+ # setsubtensor is required
125
131
r_idx = mask .nonzero ()
126
132
with pm .Model ("psi" ):
127
133
psi = pm .Beta (
@@ -132,6 +138,9 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
132
138
)
133
139
psi = pt .set_subtensor (pt .as_tensor (positive_probs )[r_idx ], psi )
134
140
psi = pm .Deterministic ("psi" , psi , dims = dims )
141
+ elif (~ mask ).all ():
142
+ # limit case where all the probs are limit case
143
+ psi = pt .as_tensor (positive_probs )
135
144
else :
136
145
psi = pm .Beta ("psi" , mu = positive_probs , sigma = positive_probs_std , dims = dims )
137
146
mask = None
0 commit comments