Skip to content

Do not include rvs in symbolic normalizing constant #7787

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented May 17, 2025

This is a more proper fix for the problem highlighted in #7778

The normalizing constant for MinibatchRVs included the graph of the shape of the RVs.

Even though the shape of the MinibatchRV can be derived without evaluating the draws, passing any graph with RVs to pytensorf.compile will automatically integrate the updates which requires evaluating the RV anyway. This PR makes sure we don't include the RVs only to get the symbolic normalizing constant.


📚 Documentation preview 📚: https://pymc--7787.org.readthedocs.build/en/7787/

@ricardoV94 ricardoV94 added maintenance VI Variational Inference labels May 17, 2025
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR ensures that random variables (RVs) are not included in the symbolic normalizing constant graph by folding shapes to constants, adds shape inference for minibatch RVs, includes a test for the new behavior, and fixes a small typo.

  • Use constant_fold to derive batch shapes instead of carrying RVs into symbolic_normalizing_constant
  • Implement infer_shape on MinibatchRandomVariable so shape propagation works correctly
  • Add a dedicated test (assert_no_rvs) to confirm no RVs appear in the symbolic normalizing constant
  • Correct a typo in the constant_fold comment

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 1 comment.

File Description
tests/variational/test_opvi.py Added test_symbolic_normalizing_constant_no_rvs with assert_no_rvs
pymc/variational/opvi.py Swapped direct .shape usage for constant_fold([...].shape) in scaling
pymc/variational/minibatch_rv.py Added infer_shape method to propagate shapes without evaluation
pymc/pytensorf.py Fixed typo in comment (constand_foldingconstant_folding)
Comments suppressed due to low confidence (3)

tests/variational/test_opvi.py:284

  • [nitpick] The test verifies no RVs are in the graph but doesn't assert that the symbolic normalizing constant still produces the expected scalar or tensor shape. Consider adding an assertion on the returned value or shape to guard against regressions.
def test_symbolic_normalizing_constant_no_rvs():

pymc/variational/opvi.py:1109

  • Calling constant_fold inside the list comprehension for each RV will repeatedly clone and rewrite the graph, which may be costly. Consider computing all shapes once (e.g., collect inputs, call constant_fold outside the loop) or caching results before the comprehension.
get_scaling(

pymc/variational/opvi.py:1279

  • This mirrored use of constant_fold in another list comprehension also risks redundant graph rewriting. Extract a helper or hoist the folding step to improve efficiency and reduce duplicated logic.
get_scaling(

Copy link

codecov bot commented May 17, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 92.84%. Comparing base (3a718f2) to head (0867cde).
Report is 2 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7787   +/-   ##
=======================================
  Coverage   92.84%   92.84%           
=======================================
  Files         107      107           
  Lines       18378    18380    +2     
=======================================
+ Hits        17063    17065    +2     
  Misses       1315     1315           
Files with missing lines Coverage Δ
pymc/pytensorf.py 89.76% <ø> (ø)
pymc/variational/minibatch_rv.py 100.00% <100.00%> (ø)
pymc/variational/opvi.py 86.75% <ø> (ø)
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@ricardoV94 ricardoV94 force-pushed the no_rvs_symbolic_normalizing_constant branch from 2551adf to 0867cde Compare May 18, 2025 13:30
@ricardoV94 ricardoV94 merged commit 618634b into pymc-devs:main May 18, 2025
3 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
maintenance VI Variational Inference
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants