Skip to content

Commit f1ece1c

Browse files
ferrinericardoV94
andauthored
Fix testing and remove warnings for r2d2m2cp (#284)
* add more typehints * Update pymc_experimental/tests/distributions/test_multivariate.py Co-authored-by: Ricardo Vieira <[email protected]> * Update pymc_experimental/tests/distributions/test_multivariate.py Co-authored-by: Ricardo Vieira <[email protected]> * change a letter * change a letter * wrap results into a named tuple --------- Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 00d7a2b commit f1ece1c

File tree

2 files changed

+123
-45
lines changed

2 files changed

+123
-45
lines changed

pymc_experimental/distributions/multivariate/r2d2m2cp.py

+62-36
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@
1313
# limitations under the License.
1414

1515

16-
from typing import Sequence, Union
16+
from collections import namedtuple
17+
from typing import Sequence, Tuple, Union
1718

1819
import numpy as np
1920
import pymc as pm
@@ -22,14 +23,23 @@
2223
__all__ = ["R2D2M2CP"]
2324

2425

25-
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable, psi_mask):
26+
def _psivar2musigma(
27+
psi: pt.TensorVariable,
28+
explained_var: pt.TensorVariable,
29+
psi_mask: Union[pt.TensorLike, None],
30+
) -> Tuple[pt.TensorVariable, pt.TensorVariable]:
31+
sign = pt.sign(psi - 0.5)
32+
if psi_mask is not None:
33+
# any computation might be ignored for ~psi_mask
34+
# sign and explained_var are used
35+
psi = pt.where(psi_mask, psi, 0.5)
2636
pi = pt.erfinv(2 * psi - 1)
2737
f = (1 / (2 * pi**2 + 1)) ** 0.5
2838
sigma = explained_var**0.5 * f
2939
mu = sigma * pi * 2**0.5
3040
if psi_mask is not None:
3141
return (
32-
pt.where(psi_mask, mu, pt.sign(pi) * explained_var**0.5),
42+
pt.where(psi_mask, mu, sign * explained_var**0.5),
3343
pt.where(psi_mask, sigma, 0),
3444
)
3545
else:
@@ -47,7 +57,7 @@ def _R2D2M2CP_beta(
4757
psi_mask,
4858
dims: Union[str, Sequence[str]],
4959
centered=False,
50-
):
60+
) -> pt.TensorVariable:
5161
"""R2D2M2CP beta prior.
5262
5363
Parameters
@@ -65,7 +75,7 @@ def _R2D2M2CP_beta(
6575
psi: tensor
6676
probability of a coefficients to be positive
6777
"""
68-
explained_variance = phi * pt.expand_dims(r2 * output_sigma**2, -1)
78+
explained_variance = phi * pt.expand_dims(r2 * output_sigma**2, (-1,))
6979
mu_param, std_param = _psivar2musigma(psi, explained_variance, psi_mask=psi_mask)
7080
if not centered:
7181
with pm.Model(name):
@@ -107,7 +117,10 @@ def _R2D2M2CP_beta(
107117
return beta
108118

109119

110-
def _broadcast_as_dims(*values, dims):
120+
def _broadcast_as_dims(
121+
*values: np.ndarray,
122+
dims: Sequence[str],
123+
) -> Union[Tuple[np.ndarray, ...], np.ndarray]:
111124
model = pm.modelcontext(None)
112125
shape = [len(model.coords[d]) for d in dims]
113126
ret = tuple(np.broadcast_to(v, shape) for v in values)
@@ -117,7 +130,12 @@ def _broadcast_as_dims(*values, dims):
117130
return ret
118131

119132

120-
def _psi_masked(positive_probs, positive_probs_std, *, dims):
133+
def _psi_masked(
134+
positive_probs: pt.TensorLike,
135+
positive_probs_std: pt.TensorLike,
136+
*,
137+
dims: Sequence[str],
138+
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
121139
if not (
122140
isinstance(positive_probs, pt.Constant) and isinstance(positive_probs_std, pt.Constant)
123141
):
@@ -152,7 +170,12 @@ def _psi_masked(positive_probs, positive_probs_std, *, dims):
152170
return mask, psi
153171

154172

155-
def _psi(positive_probs, positive_probs_std, *, dims):
173+
def _psi(
174+
positive_probs: pt.TensorLike,
175+
positive_probs_std: Union[pt.TensorLike, None],
176+
*,
177+
dims: Sequence[str],
178+
) -> Tuple[Union[pt.TensorLike, None], pt.TensorVariable]:
156179
if positive_probs_std is not None:
157180
mask, psi = _psi_masked(
158181
positive_probs=pt.as_tensor(positive_probs),
@@ -171,12 +194,12 @@ def _psi(positive_probs, positive_probs_std, *, dims):
171194

172195

173196
def _phi(
174-
variables_importance,
175-
variance_explained,
176-
importance_concentration,
197+
variables_importance: Union[pt.TensorLike, None],
198+
variance_explained: Union[pt.TensorLike, None],
199+
importance_concentration: Union[pt.TensorLike, None],
177200
*,
178-
dims,
179-
):
201+
dims: Sequence[str],
202+
) -> pt.TensorVariable:
180203
*broadcast_dims, dim = dims
181204
model = pm.modelcontext(None)
182205
if variables_importance is not None:
@@ -200,47 +223,50 @@ def _phi(
200223
return phi
201224

202225

226+
R2D2M2CPOut = namedtuple("R2D2M2CPOut", ["eps", "beta"])
227+
228+
203229
def R2D2M2CP(
204-
name,
205-
output_sigma,
206-
input_sigma,
230+
name: str,
231+
output_sigma: pt.TensorLike,
232+
input_sigma: pt.TensorLike,
207233
*,
208-
dims,
209-
r2,
210-
variables_importance=None,
211-
variance_explained=None,
212-
importance_concentration=None,
213-
r2_std=None,
214-
positive_probs=0.5,
215-
positive_probs_std=None,
216-
centered=False,
217-
):
234+
dims: Sequence[str],
235+
r2: pt.TensorLike,
236+
variables_importance: Union[pt.TensorLike, None] = None,
237+
variance_explained: Union[pt.TensorLike, None] = None,
238+
importance_concentration: Union[pt.TensorLike, None] = None,
239+
r2_std: Union[pt.TensorLike, None] = None,
240+
positive_probs: Union[pt.TensorLike, None] = 0.5,
241+
positive_probs_std: Union[pt.TensorLike, None] = None,
242+
centered: bool = False,
243+
) -> R2D2M2CPOut:
218244
"""R2D2M2CP Prior.
219245
220246
Parameters
221247
----------
222248
name : str
223249
Name for the distribution
224-
output_sigma : tensor
250+
output_sigma : Tensor
225251
Output standard deviation
226-
input_sigma : tensor
252+
input_sigma : Tensor
227253
Input standard deviation
228254
dims : Union[str, Sequence[str]]
229255
Dims for the distribution
230-
r2 : tensor
256+
r2 : Tensor
231257
:math:`R^2` estimate
232-
variables_importance : tensor, optional
258+
variables_importance : Tensor, optional
233259
Optional estimate for variables importance, positive, by default None
234-
variance_explained : tensor, optional
260+
variance_explained : Tensor, optional
235261
Alternative estimate for variables importance which is point estimate of
236262
variance explained, should sum up to one, by default None
237-
importance_concentration : tensor, optional
263+
importance_concentration : Tensor, optional
238264
Confidence around variance explained or variable importance estimate
239-
r2_std : tensor, optional
265+
r2_std : Tensor, optional
240266
Optional uncertainty over :math:`R^2`, by default None
241-
positive_probs : tensor, optional
267+
positive_probs : Tensor, optional
242268
Optional probability of variables contribution to be positive, by default 0.5
243-
positive_probs_std : tensor, optional
269+
positive_probs_std : Tensor, optional
244270
Optional uncertainty over effect direction probability, by default None
245271
centered : bool, optional
246272
Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both
@@ -419,4 +445,4 @@ def R2D2M2CP(
419445
psi_mask=mask,
420446
)
421447
resid_sigma = (1 - r2) ** 0.5 * output_sigma
422-
return resid_sigma, beta
448+
return R2D2M2CPOut(resid_sigma, beta)

pymc_experimental/tests/distributions/test_multivariate.py

+61-9
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
import numpy as np
22
import pymc as pm
3+
import pytensor
34
import pytest
45

56
import pymc_experimental as pmx
67

78

89
class TestR2D2M2CP:
10+
@pytest.fixture(autouse=True)
11+
def fast_compile(self):
12+
with pytensor.config.change_flags(mode="FAST_COMPILE", exception_verbosity="high"):
13+
yield
14+
915
@pytest.fixture(autouse=True)
1016
def model(self):
1117
# every method is within a model
@@ -95,17 +101,13 @@ def phi_args(self, request, phi_args_base):
95101
phi_args_base["importance_concentration"] = 10
96102
return phi_args_base
97103

98-
def test_init(
104+
def test_init_r2(
99105
self,
100106
dims,
101-
centered,
102107
input_std,
103108
output_std,
104109
r2,
105110
r2_std,
106-
positive_probs,
107-
positive_probs_std,
108-
phi_args,
109111
model: pm.Model,
110112
):
111113
eps, beta = pmx.distributions.R2D2M2CP(
@@ -115,10 +117,6 @@ def test_init(
115117
dims=dims,
116118
r2=r2,
117119
r2_std=r2_std,
118-
centered=centered,
119-
positive_probs_std=positive_probs_std,
120-
positive_probs=positive_probs,
121-
**phi_args
122120
)
123121
assert not np.isnan(beta.eval()).any()
124122
assert eps.eval().shape == output_std.shape
@@ -127,9 +125,63 @@ def test_init(
127125
assert "beta" in model.named_vars
128126
assert ("beta::r2" in model.named_vars) == (r2_std is not None), set(model.named_vars)
129127
# phi is only created if variable importance is not None and there is more than one var
128+
assert np.isfinite(model.compile_logp()(model.initial_point()))
129+
130+
def test_init_importance(
131+
self,
132+
dims,
133+
centered,
134+
input_std,
135+
output_std,
136+
phi_args,
137+
model: pm.Model,
138+
):
139+
eps, beta = pmx.distributions.R2D2M2CP(
140+
"beta",
141+
output_std,
142+
input_std,
143+
dims=dims,
144+
r2=1,
145+
centered=centered,
146+
**phi_args,
147+
)
148+
assert not np.isnan(beta.eval()).any()
149+
assert eps.eval().shape == output_std.shape
150+
assert beta.eval().shape == input_std.shape
151+
# r2 rv is only created if r2 std is not None
152+
assert "beta" in model.named_vars
153+
# phi is only created if variable importance is not None and there is more than one var
130154
assert ("beta::phi" in model.named_vars) == (
131155
"variables_importance" in phi_args or "importance_concentration" in phi_args
132156
), set(model.named_vars)
157+
assert np.isfinite(model.compile_logp()(model.initial_point()))
158+
159+
def test_init_positive_probs(
160+
self,
161+
dims,
162+
centered,
163+
input_std,
164+
output_std,
165+
positive_probs,
166+
positive_probs_std,
167+
model: pm.Model,
168+
):
169+
eps, beta = pmx.distributions.R2D2M2CP(
170+
"beta",
171+
output_std,
172+
input_std,
173+
dims=dims,
174+
r2=1.0,
175+
centered=centered,
176+
positive_probs_std=positive_probs_std,
177+
positive_probs=positive_probs,
178+
)
179+
assert not np.isnan(beta.eval()).any()
180+
assert eps.eval().shape == output_std.shape
181+
assert beta.eval().shape == input_std.shape
182+
# r2 rv is only created if r2 std is not None
183+
assert "beta" in model.named_vars
184+
# phi is only created if variable importance is not None and there is more than one var
133185
assert ("beta::psi" in model.named_vars) == (
134186
positive_probs_std is not None and positive_probs_std.any()
135187
), set(model.named_vars)

0 commit comments

Comments
 (0)