-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Implemented logprob for SpecifyShape and CheckandRaise #6538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #6538 +/- ##
==========================================
- Coverage 92.02% 87.06% -4.96%
==========================================
Files 93 94 +1
Lines 15752 15804 +52
==========================================
- Hits 14495 13759 -736
- Misses 1257 2045 +788
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Just some renaming suggestion.
More importantly, we need tests :)
pymc/logprob/check_raise_assert.py
Outdated
from pymc.logprob.rewriting import PreserveRVMappings, measurable_ir_rewrites_db | ||
|
||
|
||
class MeasurableAssert(CheckAndRaise): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use the same name as the original class everywhere, including in the other functions in this file
class MeasurableAssert(CheckAndRaise): | |
class MeasurableCheckAndRaise(CheckAndRaise): |
And let's call this file raise_op
like the pytensor one
Hey @ricardoV94 I did try including test but I think I'll need some help. From what I understand, we'll just have to import the function |
Usually we test a bit more high-level.
Something like this for the SpecifyShape (untested code): import numpy as np
import pytensor
import pytensor.tensor as pt
from scipy import stats
from pymc.distributions import Dirichlet
from pymc.logprob.joint_logprob import factorized_joint_logprob
def test_specify_shape_logprob():
# 1. Create graph using SpecifyShape
# Use symbolic last dimension, so that SpecifyShape is not useless
last_dim = pt.scalar(name="last_dim", dtype="int64")
x_base = Dirichlet.dist(pt.ones((last_dim,)), shape=(5, last_dim)),
x_rv = pt.specify_shape(x_base, shape=(5, 3))
x_rv.name = "x"
# 2. Request logp
x_vv = x_rv.clone()
[x_logp] = factorized_joint_logprob({x_rv: x_vv}).values()
# 3. Test logp
x_logp_fn = pytensor.function([last_dim, x_vv], x_logp)
# 3.1 Test valid logp
x_vv_test = stats.dirichlet(np.ones((3,))).rvs(size=(5,))
np.testing.assert_array_almost_equal(
x_logp_fn(last_dim=3, x=x_vv_test),
stats.dirichlet(np.ones((3,))).logpdf(x_vv_test),
)
# 3.2 Test shape error
x_vv_test_invalid = stats.dirichlet(np.ones((1,))).rvs(size=(5,))
with pytest.raises(ValueError, match=...):
x_logp_fn(last_dim=1, x=x_vv_test_invalid) |
Hey @ricardoV94 I tried generating the tests referring the above code block. But the |
Can you include the code you tried in this Pull Request? Then I will be able to replicate and provide more direct advice |
Yeah sure! I was just experimenting and trying to understand the same code as above hence did not make any major changes.
|
@Dhruvanshu-Joshi There were some issues in I also needed to tweak the test. Let me know if the changes help to understand what was wrong, and if you want to proceed with testing (and possibly fixing remaining errors) on the measurable assert part. |
Hey @ricardoV94 I have made changes to the
I'll need a little help in |
As a result I tried to run this in a google colab cell and I encounter an error in line |
@Dhruvanshu-Joshi The assert should be used in the random graph, not in the logp (just like the SpecifyShape): rv = at.random.normal()
assert_op = Assert("Test assert")
# Example: Add assert that rv must be positive
assert_rv = assert_op(rv > 0, rv)
assert_rv.name = "assert_rv"
assert_vv = assert_rv.clone()
assert_logp = factorized_joint_logp({assert_rv: assert_vv})[assert_vv]
# TODO: Check valid value is correct and doesn't raise
# Check invalid value
with pytest.raises(AssertionError, match="Test assert"):
assert_logp.eval({assert_vv: -5.0) |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
Hey @ricardoV94 . Added the tests and made the changes. Because the folder tests was move outside the |
67a4cad
to
74e5067
Compare
Hey @ricardoV94 . The previous commit had some issues related to the mypy tests which are now rectified. Also made some changes in the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Some things:
- You are adding two binary files accidentaly:
.coverage
andpymc/.model.py.swn
that must be removed from the PR - Let's merge the two functions in
pymc/logprob/checks.py
, and the tests intests/logprob/test_checks.py
I think that's more intuitive. - You need to add the new test file here:
pymc/.github/workflows/tests.yml
Lines 92 to 102 in 473c952
tests/logprob/test_abstract.py tests/logprob/test_censoring.py tests/logprob/test_composite_logprob.py tests/logprob/test_cumsum.py tests/logprob/test_joint_logprob.py tests/logprob/test_mixture.py tests/logprob/test_rewriting.py tests/logprob/test_scan.py tests/logprob/test_tensor.py tests/logprob/test_transforms.py tests/logprob/test_utils.py
If the tests pass, that should be it!
Hey @ricardoV94 . I have incorporated all your suggested changes in the latest commit. Hope this solves the issue. Also I have noticed that ever since tests was moved outside the pymc package, in every program which inherits any instance from tests, the import
has been replaced with
Although this does not cause any problem when running the test_suites using pytest, runnning them naively using
I understand this is not a big problem and the solution is a little lengthy. However if you are interested, I would like to explore more methods. |
I think the tests are supposed to be run from the root repository folder. Then you shouldn't have issue with the imports? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. I left a couple of small suggestions below.
Also note that you seem to have picked/reverted a couple of changes that don't belong to this PR when you merged from the main branch (see the files changed tab). Those have to be cleaned up, before we can merge this PR
pymc/logprob/checks.py
Outdated
if not (isinstance(node.op, SpecifyShape)): | ||
return None # pragma: no cover |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't needed. The decorator already makes sure the rewrite is only ever called for nodes with Op of the right kind
if not (isinstance(node.op, SpecifyShape)): | |
return None # pragma: no cover |
pymc/logprob/checks.py
Outdated
) | ||
|
||
|
||
class MeasurableAssert(CheckAndRaise): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class MeasurableAssert(CheckAndRaise): | |
class MeasurableCheckAndRaiseCheckAndRaise): |
pymc/logprob/checks.py
Outdated
if not (isinstance(node.op, CheckAndRaise)): | ||
return None # pragma: no cover |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also not needed
if not (isinstance(node.op, CheckAndRaise)): | |
return None # pragma: no cover |
Even from the root repository if we run the simple command |
Hey @ricardoV94 I have included your suggestions in the latest commit. Sorry for the delay as I got caught up in exams and am currently working on my GSOC proposal. |
@Dhruvanshu-Joshi Your PR still shows changes from main that don't belong here. This can happen when you have lot's of incremental commits and try to merge main, github is not very helpful at showing what changes are actually yours or from main. It's usually easier to work if you keep your commit history clean. When doing incremental work just squash your related commits together and force-push. Ideally your final commits will look the same as if you had started from scratch knowing the final solution. I would also advise to rebase from main instead of merging when trying to sync your branch, but that's optional. https://stackoverflow.com/questions/71074242/github-old-commits-shows-up-in-new-pull-request |
6d5e5ac
to
86ca5cb
Compare
tests/logprob/test_checks.py
Outdated
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | ||
parentdir = os.path.dirname(currentdir) | ||
sys.path.insert(0, parentdir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be removed. For running the test locally, just make sure you are running from the root folder I think
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | |
parentdir = os.path.dirname(currentdir) | |
sys.path.insert(0, parentdir) |
@ricardoV94 I have included this change. Seems like this time only the commit with the actual change shows up by following the steps you provided. Thank you. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great!
Thanks @Dhruvanshu-Joshi ! |
What is this PR about?
This PR closes 6352.
I have implemented
logprob
forSpecifyShape
andCheckandRaise
and also included a rewrite that converts a SpecifyShape orCheckandRaise
of aMeasurableVariable
into aMeasurableSpecifyShape/Assert Op
.I have also modified the
__init__
file to be consistent with the modifications.Checklist
Major / Breaking Changes
New features
logprob
forSpecifyShape
by transferingspecify_shape
fromrv
to value and included a rewrite that converts aSpecifyShape
of aMeasurableVariable
into aMeasurableSpecifyShape Op
.logprob
forCheckandRaise
. If the assertion is true, the value variable will retain its value and if its false, aValueError
will be raised. Also have included a rewrite that converts aCheckandRaise
of aMeasurableVariable
into aMeasurableAssert Op
.Bugfixes
Documentation
Maintenance