Skip to content

Implemented logprob for SpecifyShape and CheckandRaise #6538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 201 commits into from
Mar 29, 2023

Conversation

Dhruvanshu-Joshi
Copy link
Member

@Dhruvanshu-Joshi Dhruvanshu-Joshi commented Feb 20, 2023

What is this PR about?
This PR closes 6352.
I have implemented logprob for SpecifyShape and CheckandRaise and also included a rewrite that converts a SpecifyShape or CheckandRaise of a MeasurableVariable into a MeasurableSpecifyShape/Assert Op.
I have also modified the __init__ file to be consistent with the modifications.

Checklist

Major / Breaking Changes

  • no

New features

  • Created a file to implement logprob for SpecifyShape by transfering specify_shape from rv to value and included a rewrite that converts a SpecifyShape of a MeasurableVariable into a MeasurableSpecifyShape Op.
  • Created a file to implement logprob for CheckandRaise. If the assertion is true, the value variable will retain its value and if its false, a ValueError will be raised. Also have included a rewrite that converts a CheckandRaise of a MeasurableVariable into a MeasurableAssert Op.

Bugfixes

  • no

Documentation

  • Documentation has been updated consistent with modifications

Maintenance

  • no

@codecov
Copy link

codecov bot commented Feb 20, 2023

Codecov Report

Merging #6538 (54eb767) into main (f3ce16f) will decrease coverage by 4.96%.
The diff coverage is 67.30%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6538      +/-   ##
==========================================
- Coverage   92.02%   87.06%   -4.96%     
==========================================
  Files          93       94       +1     
  Lines       15752    15804      +52     
==========================================
- Hits        14495    13759     -736     
- Misses       1257     2045     +788     
Impacted Files Coverage Δ
pymc/logprob/checks.py 66.66% <66.66%> (ø)
pymc/logprob/__init__.py 100.00% <100.00%> (ø)

... and 12 files with indirect coverage changes

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Just some renaming suggestion.

More importantly, we need tests :)

from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db


class MeasurableAssert(CheckAndRaise):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use the same name as the original class everywhere, including in the other functions in this file

Suggested change
class MeasurableAssert(CheckAndRaise):
class MeasurableCheckAndRaise(CheckAndRaise):

And let's call this file raise_op like the pytensor one

@Dhruvanshu-Joshi
Copy link
Member Author

Hey @ricardoV94 I did try including test but I think I'll need some help. From what I understand, we'll just have to import the function logprob_specify_shape into the test file and use it on some op and inner_rv to specify its shape and generate the logprob. Lastly, we'll need a assert to verify if the log-likelihood for this and expected matches or not. Somehow, I am not able to implement this. I did try referring test_cumsum but got more confused. Can you provide some sample examples/resources to get me started in the right direction. Thanks!

@ricardoV94
Copy link
Member

ricardoV94 commented Feb 21, 2023

Usually we test a bit more high-level.

  1. Create the generative graph which uses the new Op that we support the logprob for now
  2. Obtain the logprob
  3. Test the logprob

Something like this for the SpecifyShape (untested code):

import numpy as np
import pytensor
import pytensor.tensor as pt
from scipy import stats

from pymc.distributions import Dirichlet
from pymc.logprob.joint_logprob import factorized_joint_logprob

def test_specify_shape_logprob():
  # 1. Create graph using SpecifyShape
  # Use symbolic last dimension, so that SpecifyShape is not useless
  last_dim = pt.scalar(name="last_dim", dtype="int64") 
  x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim)),
  x_rv = pt.specify_shape(x_base, shape=(5, 3))
  x_rv.name = "x"

  # 2. Request logp
  x_vv = x_rv.clone()
  [x_logp] = factorized_joint_logprob({x_rv: x_vv}).values()

  # 3. Test logp
  x_logp_fn = pytensor.function([last_dim, x_vv], x_logp)

  # 3.1 Test valid logp
  x_vv_test = stats.dirichlet(np.ones((3,))).rvs(size=(5,))
  np.testing.assert_array_almost_equal(
    x_logp_fn(last_dim=3, x=x_vv_test), 
    stats.dirichlet(np.ones((3,))).logpdf(x_vv_test),
  )
  
  # 3.2 Test shape error
  x_vv_test_invalid = stats.dirichlet(np.ones((1,))).rvs(size=(5,))
  with pytest.raises(ValueError, match=...):
    x_logp_fn(last_dim=1, x=x_vv_test_invalid)

@Dhruvanshu-Joshi
Copy link
Member Author

Hey @ricardoV94 I tried generating the tests referring the above code block. But the pt.specify_shape function is causing error when requesting the log-prob using factorized_joint_logprob. The problem is that the output in the node and the updated_rv from the map do not match. This is not a problem if specify_shape is not used. Can you please help me with this?

