Skip to content

Commit 49aacf4

Browse files
committed
Allow Scan logprob inference of non-pure RandomVariable outputs
Most of the IR logprob rewrites require a PreserveRVMappings feature in the fgraph. The rewrite responsible to introduce IR in the inner graph of Scan was not adding this feature. In addition, find_measurable_scans, was bailing out when there were MesurableVariable nodes that were not outputs, even if these were being used by downstream nodes as the source of measurability.
1 parent e76bba9 commit 49aacf4

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

pymc/logprob/scan.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,13 +50,20 @@
5050
from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2
5151
from pytensor.scan.utils import ScanArgs
5252
from pytensor.tensor.random.type import RandomType
53+
from pytensor.tensor.rewriting.shape import ShapeFeature
5354
from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
5455
from pytensor.tensor.var import TensorVariable
5556
from pytensor.updates import OrderedUpdates
5657

57-
from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob
58+
from pymc.logprob.abstract import (
59+
MeasurableVariable,
60+
_get_measurable_outputs,
61+
_logprob,
62+
get_measurable_outputs,
63+
)
5864
from pymc.logprob.joint_logprob import factorized_joint_logprob
5965
from pymc.logprob.rewriting import (
66+
PreserveRVMappings,
6067
inc_subtensor_ops,
6168
logprob_rewrites_db,
6269
measurable_ir_rewrites_db,
@@ -66,6 +73,9 @@
6673
class MeasurableScan(Scan):
6774
"""A placeholder used to specify a log-likelihood for a scan sub-graph."""
6875

76+
def __str__(self):
77+
return f"Measurable({super().__str__()})"
78+
6979

7080
MeasurableVariable.register(MeasurableScan)
7181

@@ -359,6 +369,12 @@ def find_measurable_scans(fgraph, node):
359369
)
360370
for n in local_fgraph_topo:
361371
if isinstance(n.op, MeasurableVariable):
372+
measurable_outputs = get_measurable_outputs(n.op, n)
373+
# This variable's source of measure is used by another inner node,
374+
# So we don't need it to be an output!
375+
if not measurable_outputs:
376+
continue
377+
362378
non_output_node_clients = [
363379
c for c in clients[n] if c not in curr_scanargs.inner_outputs
364380
]
@@ -494,6 +510,10 @@ def add_opts_to_inner_graphs(fgraph, node):
494510
clone=True,
495511
copy_inputs=False,
496512
copy_orphans=False,
513+
features=[
514+
ShapeFeature(),
515+
PreserveRVMappings({}),
516+
],
497517
)
498518

499519
logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite(inner_fgraph)

tests/logprob/test_scan.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@
4242
from pytensor import Mode
4343
from pytensor.raise_op import assert_op
4444
from pytensor.scan.utils import ScanArgs
45+
from scipy import stats
4546

4647
from pymc.logprob.abstract import logprob
47-
from pymc.logprob.joint_logprob import factorized_joint_logprob
48+
from pymc.logprob.joint_logprob import factorized_joint_logprob, logp
4849
from pymc.logprob.scan import (
4950
construct_scan,
5051
convert_outer_out_to_in,
@@ -458,3 +459,22 @@ def test_mode_is_kept(remove_asserts):
458459
else:
459460
with pytest.raises(AssertionError):
460461
x_logp(x=x_test_val)
462+
463+
464+
def test_scan_non_pure_rv_output():
465+
grw, _ = pytensor.scan(
466+
fn=lambda xtm1: at.random.normal() + xtm1,
467+
outputs_info=[at.zeros(())],
468+
n_steps=10,
469+
name="grw",
470+
)
471+
472+
grw_vv = grw.clone()
473+
grw_logp = logp(grw, grw_vv)
474+
assert_no_rvs(grw_logp)
475+
476+
grw_vv_test = np.arange(10) + 1
477+
np.testing.assert_array_almost_equal(
478+
grw_logp.eval({grw_vv: grw_vv_test}),
479+
stats.norm.logpdf(np.ones(10)),
480+
)

0 commit comments

Comments
 (0)