-
-
Notifications
You must be signed in to change notification settings - Fork 65
Initialize a prior from a fitted posterior #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
f41ffcd
70d1b10
5fdb819
45582f5
8ea35db
8d047fc
20c76bf
661eacb
58ff72a
2ba7a63
7b09ff6
bb5cfd1
1eeb82d
83beb1d
0c2a3a7
a2e0db2
c10e880
71eb6b7
f9b5632
88ee0a9
678c849
c2baf4c
27ec67b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
import pymc_experimental as pmx | ||
from pymc.distributions import transforms | ||
import pytest | ||
import arviz as az | ||
import numpy as np | ||
import pymc as pm | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"case", | ||
[ | ||
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)), | ||
(("a", None), dict(name="a", transform=None, dims=None)), | ||
(("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)), | ||
( | ||
("a", dict(transform=transforms.log)), | ||
dict(name="a", transform=transforms.log, dims=None), | ||
), | ||
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)), | ||
(("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")), | ||
(("a", ("test",)), dict(name="a", transform=None, dims=("test",))), | ||
], | ||
) | ||
def test_parsing_arguments(case): | ||
inp, out = case | ||
test = pmx.utils.prior._arg_to_param_cfg(*inp) | ||
assert test == out | ||
|
||
|
||
@pytest.fixture | ||
def coords(): | ||
return dict(test=range(3), simplex=range(4)) | ||
|
||
|
||
@pytest.fixture | ||
def user_param_cfg(): | ||
return ("t",), dict( | ||
a="d", | ||
b=dict(transform=transforms.log, dims=("test",)), | ||
c=dict(transform=transforms.simplex, dims=("simplex",)), | ||
) | ||
|
||
|
||
@pytest.fixture | ||
def param_cfg(user_param_cfg): | ||
return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1]) | ||
|
||
|
||
@pytest.fixture | ||
def transformed_data(param_cfg, coords): | ||
vars = dict() | ||
for k, cfg in param_cfg.items(): | ||
if cfg["dims"] is not None: | ||
extra_dims = [len(coords[d]) for d in cfg["dims"]] | ||
if cfg["transform"] is not None: | ||
t = np.random.randn(*extra_dims) | ||
extra_dims = tuple(cfg["transform"].forward(t).shape.eval()) | ||
else: | ||
extra_dims = [] | ||
orig = np.random.randn(4, 100, *extra_dims) | ||
vars[k] = orig | ||
return vars | ||
|
||
|
||
@pytest.fixture | ||
def idata(transformed_data, param_cfg): | ||
vars = dict() | ||
for k, orig in transformed_data.items(): | ||
cfg = param_cfg[k] | ||
if cfg["transform"] is not None: | ||
var = cfg["transform"].backward(orig).eval() | ||
else: | ||
var = orig | ||
assert not np.isnan(var).any() | ||
vars[k] = var | ||
return az.convert_to_inference_data(vars) | ||
|
||
|
||
def test_idata_for_tests(idata, param_cfg): | ||
assert set(idata.posterior.keys()) == set(param_cfg) | ||
assert len(idata.posterior.coords["chain"]) == 4 | ||
assert len(idata.posterior.coords["draw"]) == 100 | ||
|
||
|
||
def test_args_compose(): | ||
cfg = pmx.utils.prior._parse_args( | ||
var_names=["a"], | ||
b=("test",), | ||
c=transforms.log, | ||
d="e", | ||
f=dict(dims="test"), | ||
g=dict(name="h", dims="test", transform=transforms.log), | ||
) | ||
assert cfg == dict( | ||
a=dict(name="a", dims=None, transform=None), | ||
b=dict(name="b", dims=("test",), transform=None), | ||
c=dict(name="c", dims=None, transform=transforms.log), | ||
d=dict(name="e", dims=None, transform=None), | ||
f=dict(name="f", dims="test", transform=None), | ||
g=dict(name="h", dims="test", transform=transforms.log), | ||
) | ||
|
||
|
||
def test_transform_idata(transformed_data, idata, param_cfg): | ||
flat_info = pmx.utils.prior._flatten(idata, **param_cfg) | ||
expected_shape = 0 | ||
for v in transformed_data.values(): | ||
expected_shape += int(np.prod(v.shape[2:])) | ||
assert flat_info["data"].shape[1] == expected_shape | ||
assert len(flat_info["info"]) == len(param_cfg) | ||
assert "sinfo" in flat_info["info"][0] | ||
assert "vinfo" in flat_info["info"][0] | ||
|
||
|
||
@pytest.fixture | ||
def flat_info(idata, param_cfg): | ||
return pmx.utils.prior._flatten(idata, **param_cfg) | ||
|
||
|
||
def test_mean_chol(flat_info): | ||
mean, chol = pmx.utils.prior._mean_chol(flat_info["data"]) | ||
assert mean.shape == (flat_info["data"].shape[1],) | ||
assert chol.shape == (flat_info["data"].shape[1],) * 2 | ||
|
||
|
||
def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg): | ||
with pm.Model(coords=coords) as model: | ||
priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info) | ||
test_prior = pm.sample_prior_predictive(1) | ||
names = [p["name"] for p in param_cfg.values()] | ||
assert set(model.named_vars) == {"trace_prior_", *names} | ||
|
||
|
||
def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg): | ||
with pm.Model(coords=coords) as model: | ||
priors = pmx.utils.prior.prior_from_idata( | ||
idata, var_names=user_param_cfg[0], **user_param_cfg[1] | ||
) | ||
test_prior = pm.sample_prior_predictive(1) | ||
names = [p["name"] for p in param_cfg.values()] | ||
assert set(model.named_vars) == {"trace_prior_", *names} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from pymc_experimental.utils import spline | ||
from pymc_experimental.utils import prior |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict, List | ||
import aeppl.transforms | ||
import arviz | ||
import pymc as pm | ||
import aesara.tensor as at | ||
import numpy as np | ||
|
||
|
||
class ParamCfg(TypedDict): | ||
name: str | ||
transform: Optional[aeppl.transforms.RVTransform] | ||
dims: Optional[Union[str, Tuple[str]]] | ||
|
||
|
||
class ShapeInfo(TypedDict): | ||
# shape might not match slice due to a transform | ||
shape: Tuple[int] # transformed shape | ||
slice: slice | ||
|
||
|
||
class VarInfo(TypedDict): | ||
sinfo: ShapeInfo | ||
vinfo: ParamCfg | ||
|
||
|
||
class FlatInfo(TypedDict): | ||
data: np.ndarray | ||
info: List[VarInfo] | ||
|
||
|
||
def _arg_to_param_cfg( | ||
key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None | ||
): | ||
if value is None: | ||
cfg = ParamCfg(name=key, transform=None, dims=None) | ||
elif isinstance(value, Tuple): | ||
cfg = ParamCfg(name=key, transform=None, dims=value) | ||
elif isinstance(value, str): | ||
cfg = ParamCfg(name=value, transform=None, dims=None) | ||
elif isinstance(value, aeppl.transforms.RVTransform): | ||
cfg = ParamCfg(name=key, transform=value, dims=None) | ||
else: | ||
cfg = value.copy() | ||
cfg.setdefault("name", key) | ||
cfg.setdefault("transform", None) | ||
cfg.setdefault("dims", None) | ||
return cfg | ||
|
||
|
||
def _parse_args( | ||
var_names: Sequence[str], **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] | ||
) -> Dict[str, ParamCfg]: | ||
results = dict() | ||
for var in var_names: | ||
results[var] = _arg_to_param_cfg(var) | ||
for key, val in kwargs.items(): | ||
results[key] = _arg_to_param_cfg(key, val) | ||
return results | ||
|
||
|
||
def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo: | ||
posterior = idata.posterior | ||
vars = list() | ||
info = list() | ||
begin = 0 | ||
for key, cfg in kwargs.items(): | ||
data = ( | ||
posterior[key] | ||
# combine all draws from all chains | ||
.stack(__sample__=["chain", "draw"]) | ||
# move sample dim to the first position | ||
# no matter where it was before | ||
.transpose("__sample__", ...) | ||
# we need numpy data for all the rest functionality | ||
.values | ||
) | ||
# omitting __sample__ | ||
# we need shape in the untransformed space | ||
if cfg["transform"] is not None: | ||
# some transforms need original shape | ||
data = cfg["transform"].forward(data).eval() | ||
shape = data.shape[1:] | ||
# now we can get rid of shape | ||
data = data.reshape(data.shape[0], -1) | ||
end = begin + data.shape[1] | ||
vars.append(data) | ||
sinfo = dict(shape=shape, slice=slice(begin, end)) | ||
info.append(dict(sinfo=sinfo, vinfo=cfg)) | ||
begin = end | ||
return dict(data=np.concatenate(vars, axis=-1), info=info) | ||
|
||
|
||
def _mean_chol(flat_array: np.ndarray): | ||
mean = flat_array.mean(0) | ||
cov = np.cov(flat_array, rowvar=False) | ||
chol = np.linalg.cholesky(cov) | ||
return mean, chol | ||
|
||
|
||
def _mvn_prior_from_flat_info(name, flat_info: FlatInfo): | ||
mean, chol = _mean_chol(flat_info["data"]) | ||
base_dist = pm.Normal(name, np.zeros_like(mean)) | ||
interim = mean + chol @ base_dist | ||
result = dict() | ||
for var_info in flat_info["info"]: | ||
sinfo = var_info["sinfo"] | ||
vinfo = var_info["vinfo"] | ||
var = interim[sinfo["slice"]].reshape(sinfo["shape"]) | ||
if vinfo["transform"] is not None: | ||
var = vinfo["transform"].backward(var) | ||
var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"]) | ||
result[vinfo["name"]] = var | ||
return result | ||
|
||
|
||
def prior_from_idata( | ||
idata: arviz.InferenceData, | ||
name="trace_prior_", | ||
*, | ||
var_names: Sequence[str], | ||
**kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] | ||
) -> Dict[str, at.TensorVariable]: | ||
""" | ||
Create a prior from posterior. | ||
|
||
Parameters | ||
---------- | ||
idata: arviz.InferenceData | ||
Inference data with posterior group | ||
var_names: Sequence[str] | ||
names of variables to take as is from the posterior | ||
kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple] | ||
names of variables with additional configuration, see more in Examples | ||
|
||
Examples | ||
-------- | ||
>>> import pymc as pm | ||
>>> import pymc.distributions.transforms as transforms | ||
>>> import numpy as np | ||
>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model1: | ||
... a = pm.Normal("a") | ||
... b = pm.Normal("b", dims="test") | ||
... c = pm.HalfNormal("c") | ||
... d = pm.Normal("d") | ||
... e = pm.Normal("e") | ||
... f = pm.Dirichlet("f", np.ones(3), dims="options") | ||
... trace = pm.sample(progressbar=False) | ||
|
||
You can reuse the posterior in the new model. | ||
|
||
>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2: | ||
... priors = prior_from_idata( | ||
... trace, # the old trace (posterior) | ||
... var_names=["a", "d"], # take variables as is | ||
... | ||
... e="new_e", # assign new name "new_e" for a variable | ||
... # similar to dict(name="new_e") | ||
... | ||
... b=("test", ), # set a coord to "test" | ||
... # similar to dict(dims=("test", )) | ||
... | ||
... c=transforms.log, # apply log transform to a positive variable | ||
... # similar to dict(transform=transforms.log) | ||
... | ||
... # set a name, assign a coord and apply simplex transform | ||
ferrine marked this conversation as resolved.
Show resolved
Hide resolved
|
||
... f=dict(name="new_f", dims="options", transform=transforms.simplex) | ||
... ) | ||
... trace1 = pm.sample_prior_predictive(100) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be worth adding a note or even the code to use plot_pair to compare the obtained posterior to the generated prior. Even with a mvnormal and transforms, there might be cases where the posterior is not retrieved correctly, and it will generally fail if the wrong transform is used, not sure how aware of default transforms are users, I'd think the vast majority have no idea a transform is happening when they use half distributions for example. Note: regarding auto-use of default transforms. I think that arviz-devs/arviz#2056 plus a key code to map the strings in the attributes to common transforms will generally fix this issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added an explanation about this briefly |
||
""" | ||
param_cfg = _parse_args(var_names=var_names, **kwargs) | ||
flat_info = _flatten(idata, **param_cfg) | ||
return _mvn_prior_from_flat_info(name, flat_info) |
Uh oh!
There was an error while loading. Please reload this page.