@ricardoV94
Copy link
Member

Can you include the code you tried in this Pull Request? Then I will be able to replicate and provide more direct advice

@Dhruvanshu-Joshi
Copy link
Member Author

Yeah sure! I was just experimenting and trying to understand the same code as above hence did not make any major changes.

import numpy as np
import pytensor
import pytensor.tensor as pt
from scipy import stats
import pymc as pm

from pymc.distributions import Dirichlet
from pymc.logprob.joint_logprob import factorized_joint_logprob

def test_specify_shape_logprob():
  # 1. Create graph using SpecifyShape
  # Use symbolic last dimension, so that SpecifyShape is not useless
  last_dim = pt.scalar(name="last_dim", dtype="int64") 
  x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim))
  x_base.name = "x"
  x_rv = pt.specify_shape(x_base, shape=(5, 3))
  x_rv.name = "x"

  # 2. Request logp
  x_vv = x_rv.clone()
  [x_logp] = factorized_joint_logprob({x_rv : x_vv}).values()

  # 3. Test logp
  x_logp_fn = pytensor.function([x_vv], x_logp)

  # 3.1 Test valid logp
  x_vv_test = stats.dirichlet(np.ones((3,))).rvs(size=(5,))
  np.testing.assert_array_almost_equal(
    x_logp_fn(last_dim=3, x=x_vv_test), 
    stats.dirichlet(np.ones((3,))).logpdf(x_vv_test),
  )
  
  # 3.2 Test shape error
  x_vv_test_invalid = stats.dirichlet(np.ones((1,))).rvs(size=(5,))
  with pytest.raises(ValueError, match=...):
    x_logp_fn(last_dim=1, x=x_vv_test_invalid)
if __name__=="__main__":
    test_specify_shape_logprob()
    

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 6, 2023

@Dhruvanshu-Joshi There were some issues in find_measurable_specify_shapes. I found them by going into the interactive debugger and seeing which lines failed. I pushed a commit that fixes it.

I also needed to tweak the test.

Let me know if the changes help to understand what was wrong, and if you want to proceed with testing (and possibly fixing remaining errors) on the measurable assert part.

@Dhruvanshu-Joshi
Copy link
Member Author

Hey @ricardoV94 I have made changes to the check_raise_assert.py in accordance with the commit . I am also in process of creating a test for it as done in the above commit. This is the skeleton of what I think must be done here:

import re
import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest

from scipy import stats

from pymc.distributions import Dirichlet
from pymc.logprob.joint_logprob import factorized_joint_logprob
from tests.distributions.test_multivariate import dirichlet_logpdf


