-
-
Notifications
You must be signed in to change notification settings - Fork 62
Initialize a prior from a fitted posterior #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I don't see the prior being set. Is this still a draft? I assume you will use interpolation / kde? |
Yes, this is a draft. In the snippet, I've added the API to discuss. The API is permissive for types and flexible since it is easy to figure out the intension. I plan to approximate the prior using the MvNormal in the transformed space. The prior is set inside the function, so you only get the dictionary with the final result |
ceb4d06
to
8d047fc
Compare
pymc_experimental/utils/prior.py
Outdated
for key, cfg in kwargs.items(): | ||
data = posterior[key].values | ||
# omitting chain, draw | ||
shape = data.shape[2:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no guarantee the chain
and draw
dimensions will always be in the beginning, there are perfectly valid xarray operations that modify the dimension order. In xarray only the dimension name is relevant.
A quick change to the code to take this into account would be:
sample_dims = ["chain", "draw"]
for ...
batch_dims = [dim for dim in posterior[key].dims if dim not in sample_dims]
data = posterior[key].stack(__sample__=sample_dims, __batch__=batch_dims)
end = begin + len(data["__batch__"])
I suspect it might even be possible to simplify this further using https://docs.xarray.dev/en/latest/generated/xarray.Dataset.to_stacked_array.html#xarray.Dataset.to_stacked_array plus a where
to check the start and end positions of each variable. I can take a look towards the end of July if it were to still be helpful by then.
ready for review |
About API, the most common case will be to just transfer a single parameter from posterior to prior, so I wonder what that API looks like to make sure that we're not doing the general case well but not the common case well. |
with pm.Model(coords=dict(test=range(3))) as model:
a = pmx.utils.prior.prior_from_idata(trace, var_names=["a"])["a"] and if you rename with pm.Model(coords=dict(test=range(3))) as model:
b = pmx.utils.prior.prior_from_idata(trace, a="b")["b"]
with pm.Model(coords=dict(test=range(3))) as model:
a = pmx.utils.prior.prior_from_idata(trace, a=("test", ))["a"] or with pm.Model(coords=dict(test=range(3))) as model:
a = pmx.utils.prior.prior_from_idata(trace, a=dict(dims=("test", )))["a"]
with pm.Model(coords=dict(test=range(3))) as model:
a = pmx.utils.prior.prior_from_idata(trace, a=dict(dims=("test", ), transform=transforms.simplex))["a"] If we do not need coords with pm.Model(coords=dict(test=range(3))) as model:
a = pmx.utils.prior.prior_from_idata(trace, a=transforms.simplex)["a"] |
... # set a name, assign a coord and apply simplex transform | ||
... f=dict(name="new_f", dims="options", transform=transforms.simplex) | ||
... ) | ||
... trace1 = pm.sample_prior_predictive(100) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be worth adding a note or even the code to use plot_pair to compare the obtained posterior to the generated prior. Even with a mvnormal and transforms, there might be cases where the posterior is not retrieved correctly, and it will generally fail if the wrong transform is used, not sure how aware of default transforms are users, I'd think the vast majority have no idea a transform is happening when they use half distributions for example.
Note: regarding auto-use of default transforms. I think that arviz-devs/arviz#2056 plus a key code to map the strings in the attributes to common transforms will generally fix this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added an explanation about this briefly
Co-authored-by: Oriol Abril-Pla <[email protected]>
Co-authored-by: Oriol Abril-Pla <[email protected]>
@ferrine I like the API. |
Time to merge then. Thanks for reviews |
If you want to do knowledge transfer in a smart way, this is how you do this