Skip to content

Ordered transform incompatible with constrained space transforms (ZeroSum, Simplex, etc.) #6975

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
michaelosthege opened this issue Oct 29, 2023 · 9 comments
Labels

Comments

@michaelosthege
Copy link
Member

michaelosthege commented Oct 29, 2023

Describe the issue:

When applying the ordered transform to the ZSN the model suffers a -inf logp.

Reproduceable code example:

with pm.Model() as pmodel:
    pm.ZeroSumNormal("zsn", shape=2, transform=pm.distributions.transforms.ordered)
pmodel.debug()

Error message:

RuntimeWarning: divide by zero encountered in log
  variables = ufunc(*ufunc_args, **ufunc_kwargs)
point={'zsn_ordered__': array([  0., -inf])}

The variable zsn has the following parameters:
0: normal_rv{0, (0, 0), floatX, False}.1 [id A] <Vector(float64, shape=(2,))>
 ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x1E355016500>) [id B] <RandomGeneratorType>
 ├─ [2] [id C] <Vector(int64, shape=(1,))>
 ├─ 11 [id D] <Scalar(int64, shape=())>
 ├─ 0 [id E] <Scalar(int8, shape=())>
 └─ 1.0 [id F] <Scalar(float64, shape=())>
1: 1.0 [id F] <Scalar(float64, shape=())>
2: [2] [id G] <Vector(int32, shape=(1,))>
The parameters evaluate to:
0: [1.40383536 0.612628  ]
1: 1.0
2: [2]
Some of the values of variable zsn are associated with a non-finite logp:
 value = [  0. -inf] -> logp = -inf

PyMC version information:

5.9.1 at 419af06

Context for the issue:

No response

@michaelosthege
Copy link
Member Author

I looked into it, but my understanding of the ZSN and Ordered transform is limited.

The Ordered transform works by doing a set-subtensor on a diff, and I guess this is what clashes with how the ZSN is also using set_subtensor in its logp?

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 29, 2023

You need to chain the OrderedTransform with the default transform ZSN has. Same thing as with any distribution that has a default transform.

@ricardoV94 ricardoV94 added question and removed bug labels Oct 29, 2023
@ricardoV94
Copy link
Member

Related to #5674

@michaelosthege
Copy link
Member Author

Okay, makes sense.
I tried this, but it doesn't work, so obviously I'm doing something wrong:

with pm.Model() as pmodel:
    rv = pm.ZeroSumNormal.dist(shape=2)
    trafos = [
        pm.distributions.transforms._default_transform(rv.owner.op, rv),
        pm.distributions.transforms.ordered,
    ]
    trafo_chain = pm.distributions.transforms.Chain(trafos)
    pmodel.register_rv(rv, "zsn", transform=trafo_chain)

(idata.posterior["zsn"].sel(zsn_dim_0=0) < idata.posterior["zsn"].sel(zsn_dim_0=1)).mean({"chain", "draw"})
# array(0.50466667)

@ricardoV94

This comment was marked as outdated.

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 30, 2023

As implemented, the OrderedTransform is incompatible with the ZeroSum transform.

In the shape=(2,) case it doesn't do anything, because there is nothing to order in the latent case (the zerosum transform works on a n-1 space). So if the sampler proposes a positive value b, the final variable will be [b, -b] after the backward transform.

But even with more entries, ordering these n-1 entries, doesn't mean the zerosum transform will come out ordered (e.g., if all latent variables have a sum > 0 after the ordering, the missing entry, which is always the last, will have to be smaller than at least one previous entry to balance it out)

import pymc as pm

with pm.Model() as m:
    transform = pm.distributions.transforms.Chain([
        pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
        pm.distributions.transforms.ordered,
    ])
    pm.ZeroSumNormal("zsn", shape=(2,), transform=transform)
    
zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5]}))  # array([ 1.06066017, -1.06066017])

with pm.Model() as m:
    transform = pm.distributions.transforms.Chain([
        pm.distributions.transforms.ZeroSumTransform(zerosum_axes=(-1,)),
        pm.distributions.transforms.ordered,
    ])
    pm.ZeroSumNormal("zsn", shape=(3,), transform=transform)

zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [1.5, 2.5]}))  # [-1.70843849 10.47405547 -8.76561698]

The other way around doesn't work either, because the ordering transform distorts the values (there is an exp in the backward and a log in the forward steps), so multiple entries will no longer sum to zero after the chained backward transform.

@ricardoV94 ricardoV94 changed the title BUG: Ordered transform breaks on ZeroSumNormal Ordered transform incompatible with ZeroSumTransform Oct 30, 2023
@ricardoV94
Copy link
Member

ricardoV94 commented Oct 30, 2023

This limitation is not specific to the ZeroSumTransform, it also applies for example to the SimplexTransform:

import pymc as pm

with pm.Model() as m:
    transform = pm.distributions.transforms.Chain([
        pm.distributions.transforms.simplex,
        pm.distributions.transforms.ordered,
    ])
    pm.Dirichlet("x", a=[1, 1, 1], shape=(3,), transform=transform)

zsn_latent_value = m.value_vars[-1]
zsn_value = m.unobserved_value_vars[-1]
print(zsn_value.eval({zsn_latent_value: [.5, -.75]}))  # [0.10714617 0.88997555 0.00287828]

Again the other order (simplex first, then ordering in the backward step) wouldn't provide a valid simplex anymore, because of the exp operation.

The OrderedTransform seems to only be useful in unconstrained spaces.

@ricardoV94 ricardoV94 changed the title Ordered transform incompatible with ZeroSumTransform Ordered transform incompatible with constrained space transforms (ZeroSum, Simplex, etc.) Oct 30, 2023
@michaelosthege
Copy link
Member Author

Thank you for investigating!

Would it help to drop the log and exp from the Ordered transform? I experimented with this, and IIUC it was just added for numeric stability, which might already be subject to graph rewrites?

@ricardoV94
Copy link
Member

ricardoV94 commented Oct 30, 2023

Thank you for investigating!

Would it help to drop the log and exp from the Ordered transform? I experimented with this, and IIUC it was just added for numeric stability, which might already be subject to graph rewrites?

No, it's is not about numerical stability. It's a way to achieve ordering in a continuous way. You could probably come up with a continuous constrained ordering but would need to find a different expression and make sure you have the correct jacobian for that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants