37
37
import numpy as np
38
38
import pytensor
39
39
import pytensor .tensor as pt
40
+ import pytest
40
41
import scipy .stats as st
41
42
42
43
from pymc import logp
44
+ from pymc .logprob .abstract import MeasurableVariable
43
45
from pymc .logprob .basic import factorized_joint_logprob
44
46
from pymc .logprob .censoring import MeasurableClip
45
47
from pymc .logprob .rewriting import construct_ir_fgraph
@@ -121,10 +123,13 @@ def test_nested_scalar_mixtures():
121
123
assert np .isclose (logp_fn (0 , 0 , 1 , 50 ), st .norm .logpdf (150 ) + np .log (0.5 ) * 3 )
122
124
123
125
124
- def test_unvalued_ir_reversion ():
126
+ @pytest .mark .parametrize ("nested" , (False , True ))
127
+ def test_unvalued_ir_reversion (nested ):
125
128
"""Make sure that un-valued IR rewrites are reverted."""
126
129
x_rv = pt .random .normal ()
127
130
y_rv = pt .clip (x_rv , 0 , 1 )
131
+ if nested :
132
+ y_rv = y_rv + 5
128
133
z_rv = pt .random .normal (y_rv , 1 , name = "z" )
129
134
z_vv = z_rv .clone ()
130
135
@@ -134,14 +139,10 @@ def test_unvalued_ir_reversion():
134
139
135
140
z_fgraph , _ , memo = construct_ir_fgraph (rv_values )
136
141
137
- assert memo [y_rv ] in z_fgraph .preserve_rv_mappings .measurable_conversions
138
-
139
- measurable_y_rv = z_fgraph .preserve_rv_mappings .measurable_conversions [memo [y_rv ]]
140
- assert isinstance (measurable_y_rv .owner .op , MeasurableClip )
141
-
142
- # `construct_ir_fgraph` should've reverted the un-valued measurable IR
143
- # change
144
- assert measurable_y_rv not in z_fgraph
142
+ assert len (z_fgraph .preserve_rv_mappings .measurable_conversions ) == 1 + nested
143
+ assert (
144
+ sum (isinstance (node .op , MeasurableVariable ) for node in z_fgraph .apply_nodes ) == 2
145
+ ) # Just the 2 rvs
145
146
146
147
147
148
def test_shifted_cumsum ():
0 commit comments