Skip to content

Initialize a prior from a fitted posterior #56

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Jul 6, 2022
Merged

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Jun 29, 2022

If you want to do knowledge transfer in a smart way, this is how you do this

from pymc.distributions import transforms

with model1:
    trace = pm.sample()
    # trace.posterior.keys() ~ ["a", "b", "c", "d", "f", "g"]
    # a - vector
    # b - matrix
    # c - positive
    # d, f, g - some other variable we do not care about


with pm.Model(coords=dict(test=range(3))) as model:
    priors = pmx.utils.prior.prior_from_idata(
        trace, 
        var_names=["a"],
        b=("test", "test"),
        c=transforms.log, 
        d="e", 
        f=dict(dims="test"),
        g=dict(name="h", dims="test", transform=transforms.log)
    )
    # 0. do nothing special to 'a' and other items in "var_names"
    # 1. 'b' has coords ("test", "test")
    # 2. transform 'c' to logspace
    # 3. rename 'd' to 'e'
    # 4. say 'f' has coords 'test'
    # 5. do everything mentioned with 'g'
    # priors will be a dictionary with all the priors, variables are available by final name keys

@ricardoV94
Copy link
Member

I don't see the prior being set. Is this still a draft? I assume you will use interpolation / kde?

@ferrine
Copy link
Member Author

ferrine commented Jun 29, 2022

Yes, this is a draft. In the snippet, I've added the API to discuss. The API is permissive for types and flexible since it is easy to figure out the intension. I plan to approximate the prior using the MvNormal in the transformed space.

The prior is set inside the function, so you only get the dictionary with the final result

@twiecki twiecki marked this pull request as draft June 29, 2022 18:36
@ferrine ferrine force-pushed the prior-from-posterior branch from ceb4d06 to 8d047fc Compare June 30, 2022 12:07
for key, cfg in kwargs.items():
data = posterior[key].values
# omitting chain, draw
shape = data.shape[2:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no guarantee the chain and draw dimensions will always be in the beginning, there are perfectly valid xarray operations that modify the dimension order. In xarray only the dimension name is relevant.

A quick change to the code to take this into account would be:

sample_dims = ["chain", "draw"]
for ...
    batch_dims = [dim for dim in posterior[key].dims if dim not in sample_dims]
    data = posterior[key].stack(__sample__=sample_dims, __batch__=batch_dims)
    end = begin + len(data["__batch__"])

I suspect it might even be possible to simplify this further using https://docs.xarray.dev/en/latest/generated/xarray.Dataset.to_stacked_array.html#xarray.Dataset.to_stacked_array plus a where to check the start and end positions of each variable. I can take a look towards the end of July if it were to still be helpful by then.

@ferrine ferrine marked this pull request as ready for review July 1, 2022 18:37
@ferrine
Copy link
Member Author

ferrine commented Jul 1, 2022

ready for review

@ferrine ferrine requested a review from OriolAbril July 3, 2022 06:44
@twiecki
Copy link
Member

twiecki commented Jul 4, 2022

About API, the most common case will be to just transfer a single parameter from posterior to prior, so I wonder what that API looks like to make sure that we're not doing the general case well but not the common case well.

@ferrine
Copy link
Member Author

ferrine commented Jul 4, 2022

About API, the most common case will be to just transfer a single parameter from posterior to prior, so I wonder what that API looks like to make sure that we're not doing the general case well but not the common case well.

  • A scalar variable without transform
with pm.Model(coords=dict(test=range(3))) as model:
    a = pmx.utils.prior.prior_from_idata(trace, var_names=["a"])["a"]

and if you rename

with pm.Model(coords=dict(test=range(3))) as model:
    b = pmx.utils.prior.prior_from_idata(trace, a="b")["b"]
  • A vector variable without transform, adding dims
with pm.Model(coords=dict(test=range(3))) as model:
    a = pmx.utils.prior.prior_from_idata(trace, a=("test", ))["a"]

or

with pm.Model(coords=dict(test=range(3))) as model:
    a = pmx.utils.prior.prior_from_idata(trace, a=dict(dims=("test", )))["a"]
  • A vector variable with transform ("what's left from Dirichlet")
with pm.Model(coords=dict(test=range(3))) as model:
    a = pmx.utils.prior.prior_from_idata(trace, a=dict(dims=("test", ), transform=transforms.simplex))["a"]

If we do not need coords

with pm.Model(coords=dict(test=range(3))) as model:
    a = pmx.utils.prior.prior_from_idata(trace, a=transforms.simplex)["a"]

... # set a name, assign a coord and apply simplex transform
... f=dict(name="new_f", dims="options", transform=transforms.simplex)
... )
... trace1 = pm.sample_prior_predictive(100)
Copy link
Member

@OriolAbril OriolAbril Jul 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth adding a note or even the code to use plot_pair to compare the obtained posterior to the generated prior. Even with a mvnormal and transforms, there might be cases where the posterior is not retrieved correctly, and it will generally fail if the wrong transform is used, not sure how aware of default transforms are users, I'd think the vast majority have no idea a transform is happening when they use half distributions for example.

Note: regarding auto-use of default transforms. I think that arviz-devs/arviz#2056 plus a key code to map the strings in the attributes to common transforms will generally fix this issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added an explanation about this briefly

@twiecki
Copy link
Member

twiecki commented Jul 5, 2022

@ferrine I like the API.

@ferrine ferrine merged commit dea6bc9 into main Jul 6, 2022
@ferrine ferrine deleted the prior-from-posterior branch July 6, 2022 05:59
@ferrine
Copy link
Member Author

ferrine commented Jul 6, 2022

Time to merge then. Thanks for reviews

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants