-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Adds tests and mode for dirichlet multinomial distribution #5225
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adds tests and mode for dirichlet multinomial distribution #5225
Conversation
Codecov Report
@@ Coverage Diff @@
## main #5225 +/- ##
==========================================
- Coverage 81.44% 80.47% -0.98%
==========================================
Files 81 82 +1
Lines 14204 14160 -44
==========================================
- Hits 11569 11395 -174
- Misses 2635 2765 +130
|
/pre-commit-run |
pymc/distributions/multivariate.py
Outdated
p = a / at.sum(a, axis=-1, keepdims=True) | ||
if a.ndim > 1: | ||
n = at.shape_padright(n) | ||
if (a.ndim == 1) & (n.ndim > 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (a.ndim == 1) & (n.ndim > 0): | |
if (a.ndim == 1) and (n.ndim > 0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure the shape handling is completely correct with the V4 version of the DM. You may want to check how the rng_fn
is handling broadcasting. It might suffice to do this line:
pymc/pymc/distributions/multivariate.py
Line 582 in 9d4691c
n, a = broadcast_params([n, a], cls.ndims_params) |
This may also apply to the Multinomial distribution:
pymc/pymc/distributions/multivariate.py
Line 468 in 9d4691c
n, p = broadcast_params([n, p], cls.ndims_params) |
Thanks, I wasn't aware of this function! I'm getting incompatible shapes when I try to use it however.. The code you refer to uses a constant ( pymc/pymc/distributions/multivariate.py Line 571 in 9d4691c
I think the issue might be that dimension 0 of pymc/pymc/distributions/multivariate.py Line 630 in 9d4691c
|
That sounds like an old constraint that is no longer needed in V4. This works just fine on the random side, we might just need to check it also works on the logp: y = pm.DirichletMultinomial.dist(n=np.arange(1, 1+2*4).reshape(2, 4), a=np.ones(3))
y.eval()
|
|
I am lifting the constraint of |
@morganstrom in #5234 I removed the old constraint on the dimensionality of |
@morganstrom Any update on this? |
I'm sorry, have been meaning to work on this but haven't had the time or
energy. Perhaps somebody else can pick it up instead?
Den fre 7 jan. 2022 10:58Thomas Wiecki ***@***.***> skrev:
… @morganstrom <https://github.com/morganstrom> Any update on this?
—
Reply to this email directly, view it on GitHub
<#5225 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AC5ZGT3JNMNFINTGNLFXNSDUU22MXANCNFSM5IZIJY4Q>
.
Triage notifications on the go with GitHub Mobile for iOS
<https://apps.apple.com/app/apple-store/id1477376905?ct=notification-email&mt=8&pt=524675>
or Android
<https://play.google.com/store/apps/details?id=com.github.android&referrer=utm_campaign%3Dnotification-email%26utm_medium%3Demail%26utm_source%3Dgithub>.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
if (a.ndim == 1) & (n.ndim > 0): | ||
n = at.shape_padright(n) | ||
p = at.shape_padleft(p) | ||
p = a / at.sum(a, axis=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ricardoV94 Here I'm trying to mirror the changes you made in #5234
np.array([1, 10]), | ||
2, | ||
np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]), | ||
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]), # Dim: 2 x 4 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm getting an error here: TypeError: Cannot convert Type TensorType(int64, matrix) (of Variable Elemwise{Cast{int64}}.0) into Type TensorType(int64, (False, True, False))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems moment is returning a 3d tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can try to call get_moment
method directly and evaluate the returned graph, in order to facilitate debugging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I do what you suggest (in a debug console), I get a 2d array - confusing!
get_moment(model["x"]).eval()
array([[1, 0, 0, 0],
[2, 3, 3, 2]])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you using the latest Aesara version? There was a bug with shapes/casting, that was fixed recently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, so in that example the distribution shape is actually (2, 1, 4)
, but the returned moment is (2, 4)
.
pm.DirichletMultinomial.dist(
a=np.array([[26, 26, 26, 22]]), # Dim: 1 x 4
n=np.array([[1], [10]]), # Dim: 2 x 1
).eval().shape # (2, 1, 4)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I pushed some changes that I think make the mode consistent with the RV shape. Does it look sensible to you?
This was a great edge case you stumbled upon. We could very well have missed it
89bd865
to
c78b5e2
Compare
@morganstrom I rebased from main, so make sure to pull from here before doing further work |
64673c5
to
248e1d8
Compare
af05670
to
8772c0d
Compare
8772c0d
to
e6c7f91
Compare
The code is almost identical to that for the multinomial distribution (PR #5201)
Thank your for opening a PR!
Depending on what your PR does, here are a few things you might want to address in the description: