Skip to content

Adds tests and mode for dirichlet multinomial distribution #5225

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

morganstrom
Copy link
Contributor

The code is almost identical to that for the multinomial distribution (PR #5201)

Thank your for opening a PR!

Depending on what your PR does, here are a few things you might want to address in the description:

@codecov
Copy link

codecov bot commented Nov 25, 2021

Codecov Report

Merging #5225 (8772c0d) into main (24f9bd4) will decrease coverage by 0.97%.
The diff coverage is 100.00%.

❗ Current head 8772c0d differs from pull request most recent head e6c7f91. Consider uploading reports for the commit e6c7f91 to get more accurate results

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5225      +/-   ##
==========================================
- Coverage   81.44%   80.47%   -0.98%     
==========================================
  Files          81       82       +1     
  Lines       14204    14160      -44     
==========================================
- Hits        11569    11395     -174     
- Misses       2635     2765     +130     
Impacted Files Coverage Δ
pymc/distributions/multivariate.py 76.34% <100.00%> (-15.19%) ⬇️
pymc/distributions/transforms.py 92.55% <0.00%> (-7.45%) ⬇️
pymc/distributions/discrete.py 98.79% <0.00%> (-0.97%) ⬇️
pymc/sampling_jax.py 97.45% <0.00%> (-0.85%) ⬇️
pymc/distributions/continuous.py 96.93% <0.00%> (-0.30%) ⬇️
pymc/sampling.py 86.54% <0.00%> (-0.23%) ⬇️
pymc/math.py 69.90% <0.00%> (-0.15%) ⬇️
pymc/tuning/starting.py 92.56% <0.00%> (-0.13%) ⬇️
pymc/model.py 85.93% <0.00%> (-0.04%) ⬇️
... and 10 more

@michaelosthege michaelosthege added this to the v4.0.0-beta1 (vNext) milestone Nov 27, 2021
@ricardoV94
Copy link
Member

/pre-commit-run

@ricardoV94 ricardoV94 mentioned this pull request Nov 29, 2021
51 tasks
p = a / at.sum(a, axis=-1, keepdims=True)
if a.ndim > 1:
n = at.shape_padright(n)
if (a.ndim == 1) & (n.ndim > 0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (a.ndim == 1) & (n.ndim > 0):
if (a.ndim == 1) and (n.ndim > 0):

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure the shape handling is completely correct with the V4 version of the DM. You may want to check how the rng_fn is handling broadcasting. It might suffice to do this line:

n, a = broadcast_params([n, a], cls.ndims_params)

This may also apply to the Multinomial distribution:

n, p = broadcast_params([n, p], cls.ndims_params)

@morganstrom
Copy link
Contributor Author

I am not sure the shape handling is completely correct with the V4 version of the DM. You may want to check how the rng_fn is handling broadcasting. It might suffice to do this line:

n, a = broadcast_params([n, a], cls.ndims_params)

This may also apply to the Multinomial distribution:

n, p = broadcast_params([n, p], cls.ndims_params)

Thanks, I wasn't aware of this function! I'm getting incompatible shapes when I try to use it however.. The code you refer to uses a constant (cls.ndims_params) that is set to [0, 1], which doesn't seem to work here..

ndims_params = [0, 1]

I think the issue might be that dimension 0 of n (in case of an array) is required to match dimension 1 of a - this is why we're using at.shape_padright for n and at.shape_padleft for a :

Total counts in each replicate. If n is an array its shape must be (N,)

@ricardoV94
Copy link
Member

I think the issue might be that dimension 0 of n (in case of an array) is required to match dimension 1 of a - this is why we're using at.shape_padright for n and at.shape_padleft for a :

Total counts in each replicate. If n is an array its shape must be (N,)

That sounds like an old constraint that is no longer needed in V4. This works just fine on the random side, we might just need to check it also works on the logp:

y = pm.DirichletMultinomial.dist(n=np.arange(1, 1+2*4).reshape(2, 4), a=np.ones(3))
y.eval()
array([[[1, 0, 0],
        [0, 2, 0],
        [0, 0, 3],
        [3, 0, 1]],
       [[1, 2, 2],
        [0, 5, 1],
        [0, 3, 4],
        [2, 0, 6]]])

@ricardoV94
Copy link
Member

I think the issue might be that dimension 0 of n (in case of an array) is required to match dimension 1 of a - this is why we're using at.shape_padright for n and at.shape_padleft for a :

Total counts in each replicate. If n is an array its shape must be (N,)

That sounds like an old constraint that is no longer needed in V4. This works just fine on the random side, we might just need to check it also works on the logp:

y = pm.DirichletMultinomial.dist(n=np.arange(1, 1+2*4).reshape(2, 4), a=np.ones(3))
y.eval()
array([[[1, 0, 0],
        [0, 2, 0],
        [0, 0, 3],
        [3, 0, 1]],
       [[1, 2, 2],
        [0, 5, 1],
        [0, 3, 4],
        [2, 0, 6]]])

@ricardoV94 ricardoV94 closed this Dec 2, 2021
@ricardoV94 ricardoV94 reopened this Dec 2, 2021
@ricardoV94
Copy link
Member

ricardoV94 commented Dec 2, 2021

I am lifting the constraint of n in #5234

@ricardoV94
Copy link
Member

@morganstrom in #5234 I removed the old constraint on the dimensionality of n and p. The get_moment should share the same flexibility, and we have to do the same for the Multinoimal. Do you want to incorporate those updates in this PR?

@twiecki
Copy link
Member

twiecki commented Jan 7, 2022

@morganstrom Any update on this?

@ricardoV94 ricardoV94 modified the milestones: v4.0.0b2, v4.0.0b3 Jan 7, 2022
@morganstrom
Copy link
Contributor Author

morganstrom commented Jan 8, 2022 via email

if (a.ndim == 1) & (n.ndim > 0):
n = at.shape_padright(n)
p = at.shape_padleft(p)
p = a / at.sum(a, axis=-1)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 Here I'm trying to mirror the changes you made in #5234

np.array([1, 10]),
2,
np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]), # Dim: 2 x 4
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm getting an error here: TypeError: Cannot convert Type TensorType(int64, matrix) (of Variable Elemwise{Cast{int64}}.0) into Type TensorType(int64, (False, True, False))

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems moment is returning a 3d tensor?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can try to call get_moment method directly and evaluate the returned graph, in order to facilitate debugging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I do what you suggest (in a debug console), I get a 2d array - confusing!

