Skip to content

Commit dea6bc9

Browse files
ferrineOriolAbril
andauthored
Initialize a prior from a fitted posterior (#56)
* add argument parser * extend argument parser * prepare a valid fixture * improve fixture * improve fixture * use simplex transform for the test case * add parse args * add flatten util * fix typo * refactor flattening * add mean chol * add test for mvn_prior * test final api * add additional argument * add type hints * fix tests * add a docstring * add to docs * simplify implementation * Update pymc_experimental/utils/prior.py Co-authored-by: Oriol Abril-Pla <[email protected]> * Update pymc_experimental/utils/prior.py Co-authored-by: Oriol Abril-Pla <[email protected]> * update the docstring * update the docstring Co-authored-by: Oriol Abril-Pla <[email protected]>
1 parent 47e2e2b commit dea6bc9

File tree

4 files changed

+326
-0
lines changed

4 files changed

+326
-0
lines changed

docs/api_reference.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,6 @@ methods in the current release of PyMC experimental.
2929
.. automodule:: pymc_experimental.utils.spline
3030
:members: bspline_interpolation
3131

32+
.. automodule:: pymc_experimental.utils.prior
33+
:members: prior_from_idata
34+
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
import pymc_experimental as pmx
2+
from pymc.distributions import transforms
3+
import pytest
4+
import arviz as az
5+
import numpy as np
6+
import pymc as pm
7+
8+
9+
@pytest.mark.parametrize(
10+
"case",
11+
[
12+
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
13+
(("a", None), dict(name="a", transform=None, dims=None)),
14+
(("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)),
15+
(
16+
("a", dict(transform=transforms.log)),
17+
dict(name="a", transform=transforms.log, dims=None),
18+
),
19+
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
20+
(("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")),
21+
(("a", ("test",)), dict(name="a", transform=None, dims=("test",))),
22+
],
23+
)
24+
def test_parsing_arguments(case):
25+
inp, out = case
26+
test = pmx.utils.prior._arg_to_param_cfg(*inp)
27+
assert test == out
28+
29+
30+
@pytest.fixture
31+
def coords():
32+
return dict(test=range(3), simplex=range(4))
33+
34+
35+
@pytest.fixture
36+
def user_param_cfg():
37+
return ("t",), dict(
38+
a="d",
39+
b=dict(transform=transforms.log, dims=("test",)),
40+
c=dict(transform=transforms.simplex, dims=("simplex",)),
41+
)
42+
43+
44+
@pytest.fixture
45+
def param_cfg(user_param_cfg):
46+
return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1])
47+
48+
49+
@pytest.fixture
50+
def transformed_data(param_cfg, coords):
51+
vars = dict()
52+
for k, cfg in param_cfg.items():
53+
if cfg["dims"] is not None:
54+
extra_dims = [len(coords[d]) for d in cfg["dims"]]
55+
if cfg["transform"] is not None:
56+
t = np.random.randn(*extra_dims)
57+
extra_dims = tuple(cfg["transform"].forward(t).shape.eval())
58+
else:
59+
extra_dims = []
60+
orig = np.random.randn(4, 100, *extra_dims)
61+
vars[k] = orig
62+
return vars
63+
64+
65+
@pytest.fixture
66+
def idata(transformed_data, param_cfg):
67+
vars = dict()
68+
for k, orig in transformed_data.items():
69+
cfg = param_cfg[k]
70+
if cfg["transform"] is not None:
71+
var = cfg["transform"].backward(orig).eval()
72+
else:
73+
var = orig
74+
assert not np.isnan(var).any()
75+
vars[k] = var
76+
return az.convert_to_inference_data(vars)
77+
78+
79+
def test_idata_for_tests(idata, param_cfg):
80+
assert set(idata.posterior.keys()) == set(param_cfg)
81+
assert len(idata.posterior.coords["chain"]) == 4
82+
assert len(idata.posterior.coords["draw"]) == 100
83+
84+
85+
def test_args_compose():
86+
cfg = pmx.utils.prior._parse_args(
87+
var_names=["a"],
88+
b=("test",),
89+
c=transforms.log,
90+
d="e",
91+
f=dict(dims="test"),
92+
g=dict(name="h", dims="test", transform=transforms.log),
93+
)
94+
assert cfg == dict(
95+
a=dict(name="a", dims=None, transform=None),
96+
b=dict(name="b", dims=("test",), transform=None),
97+
c=dict(name="c", dims=None, transform=transforms.log),
98+
d=dict(name="e", dims=None, transform=None),
99+
f=dict(name="f", dims="test", transform=None),
100+
g=dict(name="h", dims="test", transform=transforms.log),
101+
)
102+
103+
104+
def test_transform_idata(transformed_data, idata, param_cfg):
105+
flat_info = pmx.utils.prior._flatten(idata, **param_cfg)
106+
expected_shape = 0
107+
for v in transformed_data.values():
108+
expected_shape += int(np.prod(v.shape[2:]))
109+
assert flat_info["data"].shape[1] == expected_shape
110+
assert len(flat_info["info"]) == len(param_cfg)
111+
assert "sinfo" in flat_info["info"][0]
112+
assert "vinfo" in flat_info["info"][0]
113+
114+
115+
@pytest.fixture
116+
def flat_info(idata, param_cfg):
117+
return pmx.utils.prior._flatten(idata, **param_cfg)
118+
119+
120+
def test_mean_chol(flat_info):
121+
mean, chol = pmx.utils.prior._mean_chol(flat_info["data"])
122+
assert mean.shape == (flat_info["data"].shape[1],)
123+
assert chol.shape == (flat_info["data"].shape[1],) * 2
124+
125+
126+
def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg):
127+
with pm.Model(coords=coords) as model:
128+
priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info)
129+
test_prior = pm.sample_prior_predictive(1)
130+
names = [p["name"] for p in param_cfg.values()]
131+
assert set(model.named_vars) == {"trace_prior_", *names}
132+
133+
134+
def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg):
135+
with pm.Model(coords=coords) as model:
136+
priors = pmx.utils.prior.prior_from_idata(
137+
idata, var_names=user_param_cfg[0], **user_param_cfg[1]
138+
)
139+
test_prior = pm.sample_prior_predictive(1)
140+
names = [p["name"] for p in param_cfg.values()]
141+
assert set(model.named_vars) == {"trace_prior_", *names}

pymc_experimental/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
from pymc_experimental.utils import spline
2+
from pymc_experimental.utils import prior

pymc_experimental/utils/prior.py

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict, List
2+
import aeppl.transforms
3+
import arviz
4+
import pymc as pm
5+
import aesara.tensor as at
6+
import numpy as np
7+
8+
9+
class ParamCfg(TypedDict):
10+
name: str
11+
transform: Optional[aeppl.transforms.RVTransform]
12+
dims: Optional[Union[str, Tuple[str]]]
13+
14+
15+
class ShapeInfo(TypedDict):
16+
# shape might not match slice due to a transform
17+
shape: Tuple[int] # transformed shape
18+
slice: slice
19+
20+
21+
class VarInfo(TypedDict):
22+
sinfo: ShapeInfo
23+
vinfo: ParamCfg
24+
25+
26+
class FlatInfo(TypedDict):
27+
data: np.ndarray
28+
info: List[VarInfo]
29+
30+
31+
def _arg_to_param_cfg(
32+
key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None
33+
):
34+
if value is None:
35+
cfg = ParamCfg(name=key, transform=None, dims=None)
36+
elif isinstance(value, Tuple):
37+
cfg = ParamCfg(name=key, transform=None, dims=value)
38+
elif isinstance(value, str):
39+
cfg = ParamCfg(name=value, transform=None, dims=None)
40+
elif isinstance(value, aeppl.transforms.RVTransform):
41+
cfg = ParamCfg(name=key, transform=value, dims=None)
42+
else:
43+
cfg = value.copy()
44+
cfg.setdefault("name", key)
45+
cfg.setdefault("transform", None)
46+
cfg.setdefault("dims", None)
47+
return cfg
48+
49+
50+
def _parse_args(
51+
var_names: Sequence[str], **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]
52+
) -> Dict[str, ParamCfg]:
53+
results = dict()
54+
for var in var_names:
55+
results[var] = _arg_to_param_cfg(var)
56+
for key, val in kwargs.items():
57+
results[key] = _arg_to_param_cfg(key, val)
58+
return results
59+
60+
61+
def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo:
62+
posterior = idata.posterior
63+
vars = list()
64+
info = list()
65+
begin = 0
66+
for key, cfg in kwargs.items():
67+
data = (
68+
posterior[key]
69+
# combine all draws from all chains
70+
.stack(__sample__=["chain", "draw"])
71+
# move sample dim to the first position
72+
# no matter where it was before
73+
.transpose("__sample__", ...)
74+
# we need numpy data for all the rest functionality
75+
.values
76+
)
77+
# omitting __sample__
78+
# we need shape in the untransformed space
79+
if cfg["transform"] is not None:
80+
# some transforms need original shape
81+
data = cfg["transform"].forward(data).eval()
82+
shape = data.shape[1:]
83+
# now we can get rid of shape
84+
data = data.reshape(data.shape[0], -1)
85+
end = begin + data.shape[1]
86+
vars.append(data)
87+
sinfo = dict(shape=shape, slice=slice(begin, end))
88+
info.append(dict(sinfo=sinfo, vinfo=cfg))
89+
begin = end
90+
return dict(data=np.concatenate(vars, axis=-1), info=info)
91+
92+
93+
def _mean_chol(flat_array: np.ndarray):
94+
mean = flat_array.mean(0)
95+
cov = np.cov(flat_array, rowvar=False)
96+
chol = np.linalg.cholesky(cov)
97+
return mean, chol
98+
99+
100+
def _mvn_prior_from_flat_info(name, flat_info: FlatInfo):
101+
mean, chol = _mean_chol(flat_info["data"])
102+
base_dist = pm.Normal(name, np.zeros_like(mean))
103+
interim = mean + chol @ base_dist
104+
result = dict()
105+
for var_info in flat_info["info"]:
106+
sinfo = var_info["sinfo"]
107+
vinfo = var_info["vinfo"]
108+
var = interim[sinfo["slice"]].reshape(sinfo["shape"])
109+
if vinfo["transform"] is not None:
110+
var = vinfo["transform"].backward(var)
111+
var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"])
112+
result[vinfo["name"]] = var
113+
return result
114+
115+
116+
def prior_from_idata(
117+
idata: arviz.InferenceData,
118+
name="trace_prior_",
119+
*,
120+
var_names: Sequence[str],
121+
**kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]
122+
) -> Dict[str, at.TensorVariable]:
123+
"""
124+
Create a prior from posterior using MvNormal approximation.
125+
126+
The approximation uses MvNormal distribution.
127+
Keep in mind that this function will only work well for unimodal
128+
posteriors and will fail when complicated interactions happen.
129+
130+
Moreover, if a retrieved variable is constrained, you
131+
should specify a transform for the variable, e.g.
132+
``pymc.distributions.transforms.log`` for standard
133+
deviation posterior.
134+
135+
Parameters
136+
----------
137+
idata: arviz.InferenceData
138+
Inference data with posterior group
139+
var_names: Sequence[str]
140+
names of variables to take as is from the posterior
141+
kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]
142+
names of variables with additional configuration, see more in Examples
143+
144+
Examples
145+
--------
146+
>>> import pymc as pm
147+
>>> import pymc.distributions.transforms as transforms
148+
>>> import numpy as np
149+
>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model1:
150+
... a = pm.Normal("a")
151+
... b = pm.Normal("b", dims="test")
152+
... c = pm.HalfNormal("c")
153+
... d = pm.Normal("d")
154+
... e = pm.Normal("e")
155+
... f = pm.Dirichlet("f", np.ones(3), dims="options")
156+
... trace = pm.sample(progressbar=False)
157+
158+
You can reuse the posterior in the new model.
159+
160+
>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2:
161+
... priors = prior_from_idata(
162+
... trace, # the old trace (posterior)
163+
... var_names=["a", "d"], # take variables as is
164+
...
165+
... e="new_e", # assign new name "new_e" for a variable
166+
... # similar to dict(name="new_e")
167+
...
168+
... b=("test", ), # set a dim to "test"
169+
... # similar to dict(dims=("test", ))
170+
...
171+
... c=transforms.log, # apply log transform to a positive variable
172+
... # similar to dict(transform=transforms.log)
173+
...
174+
... # set a name, assign a dim and apply simplex transform
175+
... f=dict(name="new_f", dims="options", transform=transforms.simplex)
176+
... )
177+
... trace1 = pm.sample_prior_predictive(100)
178+
"""
179+
param_cfg = _parse_args(var_names=var_names, **kwargs)
180+
flat_info = _flatten(idata, **param_cfg)
181+
return _mvn_prior_from_flat_info(name, flat_info)

0 commit comments

Comments
 (0)