Skip to content

Commit 4d0360c

Browse files
committed
Fix nested replacement of useless IR conversions
1 parent 8c93bb5 commit 4d0360c

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

pymc/logprob/rewriting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def construct_ir_fgraph(
347347

348348
if rv_remapper.measurable_conversions:
349349
# Undo un-valued measurable IR rewrites
350-
new_to_old = tuple((v, k) for k, v in rv_remapper.measurable_conversions.items())
350+
new_to_old = tuple((v, k) for k, v in reversed(rv_remapper.measurable_conversions.items()))
351351
fgraph.replace_all(new_to_old)
352352

353353
return fgraph, rv_values, memo

tests/logprob/test_composite_logprob.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,11 @@
3737
import numpy as np
3838
import pytensor
3939
import pytensor.tensor as pt
40+
import pytest
4041
import scipy.stats as st
4142

4243
from pymc import logp
44+
from pymc.logprob.abstract import MeasurableVariable
4345
from pymc.logprob.basic import factorized_joint_logprob
4446
from pymc.logprob.censoring import MeasurableClip
4547
from pymc.logprob.rewriting import construct_ir_fgraph
@@ -121,10 +123,13 @@ def test_nested_scalar_mixtures():
121123
assert np.isclose(logp_fn(0, 0, 1, 50), st.norm.logpdf(150) + np.log(0.5) * 3)
122124

123125

124-
def test_unvalued_ir_reversion():
126+
@pytest.mark.parametrize("nested", (False, True))
127+
def test_unvalued_ir_reversion(nested):
125128
"""Make sure that un-valued IR rewrites are reverted."""
126129
x_rv = pt.random.normal()
127130
y_rv = pt.clip(x_rv, 0, 1)
131+
if nested:
132+
y_rv = y_rv + 5
128133
z_rv = pt.random.normal(y_rv, 1, name="z")
129134
z_vv = z_rv.clone()
130135

@@ -134,14 +139,10 @@ def test_unvalued_ir_reversion():
134139

135140
z_fgraph, _, memo = construct_ir_fgraph(rv_values)
136141

137-
assert memo[y_rv] in z_fgraph.preserve_rv_mappings.measurable_conversions
138-
139-
measurable_y_rv = z_fgraph.preserve_rv_mappings.measurable_conversions[memo[y_rv]]
140-
assert isinstance(measurable_y_rv.owner.op, MeasurableClip)
141-
142-
# `construct_ir_fgraph` should've reverted the un-valued measurable IR
143-
# change
144-
assert measurable_y_rv not in z_fgraph
142+
assert len(z_fgraph.preserve_rv_mappings.measurable_conversions) == 1 + nested
143+
assert (
144+
sum(isinstance(node.op, MeasurableVariable) for node in z_fgraph.apply_nodes) == 2
145+
) # Just the 2 rvs
145146

146147

147148
def test_shifted_cumsum():

0 commit comments

Comments
 (0)