get_moment(model["x"]).eval()
array([[1, 0, 0, 0],
       [2, 3, 3, 2]])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you using the latest Aesara version? There was a bug with shapes/casting, that was fixed recently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so in that example the distribution shape is actually (2, 1, 4), but the returned moment is (2, 4).

pm.DirichletMultinomial.dist(
    a=np.array([[26, 26, 26, 22]]),  # Dim: 1 x 4
    n=np.array([[1], [10]]),  # Dim: 2 x 1
).eval().shape  # (2, 1, 4)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I pushed some changes that I think make the mode consistent with the RV shape. Does it look sensible to you?

This was a great edge case you stumbled upon. We could very well have missed it

@ricardoV94 ricardoV94 force-pushed the dirichlet_multinomial_moments branch from 89bd865 to c78b5e2 Compare January 26, 2022 08:12
@ricardoV94
Copy link
Member

@morganstrom I rebased from main, so make sure to pull from here before doing further work

@ricardoV94 ricardoV94 force-pushed the dirichlet_multinomial_moments branch from 64673c5 to 248e1d8 Compare January 26, 2022 08:35
@ricardoV94 ricardoV94 force-pushed the dirichlet_multinomial_moments branch from af05670 to 8772c0d Compare January 26, 2022 10:13
@ricardoV94 ricardoV94 modified the milestones: v4.0.0b3, v4.0.0 Feb 7, 2022
@ricardoV94 ricardoV94 force-pushed the dirichlet_multinomial_moments branch from 8772c0d to e6c7f91 Compare February 22, 2022 08:49
@ricardoV94 ricardoV94 merged commit da2030e into pymc-devs:main Feb 22, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants