-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Ordered transform incompatible with constrained space transforms (ZeroSum, Simplex, etc.) #6975
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I looked into it, but my understanding of the ZSN and Ordered transform is limited. The Ordered transform works by doing a set-subtensor on a diff, and I guess this is what clashes with how the ZSN is also using |
You need to chain the OrderedTransform with the default transform ZSN has. Same thing as with any distribution that has a default transform. |
Related to #5674 |
Okay, makes sense. with pm.Model() as pmodel:
rv = pm.ZeroSumNormal.dist(shape=2)
trafos = [
pm.distributions.transforms._default_transform(rv.owner.op, rv),
pm.distributions.transforms.ordered,
]
trafo_chain = pm.distributions.transforms.Chain(trafos)
pmodel.register_rv(rv, "zsn", transform=trafo_chain)
(idata.posterior["zsn"].sel(zsn_dim_0=0) < idata.posterior["zsn"].sel(zsn_dim_0=1)).mean({"chain", "draw"})
# array(0.50466667) |
This comment was marked as outdated.
This comment was marked as outdated.
As implemented, the OrderedTransform is incompatible with the ZeroSum transform. In the But even with more entries, ordering these n-1 entries, doesn't mean the zerosum transform will come out ordered (e.g., if all latent variables have a sum > 0 after the ordering, the missing entry, which is always the last, will have to be smaller than at least one previous entry to balance it out) import pymc as pm
with pm.Model() as m:
transform = pm.distributions.transforms.Chain([
pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
pm.distributions.transforms.ordered,
])
pm.ZeroSumNormal("zsn", shape=(2,), transform=transform)
zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5]})) # array([ 1.06066017, -1.06066017])
with pm.Model() as m:
transform = pm.distributions.transforms.Chain([
pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
pm.distributions.transforms.ordered,
])
pm.ZeroSumNormal("zsn", shape=(3,), transform=transform)
zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5, 2.5]})) # [-1.70843849 10.47405547 -8.76561698] The other way around doesn't work either, because the ordering transform distorts the values (there is an exp in the backward and a log in the forward steps), so multiple entries will no longer sum to zero after the chained backward transform. |
This limitation is not specific to the ZeroSumTransform, it also applies for example to the SimplexTransform: import pymc as pm
with pm.Model() as m:
transform = pm.distributions.transforms.Chain([
pm.distributions.transforms.simplex,
pm.distributions.transforms.ordered,
])
pm.Dirichlet("x", a=[1, 1, 1], shape=(3,), transform=transform)
zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [.5, -.75]})) # [0.10714617 0.88997555 0.00287828] Again the other order (simplex first, then ordering in the backward step) wouldn't provide a valid simplex anymore, because of the exp operation. The OrderedTransform seems to only be useful in unconstrained spaces. |
Thank you for investigating! Would it help to drop the |
No, it's is not about numerical stability. It's a way to achieve ordering in a continuous way. You could probably come up with a continuous constrained ordering but would need to find a different expression and make sure you have the correct jacobian for that. |
Uh oh!
There was an error while loading. Please reload this page.
Describe the issue:
When applying the ordered transform to the ZSN the model suffers a
-inf
logp.Reproduceable code example:
Error message:
PyMC version information:
5.9.1 at 419af06
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: