Skip to content

Initialize a prior from a fitted posterior #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,6 @@ methods in the current release of PyMC experimental.
.. automodule:: pymc_experimental.utils.spline
:members: bspline_interpolation

.. automodule:: pymc_experimental.utils.prior
:members: prior_from_idata

141 changes: 141 additions & 0 deletions pymc_experimental/tests/test_prior_from_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
import pymc_experimental as pmx
from pymc.distributions import transforms
import pytest
import arviz as az
import numpy as np
import pymc as pm


@pytest.mark.parametrize(
"case",
[
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
(("a", None), dict(name="a", transform=None, dims=None)),
(("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)),
(
("a", dict(transform=transforms.log)),
dict(name="a", transform=transforms.log, dims=None),
),
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
(("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")),
(("a", ("test",)), dict(name="a", transform=None, dims=("test",))),
],
)
def test_parsing_arguments(case):
inp, out = case
test = pmx.utils.prior._arg_to_param_cfg(*inp)
assert test == out


@pytest.fixture
def coords():
return dict(test=range(3), simplex=range(4))


@pytest.fixture
def user_param_cfg():
return ("t",), dict(
a="d",
b=dict(transform=transforms.log, dims=("test",)),
c=dict(transform=transforms.simplex, dims=("simplex",)),
)


@pytest.fixture
def param_cfg(user_param_cfg):
return pmx.utils.prior._parse_args(user_param_cfg[0], **user_param_cfg[1])


@pytest.fixture
def transformed_data(param_cfg, coords):
vars = dict()
for k, cfg in param_cfg.items():
if cfg["dims"] is not None:
extra_dims = [len(coords[d]) for d in cfg["dims"]]
if cfg["transform"] is not None:
t = np.random.randn(*extra_dims)
extra_dims = tuple(cfg["transform"].forward(t).shape.eval())
else:
extra_dims = []
orig = np.random.randn(4, 100, *extra_dims)
vars[k] = orig
return vars


@pytest.fixture
def idata(transformed_data, param_cfg):
vars = dict()
for k, orig in transformed_data.items():
cfg = param_cfg[k]
if cfg["transform"] is not None:
var = cfg["transform"].backward(orig).eval()
else:
var = orig
assert not np.isnan(var).any()
vars[k] = var
return az.convert_to_inference_data(vars)


def test_idata_for_tests(idata, param_cfg):
assert set(idata.posterior.keys()) == set(param_cfg)
assert len(idata.posterior.coords["chain"]) == 4
assert len(idata.posterior.coords["draw"]) == 100


def test_args_compose():
cfg = pmx.utils.prior._parse_args(
var_names=["a"],
b=("test",),
c=transforms.log,
d="e",
f=dict(dims="test"),
g=dict(name="h", dims="test", transform=transforms.log),
)
assert cfg == dict(
a=dict(name="a", dims=None, transform=None),
b=dict(name="b", dims=("test",), transform=None),
c=dict(name="c", dims=None, transform=transforms.log),
d=dict(name="e", dims=None, transform=None),
f=dict(name="f", dims="test", transform=None),
g=dict(name="h", dims="test", transform=transforms.log),
)


def test_transform_idata(transformed_data, idata, param_cfg):
flat_info = pmx.utils.prior._flatten(idata, **param_cfg)
expected_shape = 0
for v in transformed_data.values():
expected_shape += int(np.prod(v.shape[2:]))
assert flat_info["data"].shape[1] == expected_shape
assert len(flat_info["info"]) == len(param_cfg)
assert "sinfo" in flat_info["info"][0]
assert "vinfo" in flat_info["info"][0]


@pytest.fixture
def flat_info(idata, param_cfg):
return pmx.utils.prior._flatten(idata, **param_cfg)


def test_mean_chol(flat_info):
mean, chol = pmx.utils.prior._mean_chol(flat_info["data"])
assert mean.shape == (flat_info["data"].shape[1],)
assert chol.shape == (flat_info["data"].shape[1],) * 2


def test_mvn_prior_from_flat_info(flat_info, coords, param_cfg):
with pm.Model(coords=coords) as model:
priors = pmx.utils.prior._mvn_prior_from_flat_info("trace_prior_", flat_info)
test_prior = pm.sample_prior_predictive(1)
names = [p["name"] for p in param_cfg.values()]
assert set(model.named_vars) == {"trace_prior_", *names}


def test_prior_from_idata(idata, user_param_cfg, coords, param_cfg):
with pm.Model(coords=coords) as model:
priors = pmx.utils.prior.prior_from_idata(
idata, var_names=user_param_cfg[0], **user_param_cfg[1]
)
test_prior = pm.sample_prior_predictive(1)
names = [p["name"] for p in param_cfg.values()]
assert set(model.named_vars) == {"trace_prior_", *names}
1 change: 1 addition & 0 deletions pymc_experimental/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pymc_experimental.utils import spline
from pymc_experimental.utils import prior
181 changes: 181 additions & 0 deletions pymc_experimental/utils/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict, List
import aeppl.transforms
import arviz
import pymc as pm
import aesara.tensor as at
import numpy as np


class ParamCfg(TypedDict):
name: str
transform: Optional[aeppl.transforms.RVTransform]
dims: Optional[Union[str, Tuple[str]]]


class ShapeInfo(TypedDict):
# shape might not match slice due to a transform
shape: Tuple[int] # transformed shape
slice: slice


class VarInfo(TypedDict):
sinfo: ShapeInfo
vinfo: ParamCfg


class FlatInfo(TypedDict):
data: np.ndarray
info: List[VarInfo]


def _arg_to_param_cfg(
key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None
):
if value is None:
cfg = ParamCfg(name=key, transform=None, dims=None)
elif isinstance(value, Tuple):
cfg = ParamCfg(name=key, transform=None, dims=value)
elif isinstance(value, str):
cfg = ParamCfg(name=value, transform=None, dims=None)
elif isinstance(value, aeppl.transforms.RVTransform):
cfg = ParamCfg(name=key, transform=value, dims=None)
else:
cfg = value.copy()
cfg.setdefault("name", key)
cfg.setdefault("transform", None)
cfg.setdefault("dims", None)
return cfg


def _parse_args(
var_names: Sequence[str], **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]
) -> Dict[str, ParamCfg]:
results = dict()
for var in var_names:
results[var] = _arg_to_param_cfg(var)
for key, val in kwargs.items():
results[key] = _arg_to_param_cfg(key, val)
return results


def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo:
posterior = idata.posterior
vars = list()
info = list()
begin = 0
for key, cfg in kwargs.items():
data = (
posterior[key]
# combine all draws from all chains
.stack(__sample__=["chain", "draw"])
# move sample dim to the first position
# no matter where it was before
.transpose("__sample__", ...)
# we need numpy data for all the rest functionality
.values
)
# omitting __sample__
# we need shape in the untransformed space
if cfg["transform"] is not None:
# some transforms need original shape
data = cfg["transform"].forward(data).eval()
shape = data.shape[1:]
# now we can get rid of shape
data = data.reshape(data.shape[0], -1)
end = begin + data.shape[1]
vars.append(data)
sinfo = dict(shape=shape, slice=slice(begin, end))
info.append(dict(sinfo=sinfo, vinfo=cfg))
begin = end
return dict(data=np.concatenate(vars, axis=-1), info=info)


def _mean_chol(flat_array: np.ndarray):
mean = flat_array.mean(0)
cov = np.cov(flat_array, rowvar=False)
chol = np.linalg.cholesky(cov)
return mean, chol


def _mvn_prior_from_flat_info(name, flat_info: FlatInfo):
mean, chol = _mean_chol(flat_info["data"])
base_dist = pm.Normal(name, np.zeros_like(mean))
interim = mean + chol @ base_dist
result = dict()
for var_info in flat_info["info"]:
sinfo = var_info["sinfo"]
vinfo = var_info["vinfo"]
var = interim[sinfo["slice"]].reshape(sinfo["shape"])
if vinfo["transform"] is not None:
var = vinfo["transform"].backward(var)
var = pm.Deterministic(vinfo["name"], var, dims=vinfo["dims"])
result[vinfo["name"]] = var
return result


def prior_from_idata(
idata: arviz.InferenceData,
name="trace_prior_",
*,
var_names: Sequence[str],
**kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]
) -> Dict[str, at.TensorVariable]:
"""
Create a prior from posterior using MvNormal approximation.

The approximation uses MvNormal distribution.
Keep in mind that this function will only work well for unimodal
posteriors and will fail when complicated interactions happen.

Moreover, if a retrieved variable is constrained, you
should specify a transform for the variable, e.g.
``pymc.distributions.transforms.log`` for standard
deviation posterior.

Parameters
----------
idata: arviz.InferenceData
Inference data with posterior group
var_names: Sequence[str]
names of variables to take as is from the posterior
kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]
names of variables with additional configuration, see more in Examples

Examples
--------
>>> import pymc as pm
>>> import pymc.distributions.transforms as transforms
>>> import numpy as np
>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model1:
... a = pm.Normal("a")
... b = pm.Normal("b", dims="test")
... c = pm.HalfNormal("c")
... d = pm.Normal("d")
... e = pm.Normal("e")
... f = pm.Dirichlet("f", np.ones(3), dims="options")
... trace = pm.sample(progressbar=False)

You can reuse the posterior in the new model.

>>> with pm.Model(coords=dict(test=range(4), options=range(3))) as model2:
... priors = prior_from_idata(
... trace, # the old trace (posterior)
... var_names=["a", "d"], # take variables as is
...
... e="new_e", # assign new name "new_e" for a variable
... # similar to dict(name="new_e")
...
... b=("test", ), # set a dim to "test"
... # similar to dict(dims=("test", ))
...
... c=transforms.log, # apply log transform to a positive variable
... # similar to dict(transform=transforms.log)
...
... # set a name, assign a dim and apply simplex transform
... f=dict(name="new_f", dims="options", transform=transforms.simplex)
... )
... trace1 = pm.sample_prior_predictive(100)
Copy link
Member

@OriolAbril OriolAbril Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding a note or even the code to use plot_pair to compare the obtained posterior to the generated prior. Even with a mvnormal and transforms, there might be cases where the posterior is not retrieved correctly, and it will generally fail if the wrong transform is used, not sure how aware of default transforms are users, I'd think the vast majority have no idea a transform is happening when they use half distributions for example.

Note: regarding auto-use of default transforms. I think that arviz-devs/arviz#2056 plus a key code to map the strings in the attributes to common transforms will generally fix this issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an explanation about this briefly

"""
param_cfg = _parse_args(var_names=var_names, **kwargs)
flat_info = _flatten(idata, **param_cfg)
return _mvn_prior_from_flat_info(name, flat_info)