-
Notifications
You must be signed in to change notification settings - Fork 133
Misc RandomVariable improvements #79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This allows PyTensor to infer more broadcastable patterns, by placing the casting inside the MakeVector Op
a18e46a
to
128531f
Compare
@ricardoV94 I unsubscribed - ping me when CI is green |
99.2% confident it will be green this run @michaelosthege |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #79 +/- ##
==========================================
+ Coverage 74.22% 74.26% +0.03%
==========================================
Files 174 175 +1
Lines 48734 48929 +195
Branches 10367 10395 +28
==========================================
+ Hits 36175 36335 +160
- Misses 10272 10291 +19
- Partials 2287 2303 +16
|
pytensor/tensor/random/op.py
Outdated
@@ -319,8 +319,23 @@ def make_node(self, rng, size, dtype, *dist_params): | |||
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType" | |||
) | |||
|
|||
# Fail early when size is incompatible with parameters | |||
size_len = get_vector_length(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps I should move this to infer_shape
. If for some reason an RV needs different batching semantics it can override it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you come up with an example?
one of the zero-sum RVs maybe?
if not I'd say apply the YAGNI rule
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was already a case with the ChoiceRV and PermutationRVs which need special handling for the shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't say I understand everything, but I noticed nothing unplausible either.
pytensor/tensor/random/op.py
Outdated
@@ -319,8 +319,23 @@ def make_node(self, rng, size, dtype, *dist_params): | |||
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType" | |||
) | |||
|
|||
# Fail early when size is incompatible with parameters | |||
size_len = get_vector_length(size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you come up with an example?
one of the zero-sum RVs maybe?
if not I'd say apply the YAGNI rule
… parameters dimensionality
* The rewrite no longer bails out when dimshuffle affects both unique param dimensions and repeated param dimensions from the size argument. This requires: 1) Adding broadcastable dimensions to the parameters, which should be "cost-free" and would need to be done in the `perform` method anyway. 2) Extend size to incorporate implicit batch dimensions coming from the parameters. This requires computing the shape resulting from broadcasting the parameters. It's unclear whether this is less performant, because the `perform` method can now simply broadcast each parameter to the size, instead of having to broadcast the parameters together. * The rewrite now works with Multivariate RVs * The rewrite bails out when dimensions are dropped by the Dimshuffle. This case was not correctly handled by the previous rewrite
128531f
to
e59ddb7
Compare
This PR improves a couple of edge issues related to RandomVariable static shape inference, and simplifies and extends the
local_dimshuffle_rv_lift rewrite
.Allowing the rewrite to apply in more cases increases the range of graphs we can infer the logprob in PyMC: https://github.com/pymc-devs/pymc/blob/a0d6ba079eac2f044ed40cc5747f1079d99f9f16/pymc/logprob/tensor.py#L288-L290
Closes #60 and makes progress related to #49