Skip to content

Commit b3f6581

Browse files
authored
Add r2m2d2cp prior (#137)
* add r2m2d2cp * change user facing parametrization * add init test * run with doctest * fix year in the copyright notice, fix filename * add R2D2M2CP to docs * add more functional tests * pre commit run * restructure docs * fix docs additions * add bibtex entry * add comments to docstrings * restructure modules * rename hierarchy dim to broadcast dim * add a comment about named dims * improve docstrings
1 parent f9d1c55 commit b3f6581

File tree

7 files changed

+500
-19
lines changed

7 files changed

+500
-19
lines changed

docs/api_reference.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,9 @@ Distributions
3030

3131
GenExtreme
3232
GeneralizedPoisson
33-
histogram_utils.histogram_approximation
3433
DiscreteMarkovChain
34+
R2D2M2CP
35+
histogram_approximation
3536

3637

3738
Utils

pymc_experimental/distributions/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919

2020
from pymc_experimental.distributions.continuous import GenExtreme
2121
from pymc_experimental.distributions.discrete import GeneralizedPoisson
22+
from pymc_experimental.distributions.histogram_utils import histogram_approximation
23+
from pymc_experimental.distributions.multivariate import R2D2M2CP
2224
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
2325

2426
__all__ = [
2527
"DiscreteMarkovChain",
2628
"GeneralizedPoisson",
2729
"GenExtreme",
30+
"R2D2M2CP",
31+
"histogram_approximation",
2832
]

pymc_experimental/distributions/histogram_utils.py

+23-17
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,21 @@
1919
import pymc as pm
2020
from numpy.typing import ArrayLike
2121

22-
try:
23-
import dask.array
24-
import dask.dataframe
25-
except ImportError:
26-
dask = None
27-
28-
try:
29-
import xhistogram.core
30-
except ImportError:
31-
xhistogram = None
32-
33-
3422
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
3523

3624

3725
def quantile_histogram(
3826
data: ArrayLike, n_quantiles=1000, zero_inflation=False
3927
) -> Dict[str, ArrayLike]:
40-
if xhistogram is None:
41-
raise RuntimeError("quantile_histogram requires xhistogram package")
42-
28+
try:
29+
import xhistogram.core
30+
except ImportError as e:
31+
raise RuntimeError("quantile_histogram requires xhistogram package") from e
32+
try:
33+
import dask.array
34+
import dask.dataframe
35+
except ImportError:
36+
dask = None
4337
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
4438
data = data.to_dask_array(lengths=True)
4539
if zero_inflation:
@@ -74,8 +68,15 @@ def quantile_histogram(
7468

7569

7670
def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
77-
if xhistogram is None:
78-
raise RuntimeError("discrete_histogram requires xhistogram package")
71+
try:
72+
import xhistogram.core
73+
except ImportError as e:
74+
raise RuntimeError("discrete_histogram requires xhistogram package") from e
75+
try:
76+
import dask.array
77+
import dask.dataframe
78+
except ImportError:
79+
dask = None
7980

8081
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
8182
data = data.to_dask_array(lengths=True)
@@ -147,6 +148,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
147148
... observed=measurements, n_quantiles=50, zero_inflation=True
148149
... )
149150
"""
151+
try:
152+
import dask.array
153+
import dask.dataframe
154+
except ImportError:
155+
dask = None
150156
if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)):
151157
observed = observed.to_dask_array(lengths=True)
152158
if np.issubdtype(observed.dtype, np.integer):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from pymc_experimental.distributions.multivariate.r2d2m2cp import R2D2M2CP
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
# Copyright 2023 The PyMC Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from typing import Sequence, Union
17+
18+
import pymc as pm
19+
import pytensor.tensor as pt
20+
21+
__all__ = ["R2D2M2CP"]
22+
23+
24+
def _psivar2musigma(psi: pt.TensorVariable, explained_var: pt.TensorVariable):
25+
pi = pt.erfinv(2 * psi - 1)
26+
f = (1 / (2 * pi**2 + 1)) ** 0.5
27+
sigma = explained_var**0.5 * f
28+
mu = sigma * pi * 2**0.5
29+
return mu, sigma
30+
31+
32+
def _R2D2M2CP_beta(
33+
name: str,
34+
output_sigma: pt.TensorVariable,
35+
input_sigma: pt.TensorVariable,
36+
r2: pt.TensorVariable,
37+
phi: pt.TensorVariable,
38+
psi: pt.TensorVariable,
39+
*,
40+
dims: Union[str, Sequence[str]],
41+
centered=False,
42+
):
43+
"""R2D2M2CP beta prior.
44+
45+
Parameters
46+
----------
47+
name: str
48+
Name for the distribution
49+
output_sigma: tensor
50+
standard deviation of the outcome
51+
input_sigma: tensor
52+
standard deviation of the explanatory variables
53+
r2: tensor
54+
expected R2 for the linear regression
55+
phi: tensor
56+
variance weights that sums up to 1
57+
psi: tensor
58+
probability of a coefficients to be positive
59+
"""
60+
tau2 = r2 / (1 - r2)
61+
explained_variance = phi * pt.expand_dims(tau2 * output_sigma**2, -1)
62+
mu_param, std_param = _psivar2musigma(psi, explained_variance)
63+
if not centered:
64+
with pm.Model(name):
65+
raw = pm.Normal("raw", dims=dims)
66+
beta = pm.Deterministic(name, (raw * std_param + mu_param) / input_sigma, dims=dims)
67+
else:
68+
beta = pm.Normal(name, mu_param / input_sigma, std_param / input_sigma, dims=dims)
69+
return beta
70+
71+
72+
def R2D2M2CP(
73+
name,
74+
output_sigma,
75+
input_sigma,
76+
*,
77+
dims,
78+
r2,
79+
variables_importance=None,
80+
variance_explained=None,
81+
r2_std=None,
82+
positive_probs=0.5,
83+
positive_probs_std=None,
84+
centered=False,
85+
):
86+
"""R2D2M2CP Prior.
87+
88+
Parameters
89+
----------
90+
name : str
91+
Name for the distribution
92+
output_sigma : tensor
93+
Output standard deviation
94+
input_sigma : tensor
95+
Input standard deviation
96+
dims : Union[str, Sequence[str]]
97+
Dims for the distribution
98+
r2 : tensor
99+
:math:`R^2` estimate
100+
variables_importance : tensor, optional
101+
Optional estimate for variables importance, positive, by default None
102+
variance_explained : tensor, optional
103+
Alternative estimate for variables importance which is point estimate of
104+
variance explained, should sum up to one, by default None
105+
r2_std : tensor, optional
106+
Optional uncertainty over :math:`R^2`, by default None
107+
positive_probs : tensor, optional
108+
Optional probability of variables contribution to be positive, by default 0.5
109+
positive_probs_std : tensor, optional
110+
Optional uncertainty over effect direction probability, by default None
111+
centered : bool, optional
112+
Centered or Non-Centered parametrization of the distribution, by default Non-Centered. Advised to check both
113+
114+
Returns
115+
-------
116+
residual_sigma, coefficients
117+
Output variance (sigma squared) is split in residual variance and explained variance.
118+
119+
Raises
120+
------
121+
TypeError
122+
If parametrization is wrong.
123+
124+
Notes
125+
-----
126+
The R2D2M2CP prior is a modification of R2D2M2 prior.
127+
128+
- ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
129+
- R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
130+
131+
Examples
132+
--------
133+
Here are arguments explained in a synthetic example
134+
135+
.. warning::
136+
137+
To use the prior in a linear regression
138+
139+
- make sure :math:`X` is centered around zero
140+
- intercept represents prior predictive mean when :math:`X` is centered
141+
- setting named dims is required
142+
143+
.. code-block:: python
144+
145+
import pymc_experimental as pmx
146+
import pymc as pm
147+
import numpy as np
148+
X = np.random.randn(10, 3)
149+
b = np.random.randn(3)
150+
y = X @ b + np.random.randn(10) * 0.04 + 5
151+
with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
152+
eps, beta = pmx.distributions.R2D2M2CP(
153+
"beta",
154+
y.std(),
155+
X.std(0),
156+
dims="variables",
157+
# NOTE: global shrinkage
158+
r2=0.8,
159+
# NOTE: if you are unsure about r2
160+
r2_std=0.2,
161+
# NOTE: if you know where a variable should go
162+
# if you do not know, leave as 0.5
163+
positive_probs=[0.8, 0.5, 0.1],
164+
# NOTE: if you have different opinions about
165+
# where a variable should go.
166+
# NOTE: if you put 0.5 previously,
167+
# just put 0.1 there, but other
168+
# sigmas should work fine too
169+
positive_probs_std=[0.3, 0.1, 0.2],
170+
# NOTE: variable importances are relative to each other,
171+
# but larget numbers put "more" weight in the relation
172+
# use
173+
# * 1-10 for small confidence
174+
# * 10-30 for moderate confidence
175+
# * 30+ for high confidence
176+
# EXAMPLE:
177+
# "a" - is likely to be useful
178+
# "b" - no idea if it is useful
179+
# "c" - a must have in the relation
180+
variables_importance=[10, 1, 34],
181+
# NOTE: try both
182+
centered=True
183+
)
184+
# intercept prior centering should be around prior predictive mean
185+
intercept = y.mean()
186+
# regressors should be centered around zero
187+
Xc = X - X.mean(0)
188+
obs = pm.Normal("obs", intercept + Xc @ beta, eps, observed=y)
189+
190+
There can be special cases by choosing specific set of arguments
191+
192+
Here the prior distribution of beta is ``Normal(0, y.std() * r2 ** .5)``
193+
194+
.. code-block:: python
195+
196+
with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
197+
eps, beta = pmx.distributions.R2D2M2CP(
198+
"beta",
199+
y.std(),
200+
X.std(0),
201+
dims="variables",
202+
# NOTE: global shrinkage
203+
r2=0.8,
204+
# NOTE: if you are unsure about r2
205+
r2_std=0.2,
206+
# NOTE: if you know where a variable should go
207+
# if you do not know, leave as 0.5
208+
centered=False
209+
)
210+
# intercept prior centering should be around prior predictive mean
211+
intercept = y.mean()
212+
# regressors should be centered around zero
213+
Xc = X - X.mean(0)
214+
obs = pm.Normal("obs", intercept + Xc @ beta, eps, observed=y)
215+
216+
217+
It is fine to leave some of the ``_std`` arguments unspecified.
218+
You can also specify only ``positive_probs``, and all
219+
the variables are assumed to explain same amount of variance (same importance)
220+
221+
.. code-block:: python
222+
223+
with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
224+
eps, beta = pmx.distributions.R2D2M2CP(
225+
"beta",
226+
y.std(),
227+
X.std(0),
228+
dims="variables",
229+
# NOTE: global shrinkage
230+
r2=0.8,
231+
# NOTE: if you are unsure about r2
232+
r2_std=0.2,
233+
# NOTE: if you know where a variable should go
234+
# if you do not know, leave as 0.5
235+
positive_probs=[0.8, 0.5, 0.1],
236+
# NOTE: try both
237+
centered=True
238+
)
239+
intercept = y.mean()
240+
obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
241+
242+
Notes
243+
-----
244+
To reference R2D2M2CP implementation, you can use the following bibtex entry:
245+
246+
.. code-block::
247+
248+
@misc{pymc-experimental-r2d2m2cp,
249+
title = {pymc-devs/pymc-experimental: {P}ull {R}equest 137, {R2D2M2CP}},
250+
url = {https://github.com/pymc-devs/pymc-experimental/pull/137},
251+
author = {Max Kochurov},
252+
howpublished = {GitHub},
253+
year = {2023}
254+
}
255+
"""
256+
if not isinstance(dims, (list, tuple)):
257+
dims = (dims,)
258+
*broadcast_dims, dim = dims
259+
input_sigma = pt.as_tensor(input_sigma)
260+
output_sigma = pt.as_tensor(output_sigma)
261+
with pm.Model(name) as model:
262+
if variables_importance is not None:
263+
if variance_explained is not None:
264+
raise TypeError("Can't use variable importance with variance explained")
265+
if len(model.coords[dim]) <= 1:
266+
raise TypeError("Can't use variable importance with less than two variables")
267+
phi = pm.Dirichlet(
268+
"phi", pt.as_tensor(variables_importance), dims=broadcast_dims + [dim]
269+
)
270+
elif variance_explained is not None:
271+
if len(model.coords[dim]) <= 1:
272+
raise TypeError("Can't use variance explained with less than two variables")
273+
phi = pt.as_tensor(variance_explained)
274+
else:
275+
phi = 1 / len(model.coords[dim])
276+
if r2_std is not None:
277+
r2 = pm.Beta("r2", mu=r2, sigma=r2_std, dims=broadcast_dims)
278+
if positive_probs_std is not None:
279+
psi = pm.Beta(
280+
"psi",
281+
mu=pt.as_tensor(positive_probs),
282+
sigma=pt.as_tensor(positive_probs_std),
283+
dims=broadcast_dims + [dim],
284+
)
285+
else:
286+
psi = pt.as_tensor(positive_probs)
287+
beta = _R2D2M2CP_beta(
288+
name,
289+
output_sigma,
290+
input_sigma,
291+
r2,
292+
phi,
293+
psi,
294+
dims=broadcast_dims + [dim],
295+
centered=centered,
296+
)
297+
resid_sigma = (1 - r2) ** 0.5 * output_sigma
298+
return resid_sigma, beta

0 commit comments

Comments
 (0)