Skip to content

Commit 4a683b3

Browse files
committed
add init test
1 parent 1b89daa commit 4a683b3

File tree

3 files changed

+117
-4
lines changed

3 files changed

+117
-4
lines changed

pymc_experimental/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from pymc_experimental.distributions.continuous import GenExtreme
2121
from pymc_experimental.distributions.discrete import GeneralizedPoisson
2222
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
23+
from pymc_experimental.distributions.multivatiate import R2D2M2CP
2324

2425
__all__ = [
2526
"DiscreteMarkovChain",
2627
"GeneralizedPoisson",
2728
"GenExtreme",
29+
"R2D2M2CP",
2830
]

pymc_experimental/distributions/multivatiate.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,28 @@
1+
# Copyright 2022 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
116
from typing import Sequence, Union
217

318
import pymc as pm
419
import pytensor.tensor as pt
520

621

7-
def _psivar2musigma(psi: pt.TensorVariable, var: pt.TensorVariable):
22+
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable):
823
pi = pt.erfinv(2 * psi - 1)
924
f = (1 / (2 * pi**2 + 1)) ** 0.5
10-
sigma = pt.expand_dims(var, -1) ** 0.5 * f
25+
sigma = explained_var**0.5 * f
1126
mu = sigma * pi * 2**0.5
1227
return mu, sigma
1328

@@ -38,7 +53,7 @@ def _R2D2M2CP_beta(
3853
probability of a coefficients to be positive
3954
"""
4055
tau2 = r2 / (1 - r2)
41-
explained_variance = phi * tau2 * pt.expand_dims(output_sigma**2, -1)
56+
explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1)
4257
mu_param, std_param = _psivar2musigma(psi, explained_variance)
4358
if not centered:
4459
with pm.Model(name):
@@ -103,6 +118,8 @@ def R2D2M2CP(
103118
104119
Notes
105120
-----
121+
The R2D2M2CP prior is a modification of R2D2M2 prior.
122+
106123
- ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
107124
- R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
108125
"""
@@ -116,7 +133,7 @@ def R2D2M2CP(
116133
if variance_explained is not None:
117134
raise TypeError("Can't use variable importance with variance explained")
118135
phi = pm.Dirichlet("phi", pt.as_tensor(variables_importance), dims=hierarchy + [dim])
119-
elif variance_explained:
136+
elif variance_explained is not None:
120137
phi = pt.as_tensor(variance_explained)
121138
else:
122139
phi = 1 / len(model.coords[dim])
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import numpy as np
2+
import pymc as pm
3+
import pytest
4+
5+
import pymc_experimental as pmx
6+
7+
8+
class TestR2D2M2CP:
9+
@pytest.fixture(autouse=True)
10+
def model(self):
11+
# every method is within a model
12+
with pm.Model() as model:
13+
yield model
14+
15+
@pytest.fixture(params=[True, False])
16+
def centered(self, request):
17+
return request.param
18+
19+
@pytest.fixture(params=[["a"], ["a", "b"], ["one"]])
20+
def dims(self, model: pm.Model, request):
21+
for i, c in enumerate(request.param):
22+
if c == "one":
23+
model.add_coord(c, range(1))
24+
else:
25+
model.add_coord(c, range((i + 2) ** 2))
26+
return request.param
27+
28+
@pytest.fixture
29+
def input_std(self, dims, model):
30+
input_shape = [int(model.dim_lengths[d].eval()) for d in dims]
31+
return np.ones(input_shape)
32+
33+
@pytest.fixture
34+
def output_std(self, dims, model):
35+
*hierarchy, _ = dims
36+
output_shape = [int(model.dim_lengths[d].eval()) for d in hierarchy]
37+
return np.ones(output_shape)
38+
39+
@pytest.fixture
40+
def r2(self):
41+
return 0.8
42+
43+
@pytest.fixture(params=[None, 0.1])
44+
def r2_std(self, request):
45+
return request.param
46+
47+
@pytest.fixture(params=[True, False])
48+
def positive_probs(self, input_std, request):
49+
if request.param:
50+
return np.full_like(input_std, 0.5)
51+
else:
52+
return 0.5
53+
54+
@pytest.fixture(params=[True, False])
55+
def positive_probs_std(self, positive_probs, request):
56+
if request.param:
57+
return np.full_like(positive_probs, 0.1)
58+
else:
59+
return None
60+
61+
@pytest.fixture(params=["importance", "explained"])
62+
def phi_args(self, request, input_std):
63+
if request.param == "importance":
64+
return {"variables_importance": np.full_like(input_std, 2)}
65+
else:
66+
val = np.full_like(input_std, 2)
67+
return {"variance_explained": val / val.sum(-1, keepdims=True)}
68+
69+
def test_init(
70+
self,
71+
dims,
72+
centered,
73+
input_std,
74+
output_std,
75+
r2,
76+
r2_std,
77+
positive_probs,
78+
positive_probs_std,
79+
phi_args,
80+
):
81+
eps, beta = pmx.distributions.R2D2M2CP(
82+
"beta",
83+
output_std,
84+
input_std,
85+
dims=dims,
86+
r2=r2,
87+
r2_std=r2_std,
88+
centered=centered,
89+
positive_probs_std=positive_probs_std,
90+
positive_probs=positive_probs,
91+
**phi_args
92+
)
93+
assert eps.eval().shape == output_std.shape
94+
assert beta.eval().shape == input_std.shape

0 commit comments

Comments
 (0)