Skip to content

Commit c4c69eb

Browse files
committed
run with doctest
1 parent 4a683b3 commit c4c69eb

File tree

4 files changed

+123
-18
lines changed

4 files changed

+123
-18
lines changed

pymc_experimental/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
from pymc_experimental.distributions.continuous import GenExtreme
2121
from pymc_experimental.distributions.discrete import GeneralizedPoisson
2222
from pymc_experimental.distributions.timeseries import DiscreteMarkovChain
23+
from pymc_experimental.distributions.histogram_utils import histogram_approximation
2324
from pymc_experimental.distributions.multivatiate import R2D2M2CP
2425

2526
__all__ = [
2627
"DiscreteMarkovChain",
2728
"GeneralizedPoisson",
2829
"GenExtreme",
2930
"R2D2M2CP",
31+
"histogram_approximation",
3032
]

pymc_experimental/distributions/histogram_utils.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,27 +19,21 @@
1919
import pymc as pm
2020
from numpy.typing import ArrayLike
2121

22-
try:
23-
import dask.array
24-
import dask.dataframe
25-
except ImportError:
26-
dask = None
27-
28-
try:
29-
import xhistogram.core
30-
except ImportError:
31-
xhistogram = None
32-
33-
3422
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
3523

3624

3725
def quantile_histogram(
3826
data: ArrayLike, n_quantiles=1000, zero_inflation=False
3927
) -> Dict[str, ArrayLike]:
40-
if xhistogram is None:
41-
raise RuntimeError("quantile_histogram requires xhistogram package")
42-
28+
try:
29+
import xhistogram.core
30+
except ImportError as e:
31+
raise RuntimeError("quantile_histogram requires xhistogram package") from e
32+
try:
33+
import dask.array
34+
import dask.dataframe
35+
except ImportError:
36+
dask = None
4337
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
4438
data = data.to_dask_array(lengths=True)
4539
if zero_inflation:
@@ -74,8 +68,15 @@ def quantile_histogram(
7468

7569

7670
def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
77-
if xhistogram is None:
78-
raise RuntimeError("discrete_histogram requires xhistogram package")
71+
try:
72+
import xhistogram.core
73+
except ImportError as e:
74+
raise RuntimeError("discrete_histogram requires xhistogram package") from e
75+
try:
76+
import dask.array
77+
import dask.dataframe
78+
except ImportError:
79+
dask = None
7980

8081
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
8182
data = data.to_dask_array(lengths=True)
@@ -147,6 +148,11 @@ def histogram_approximation(name, dist, *, observed, **h_kwargs):
147148
... observed=measurements, n_quantiles=50, zero_inflation=True
148149
... )
149150
"""
151+
try:
152+
import dask.array
153+
import dask.dataframe
154+
except ImportError:
155+
dask = None
150156
if dask and isinstance(observed, (dask.dataframe.Series, dask.dataframe.DataFrame)):
151157
observed = observed.to_dask_array(lengths=True)
152158
if np.issubdtype(observed.dtype, np.integer):

pymc_experimental/distributions/multivatiate.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,97 @@ def R2D2M2CP(
122122
123123
- ``(R2D2M2)``CP is taken from https://arxiv.org/abs/2208.07132
124124
- R2D2M2``(CP)``, (Correlation Probability) is proposed and implemented by Max Kochurov (@ferrine)
125+
126+
Examples
127+
--------
128+
Here are arguments explained in a synthetic example
129+
130+
>>> import pymc_experimental as pmx
131+
>>> import pymc as pm
132+
>>> import numpy as np
133+
>>> X = np.random.randn(10, 3)
134+
>>> b = np.random.randn(3)
135+
>>> y = X @ b + np.random.randn(10) * 0.04 + 5
136+
>>> with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
137+
... eps, beta = pmx.distributions.R2D2M2CP(
138+
... "beta",
139+
... y.std(),
140+
... X.std(0),
141+
... dims="variables",
142+
... # NOTE: global shrinkage
143+
... r2=0.8,
144+
... # NOTE: if you are unsure about r2
145+
... r2_std=0.2,
146+
... # NOTE: if you know where a variable should go
147+
... # if you do not know, leave as 0.5
148+
... positive_probs=[0.8, 0.5, 0.1],
149+
... # NOTE: if you have different opinions about
150+
... # where a variable should go.
151+
... # NOTE: if you put 0.5 previously,
152+
... # just put 0.1 there, but other
153+
... # sigmas should work fine too
154+
... positive_probs_std=[0.3, 0.1, 0.2],
155+
... # NOTE: variable importances are relative to each other,
156+
... # but larget numbers put "more" weight in the relation
157+
... # use
158+
... # * 1-10 for small confidence
159+
... # * 10-30 for moderate confidence
160+
... # * 30+ for high confidence
161+
... # EXAMPLE:
162+
... # "a" - is likely to be useful
163+
... # "b" - no idea if it is useful
164+
... # "c" - a must have in the relation
165+
... variables_importance=[10, 1, 34],
166+
... # NOTE: try both
167+
... centered=True
168+
... )
169+
... intercept = y.mean()
170+
... obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
171+
172+
There can be special cases by choosing specific set of arguments
173+
174+
Here the prior distribution of beta is ``Normal(0, y.std() * r2 ** .5)``
175+
176+
>>> with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
177+
... eps, beta = pmx.distributions.R2D2M2CP(
178+
... "beta",
179+
... y.std(),
180+
... X.std(0),
181+
... dims="variables",
182+
... # NOTE: global shrinkage
183+
... r2=0.8,
184+
... # NOTE: if you are unsure about r2
185+
... r2_std=0.2,
186+
... # NOTE: if you know where a variable should go
187+
... # if you do not know, leave as 0.5
188+
... centered=False
189+
... )
190+
... intercept = y.mean()
191+
... obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
192+
193+
194+
It is fine to leave some of the ``_std`` arguments unspecified.
195+
You can also specify only ``positive_probs``, and all
196+
the variables are assumed to explain same amount of variance (same importance)
197+
198+
>>> with pm.Model(coords=dict(variables=["a", "b", "c"])) as model:
199+
... eps, beta = pmx.distributions.R2D2M2CP(
200+
... "beta",
201+
... y.std(),
202+
... X.std(0),
203+
... dims="variables",
204+
... # NOTE: global shrinkage
205+
... r2=0.8,
206+
... # NOTE: if you are unsure about r2
207+
... r2_std=0.2,
208+
... # NOTE: if you know where a variable should go
209+
... # if you do not know, leave as 0.5
210+
... positive_probs=[0.8, 0.5, 0.1],
211+
... # NOTE: try both
212+
... centered=True
213+
... )
214+
... intercept = y.mean()
215+
... obs = pm.Normal("obs", intercept + X @ beta, eps, observed=y)
125216
"""
126217
if not isinstance(dims, (list, tuple)):
127218
dims = (dims,)

pyproject.toml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
[tool.pytest.ini_options]
22
minversion = "6.0"
3-
xfail_strict=true
3+
xfail_strict = true
4+
addopts = [
5+
"-v",
6+
"--doctest-modules",
7+
"--ignore=pymc_experimental/model_builder.py"
8+
]
9+
410

511
[tool.black]
612
line-length = 100

0 commit comments

Comments
 (0)