-
Notifications
You must be signed in to change notification settings - Fork 133
HalfNormal in JAX failing due to implicit downcasting of constant 0d TensorVariable to float #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Were calling Probably has to do with how in JAX we don't respect the scalar vs 0d array case (does JAX even allow it)? |
Something like this also showed up in #372 |
The solution is to either revert 4235ccc or tweak the JAX The source of the problem is actually silly, we have a The compiled graph looks like Add [id A] <Scalar(float64, shape=())> 3
├─ Abs [id B] <Scalar(float64, shape=())> 2
│ └─ normal_rv{0, (0, 0), floatX, False}.1 [id C] <Scalar(float64, shape=())> 1
│ ├─ RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7F582FC4A820>) [id D] <RandomGeneratorType>
│ ├─ [] [id E] <Vector(int64, shape=(0,))>
│ ├─ 11 [id F] <Scalar(int64, shape=())>
│ ├─ Second [id G] <Scalar(int8, shape=())> 0
│ │ ├─ 0 [id H] <Scalar(int8, shape=())>
│ │ └─ 0 [id I] <Scalar(int8, shape=())>
│ └─ 1 [id J] <Scalar(int8, shape=())>
└─ 0 [id H] <Scalar(int8, shape=())>
Just not spinning up a PR immediately because we should probably discuss whether we want to implicitly downcast constant 0d array to float/integers in JAX or keep types consistent. We could always tweak the specific Op dispatch functions to handle Constant TensorScalarVariable in a special way. This touches on a more general question of handling scalars in our graphs, that also applies to other backends. See #107 and #349 |
mode="JAX"
Uh oh!
There was an error while loading. Please reload this page.
Describe the issue:
You can't forward sample from a half-normal distribution in JAX mode
Reproduceable code example:
Error message:
PyMC version information:
Context for the issue:
No response
The text was updated successfully, but these errors were encountered: