Skip to content

Commit 1b89daa

Browse files
committed
change user facing parametrization
1 parent eb55cd5 commit 1b89daa

File tree

1 file changed

+20
-20
lines changed

1 file changed

+20
-20
lines changed

pymc_experimental/distributions/multivatiate.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def _psivar2musigma(psi: pt.TensorVariable, var: pt.TensorVariable):
1414

1515
def _R2D2M2CP_beta(
1616
name: str,
17-
variance: pt.TensorVariable,
18-
param_sigma: pt.TensorVariable,
17+
output_sigma: pt.TensorVariable,
18+
input_sigma: pt.TensorVariable,
1919
r2: pt.TensorVariable,
2020
phi: pt.TensorVariable,
2121
psi: pt.TensorVariable,
@@ -26,9 +26,9 @@ def _R2D2M2CP_beta(
2626
"""R2D2M2CP_beta prior.
2727
name: str
2828
Name for the distribution
29-
variance: tensor
29+
output_sigma: tensor
3030
standard deviation of the outcome
31-
param_sigma: tensor
31+
input_sigma: tensor
3232
standard deviation of the explanatory variables
3333
r2: tensor
3434
expected R2 for the linear regression
@@ -38,21 +38,21 @@ def _R2D2M2CP_beta(
3838
probability of a coefficients to be positive
3939
"""
4040
tau2 = r2 / (1 - r2)
41-
explained_variance = phi * tau2 * pt.expand_dims(variance, -1)
41+
explained_variance = phi * tau2 * pt.expand_dims(output_sigma**2, -1)
4242
mu_param, std_param = _psivar2musigma(psi, explained_variance)
4343
if not centered:
4444
with pm.Model(name):
4545
raw = pm.Normal("raw", dims=dims)
46-
beta = pm.Deterministic(name, (raw * std_param + mu_param) / param_sigma, dims=dims)
46+
beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims)
4747
else:
48-
beta = pm.Normal(name, mu_param / param_sigma, std_param / param_sigma, dims=dims)
48+
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
4949
return beta
5050

5151

5252
def R2D2M2CP(
5353
name,
54-
variance,
55-
param_sigma,
54+
output_sigma,
55+
input_sigma,
5656
*,
5757
dims,
5858
r2,
@@ -69,16 +69,16 @@ def R2D2M2CP(
6969
----------
7070
name : str
7171
Name for the distribution
72-
variance : tensor
73-
Output variance
74-
param_sigma : tensor
72+
output_sigma : tensor
73+
Output standard deviation
74+
input_sigma : tensor
7575
Input standard deviation
7676
dims : Union[str, Sequence[str]]
7777
Dims for the distribution
7878
r2 : tensor
7979
:math:`R^2` estimate
8080
variables_importance : tensor, optional
81-
Optional estimate for variables importance, positive, , by default None
81+
Optional estimate for variables importance, positive, by default None
8282
variance_explained : tensor, optional
8383
Alternative estimate for variables importance which is point estimate of
8484
variance explained, should sum up to one, by default None
@@ -93,8 +93,8 @@ def R2D2M2CP(
9393
9494
Returns
9595
-------
96-
residual_variance, coefficients
97-
Output variance is split in residual variance and explained variance.
96+
residual_sigma, coefficients
97+
Output variance (sigma squared) is split in residual variance and explained variance.
9898
9999
Raises
100100
------
@@ -109,8 +109,8 @@ def R2D2M2CP(
109109
if not isinstance(dims, (list, tuple)):
110110
dims = (dims,)
111111
*hierarchy, dim = dims
112-
param_sigma = pt.as_tensor(param_sigma)
113-
variance = pt.as_tensor(variance)
112+
input_sigma = pt.as_tensor(input_sigma)
113+
output_sigma = pt.as_tensor(output_sigma)
114114
with pm.Model(name) as model:
115115
if variables_importance is not None and len(model.coords[dim]) > 1:
116116
if variance_explained is not None:
@@ -132,7 +132,7 @@ def R2D2M2CP(
132132
else:
133133
psi = pt.as_tensor(positive_probs)
134134
beta = _R2D2M2CP_beta(
135-
name, variance, param_sigma, r2, phi, psi, dims=hierarchy + [dim], centered=centered
135+
name, output_sigma, input_sigma, r2, phi, psi, dims=hierarchy + [dim], centered=centered
136136
)
137-
variance_resid = (1 - r2) * variance
138-
return variance_resid, beta
137+
resid_sigma = (1 - r2) ** 0.5 * output_sigma
138+
return resid_sigma, beta

0 commit comments

Comments
 (0)