Skip to content

Initialize a prior from a fitted posterior #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 6, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions pymc_experimental/tests/test_prior_from_trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import pymc_experimental as pmx
from pymc.distributions import transforms
import pytest
import arviz as az
import numpy as np


@pytest.mark.parametrize(
"case",
[
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
(("a", None), dict(name="a", transform=None, dims=None)),
(("a", transforms.log), dict(name="a", transform=transforms.log, dims=None)),
(
("a", dict(transform=transforms.log)),
dict(name="a", transform=transforms.log, dims=None),
),
(("a", dict(name="b")), dict(name="b", transform=None, dims=None)),
(("a", dict(name="b", dims="test")), dict(name="b", transform=None, dims="test")),
(("a", ("test",)), dict(name="a", transform=None, dims=("test",))),
],
)
def test_parsing_arguments(case):
inp, out = case
test = pmx.utils.prior._arg_to_param_cfg(*inp)
assert test == out


@pytest.fixture
def coords():
return dict(test=range(3), simplex=range(4))


@pytest.fixture
def param_cfg():
return dict(
a=pmx.utils.prior._arg_to_param_cfg("d"),
b=pmx.utils.prior._arg_to_param_cfg("b", dict(transform=transforms.log, dims=("test",))),
c=pmx.utils.prior._arg_to_param_cfg(
"c", dict(transform=transforms.simplex, dims=("simplex",))
),
)


@pytest.fixture
def transformed_data(param_cfg, coords):
vars = dict()
for k, cfg in param_cfg.items():
if cfg["dims"] is not None:
extra_dims = [len(coords[d]) for d in cfg["dims"]]
else:
extra_dims = []
orig = np.random.randn(4, 100, *extra_dims)
vars[k] = orig
return vars


@pytest.fixture
def idata(transformed_data, param_cfg):
vars = dict()
for k, orig in transformed_data.items():
cfg = param_cfg[k]
if cfg["transform"] is not None:
var = cfg["transform"].backward(orig).eval()
else:
var = orig
assert not np.isnan(var).any()
vars[k] = var
return az.convert_to_inference_data(vars)


def test_idata_for_tests(idata, param_cfg):
assert set(idata.posterior.keys()) == set(param_cfg)
assert len(idata.posterior.coords["chain"]) == 4
assert len(idata.posterior.coords["draw"]) == 100


def test_args_compose():
cfg = pmx.utils.prior._parse_args(
var_names=["a"],
b=("test",),
c=transforms.log,
d="e",
f=dict(dims="test"),
g=dict(name="h", dims="test", transform=transforms.log),
)
assert cfg == dict(
a=dict(name="a", dims=None, transform=None),
b=dict(name="b", dims=("test",), transform=None),
c=dict(name="c", dims=None, transform=transforms.log),
d=dict(name="e", dims=None, transform=None),
f=dict(name="f", dims="test", transform=None),
g=dict(name="h", dims="test", transform=transforms.log),
)


def test_transform_idata(transformed_data, idata, param_cfg):
flat_info = pmx.utils.prior._flatten(idata, **param_cfg)
expected_shape = 0
for v in transformed_data.values():
expected_shape += int(np.prod(v.shape[2:]))
assert flat_info["data"].shape[1] == expected_shape
1 change: 1 addition & 0 deletions pymc_experimental/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from pymc_experimental.utils import spline
from pymc_experimental.utils import prior
76 changes: 76 additions & 0 deletions pymc_experimental/utils/prior.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import TypedDict, Optional, Union, Tuple, Sequence, Dict, List
import aeppl.transforms
import arviz
import numpy as np


class ParamCfg(TypedDict):
name: str
transform: Optional[aeppl.transforms.RVTransform]
dims: Optional[Union[str, Tuple[str]]]


class ShapeInfo(TypedDict):
# shape might not match slice due to a transform
shape: Tuple[int]
slice: slice


class VarInfo(TypedDict):
sinfo: ShapeInfo
vinfo: ParamCfg


class FlatInfo(TypedDict):
data: np.ndarray
info: List[VarInfo]


def _arg_to_param_cfg(
key, value: Optional[Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]] = None
):
if value is None:
cfg = ParamCfg(name=key, transform=None, dims=None)
elif isinstance(value, Tuple):
cfg = ParamCfg(name=key, transform=None, dims=value)
elif isinstance(value, str):
cfg = ParamCfg(name=value, transform=None, dims=None)
elif isinstance(value, aeppl.transforms.RVTransform):
cfg = ParamCfg(name=key, transform=value, dims=None)
else:
cfg = value.copy()
cfg.setdefault("name", key)
cfg.setdefault("transform", None)
cfg.setdefault("dims", None)
return cfg


def _parse_args(
var_names: Sequence[str], **kwargs: Union[ParamCfg, aeppl.transforms.RVTransform, str, Tuple]
) -> Dict[str, ParamCfg]:
results = dict()
for var in var_names:
results[var] = _arg_to_param_cfg(var)
for key, val in kwargs.items():
results[key] = _arg_to_param_cfg(key, val)
return results


def _flatten(idata: arviz.InferenceData, **kwargs: ParamCfg) -> FlatInfo:
posterior = idata.posterior
vars = list()
info = list()
begin = 0
for key, cfg in kwargs.items():
data = posterior[key].values
# omitting chain, draw
shape = data.shape[2:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no guarantee the chain and draw dimensions will always be in the beginning, there are perfectly valid xarray operations that modify the dimension order. In xarray only the dimension name is relevant.

A quick change to the code to take this into account would be:

sample_dims = ["chain", "draw"]
for ...
    batch_dims = [dim for dim in posterior[key].dims if dim not in sample_dims]
    data = posterior[key].stack(__sample__=sample_dims, __batch__=batch_dims)
    end = begin + len(data["__batch__"])

I suspect it might even be possible to simplify this further using https://docs.xarray.dev/en/latest/generated/xarray.Dataset.to_stacked_array.html#xarray.Dataset.to_stacked_array plus a where to check the start and end positions of each variable. I can take a look towards the end of July if it were to still be helpful by then.

if cfg["transform"] is not None:
data = cfg["transform"].forward(data).eval()
data = data.reshape(*data.shape[:2], -1)
data = data.reshape(-1, data.shape[2])
end = begin + data.shape[1]
vars.append(data)
info.append(dict(shape=shape, slice=slice(begin, end)))
begin = end
return dict(data=np.concatenate(vars, axis=-1), infp=info)