Skip to content

Misc RandomVariable improvements #79

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 4, 2022

Conversation

ricardoV94
Copy link
Member

This PR improves a couple of edge issues related to RandomVariable static shape inference, and simplifies and extends the local_dimshuffle_rv_lift rewrite.

Allowing the rewrite to apply in more cases increases the range of graphs we can infer the logprob in PyMC: https://github.com/pymc-devs/pymc/blob/a0d6ba079eac2f044ed40cc5747f1079d99f9f16/pymc/logprob/tensor.py#L288-L290

Closes #60 and makes progress related to #49

@ricardoV94 ricardoV94 added bug Something isn't working enhancement New feature or request graph rewriting random variables labels Dec 2, 2022
This allows PyTensor to infer more broadcastable patterns, by placing the casting inside the MakeVector Op
@ricardoV94 ricardoV94 force-pushed the fix_rv_ds_lift branch 2 times, most recently from a18e46a to 128531f Compare December 2, 2022 15:39
@michaelosthege
Copy link
Member

@ricardoV94 I unsubscribed - ping me when CI is green

@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 2, 2022

@ricardoV94 I unsubscribed - ping me when CI is green

99.2% confident it will be green this run @michaelosthege

@codecov-commenter
Copy link

codecov-commenter commented Dec 2, 2022

Codecov Report

Merging #79 (e59ddb7) into main (491f93e) will increase coverage by 0.03%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main      #79      +/-   ##
==========================================
+ Coverage   74.22%   74.26%   +0.03%     
==========================================
  Files         174      175       +1     
  Lines       48734    48929     +195     
  Branches    10367    10395      +28     
==========================================
+ Hits        36175    36335     +160     
- Misses      10272    10291      +19     
- Partials     2287     2303      +16     
Impacted Files Coverage Δ
pytensor/tensor/elemwise.py 88.08% <100.00%> (+0.01%) ⬆️
pytensor/tensor/random/basic.py 99.03% <100.00%> (+<0.01%) ⬆️
pytensor/tensor/random/op.py 97.46% <100.00%> (+0.06%) ⬆️
pytensor/tensor/random/rewriting.py 93.54% <100.00%> (-0.62%) ⬇️
pytensor/tensor/random/utils.py 100.00% <100.00%> (ø)
pytensor/link/numba/dispatch/extra_ops.py 92.24% <0.00%> (-5.77%) ⬇️
pytensor/link/numba/dispatch/basic.py 90.06% <0.00%> (-2.62%) ⬇️
pytensor/sparse/sandbox/sp.py 73.48% <0.00%> (ø)
pytensor/link/numba/dispatch/nlinalg.py 100.00% <0.00%> (ø)
pytensor/link/numba/dispatch/cython_support.py 86.95% <0.00%> (ø)
... and 2 more

@@ -319,8 +319,23 @@ def make_node(self, rng, size, dtype, *dist_params):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)

# Fail early when size is incompatible with parameters
size_len = get_vector_length(size)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps I should move this to infer_shape. If for some reason an RV needs different batching semantics it can override it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you come up with an example?
one of the zero-sum RVs maybe?

if not I'd say apply the YAGNI rule

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was already a case with the ChoiceRV and PermutationRVs which need special handling for the shape

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can't say I understand everything, but I noticed nothing unplausible either.

@@ -319,8 +319,23 @@ def make_node(self, rng, size, dtype, *dist_params):
"The type of rng should be an instance of either RandomGeneratorType or RandomStateType"
)

# Fail early when size is incompatible with parameters
size_len = get_vector_length(size)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you come up with an example?
one of the zero-sum RVs maybe?

if not I'd say apply the YAGNI rule

@ricardoV94 ricardoV94 marked this pull request as draft December 3, 2022 13:39
* The rewrite no longer bails out when dimshuffle affects both unique param dimensions and repeated param dimensions from the size argument. This requires:
  1) Adding broadcastable dimensions to the parameters, which should be "cost-free" and would need to be done in the `perform` method anyway.
  2) Extend size to incorporate implicit batch dimensions coming from the parameters. This requires computing the shape resulting from broadcasting the parameters. It's unclear whether this is less performant, because the `perform` method can now simply broadcast each parameter to the size, instead of having to broadcast the parameters together.
* The rewrite now works with Multivariate RVs
* The rewrite bails out when dimensions are dropped by the Dimshuffle. This case was not correctly handled by the previous rewrite
@ricardoV94 ricardoV94 marked this pull request as ready for review December 4, 2022 11:01
@ricardoV94 ricardoV94 merged commit 2ebfbf1 into pymc-devs:main Dec 4, 2022
@ricardoV94 ricardoV94 deleted the fix_rv_ds_lift branch January 20, 2023 13:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working enhancement New feature or request graph rewriting random variables
Projects
None yet
Development

Successfully merging this pull request may close these issues.

local_dimshuffle_rv_lift fails in some cases
3 participants