def test_check_raise_assert():
    # 1. Create graph using SpecifyShape
    # Use symbolic last dimension, so that SpecifyShape is not useless
    last_dim = pt.scalar(name="last_dim", dtype="int64")
    x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim))
    x_base.name = "x"
    assert_op = Assert("This assert failed")
    x_rv = x_base.clone()
    x_rv.name = "x"

    # 2. Request logp
    x_vv = x_rv.clone()
    [x_logp] = factorized_joint_logprob({x_rv: x_vv}).values()

    # 3. Test logp
    x_logp_fn =  pytensor.function([x_logp], assert_op(x_logp, x_logp.size < 2)

I'll need a little help in 3.1 Test valid logp and 3.2 Test shape error part of the code .

@Dhruvanshu-Joshi
Copy link
Member Author

As a result I tried to run this in a google colab cell and I encounter an error in line [x_logp] = factorized_joint_logprob({x_rv: x_vv}).values() stating The logprob terms of the following value variables could not be derived: {x}

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 9, 2023

@Dhruvanshu-Joshi The assert should be used in the random graph, not in the logp (just like the SpecifyShape):

rv = at.random.normal()
assert_op = Assert("Test assert")
# Example: Add assert that rv must be positive
assert_rv = assert_op(rv > 0, rv)
assert_rv.name = "assert_rv"

assert_vv = assert_rv.clone()
assert_logp = factorized_joint_logp({assert_rv:  assert_vv})[assert_vv]

# TODO: Check valid value is correct and doesn't raise

# Check invalid value
with pytest.raises(AssertionError, match="Test assert"):
  assert_logp.eval({assert_vv: -5.0)

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@Dhruvanshu-Joshi
Copy link
Member Author

Hey @ricardoV94 . Added the tests and made the changes. Because the folder tests was move outside the pymc package we import dirichlet_logpdf which is defined in tests.distributions.test_multivariate using the fact that tests.logprob in which our tests are defined and tests.distributions lie in the same parent directory.

@Dhruvanshu-Joshi
Copy link
Member Author

Hey @ricardoV94 . The previous commit had some issues related to the mypy tests which are now rectified. Also made some changes in the find_measurable_asserts and have included the tests too. Please review.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Some things:

  1. You are adding two binary files accidentaly: .coverage and pymc/.model.py.swn that must be removed from the PR
  2. Let's merge the two functions in pymc/logprob/checks.py, and the tests in tests/logprob/test_checks.py I think that's more intuitive.
  3. You need to add the new test file here:
    tests/logprob/test_abstract.py
    tests/logprob/test_censoring.py
    tests/logprob/test_composite_logprob.py
    tests/logprob/test_cumsum.py
    tests/logprob/test_joint_logprob.py
    tests/logprob/test_mixture.py
    tests/logprob/test_rewriting.py
    tests/logprob/test_scan.py
    tests/logprob/test_tensor.py
    tests/logprob/test_transforms.py
    tests/logprob/test_utils.py

If the tests pass, that should be it!

@Dhruvanshu-Joshi
Copy link
Member Author

Hey @ricardoV94 . I have incorporated all your suggested changes in the latest commit. Hope this solves the issue.

Also I have noticed that ever since tests was moved outside the pymc package, in every program which inherits any instance from tests, the import from pymc.tests import xyz has been simply replaced withfrom tests import xyz. Eg in tests\logprob\test_composite_logprob.py```:

from pymc.logprob.censoring import MeasurableClip
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.testing import assert_no_rvs
from pymc.tests.logprob.utils import joint_logprob

has been replaced with

from pymc.logprob.censoring import MeasurableClip
from pymc.logprob.rewriting import construct_ir_fgraph
from pymc.testing import assert_no_rvs
from tests.logprob.utils import joint_logprob

Although this does not cause any problem when running the test_suites using pytest, runnning them naively using python test_composite_logprob.py causes import issues
As a solution I have used

currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
from distributions.test_multivariate import dirichlet_logpdf

I understand this is not a big problem and the solution is a little lengthy. However if you are interested, I would like to explore more methods.
What are your views on including a related solution everywhere where tests is called?

@ricardoV94
Copy link
Member

I think the tests are supposed to be run from the root repository folder. Then you shouldn't have issue with the imports?

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. I left a couple of small suggestions below.

Also note that you seem to have picked/reverted a couple of changes that don't belong to this PR when you merged from the main branch (see the files changed tab). Those have to be cleaned up, before we can merge this PR

Comment on lines 69 to 70
if not (isinstance(node.op, SpecifyShape)):
return None # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't needed. The decorator already makes sure the rewrite is only ever called for nodes with Op of the right kind

Suggested change
if not (isinstance(node.op, SpecifyShape)):
return None # pragma: no cover

)


class MeasurableAssert(CheckAndRaise):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class MeasurableAssert(CheckAndRaise):
class MeasurableCheckAndRaiseCheckAndRaise):

Comment on lines 127 to 128
if not (isinstance(node.op, CheckAndRaise)):
return None # pragma: no cover
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not needed

Suggested change
if not (isinstance(node.op, CheckAndRaise)):
return None # pragma: no cover

@Dhruvanshu-Joshi
Copy link
Member Author

I think the tests are supposed to be run from the root repository folder. Then you shouldn't have issue with the imports?

Even from the root repository if we run the simple command python tests/test_model.py we face import errors. Running pytest -v tests/test_model.py gives no error and works smoothly.

@Dhruvanshu-Joshi
Copy link
Member Author

Hey @ricardoV94 I have included your suggestions in the latest commit. Sorry for the delay as I got caught up in exams and am currently working on my GSOC proposal.
About file changes from other PRs, I have cross-checked and all these changes have been merged into main already and none seem to revert any changes.

@ricardoV94
Copy link
Member

ricardoV94 commented Mar 29, 2023

@Dhruvanshu-Joshi Your PR still shows changes from main that don't belong here. This can happen when you have lot's of incremental commits and try to merge main, github is not very helpful at showing what changes are actually yours or from main.

It's usually easier to work if you keep your commit history clean. When doing incremental work just squash your related commits together and force-push. Ideally your final commits will look the same as if you had started from scratch knowing the final solution.

I would also advise to rebase from main instead of merging when trying to sync your branch, but that's optional.

https://stackoverflow.com/questions/71074242/github-old-commits-shows-up-in-new-pull-request

Comment on lines 53 to 55
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed. For running the test locally, just make sure you are running from the root folder I think

Suggested change
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0, parentdir)

@Dhruvanshu-Joshi
Copy link
Member Author

@ricardoV94 I have included this change. Seems like this time only the commit with the actual change shows up by following the steps you provided. Thank you.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great!

@ricardoV94 ricardoV94 merged commit ae9fcac into pymc-devs:main Mar 29, 2023
@ricardoV94
Copy link
Member

Thanks @Dhruvanshu-Joshi !

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement logprob for SpecifyShape and CheckAndRaise
2 participants