|
50 | 50 | from pytensor.scan.rewriting import scan_eqopt1, scan_eqopt2
|
51 | 51 | from pytensor.scan.utils import ScanArgs
|
52 | 52 | from pytensor.tensor.random.type import RandomType
|
| 53 | +from pytensor.tensor.rewriting.shape import ShapeFeature |
53 | 54 | from pytensor.tensor.subtensor import Subtensor, indices_from_subtensor
|
54 | 55 | from pytensor.tensor.var import TensorVariable
|
55 | 56 | from pytensor.updates import OrderedUpdates
|
56 | 57 |
|
57 |
| -from pymc.logprob.abstract import MeasurableVariable, _get_measurable_outputs, _logprob |
| 58 | +from pymc.logprob.abstract import ( |
| 59 | + MeasurableVariable, |
| 60 | + _get_measurable_outputs, |
| 61 | + _logprob, |
| 62 | + get_measurable_outputs, |
| 63 | +) |
58 | 64 | from pymc.logprob.joint_logprob import factorized_joint_logprob
|
59 | 65 | from pymc.logprob.rewriting import (
|
| 66 | + PreserveRVMappings, |
60 | 67 | inc_subtensor_ops,
|
61 | 68 | logprob_rewrites_db,
|
62 | 69 | measurable_ir_rewrites_db,
|
|
66 | 73 | class MeasurableScan(Scan):
|
67 | 74 | """A placeholder used to specify a log-likelihood for a scan sub-graph."""
|
68 | 75 |
|
| 76 | + def __str__(self): |
| 77 | + return f"Measurable({super().__str__()})" |
| 78 | + |
69 | 79 |
|
70 | 80 | MeasurableVariable.register(MeasurableScan)
|
71 | 81 |
|
@@ -359,6 +369,12 @@ def find_measurable_scans(fgraph, node):
|
359 | 369 | )
|
360 | 370 | for n in local_fgraph_topo:
|
361 | 371 | if isinstance(n.op, MeasurableVariable):
|
| 372 | + measurable_outputs = get_measurable_outputs(n.op, n) |
| 373 | + # This variable's source of measure is used by another inner node, |
| 374 | + # So we don't need it to be an output! |
| 375 | + if not measurable_outputs: |
| 376 | + continue |
| 377 | + |
362 | 378 | non_output_node_clients = [
|
363 | 379 | c for c in clients[n] if c not in curr_scanargs.inner_outputs
|
364 | 380 | ]
|
@@ -494,6 +510,10 @@ def add_opts_to_inner_graphs(fgraph, node):
|
494 | 510 | clone=True,
|
495 | 511 | copy_inputs=False,
|
496 | 512 | copy_orphans=False,
|
| 513 | + features=[ |
| 514 | + ShapeFeature(), |
| 515 | + PreserveRVMappings({}), |
| 516 | + ], |
497 | 517 | )
|
498 | 518 |
|
499 | 519 | logprob_rewrites_db.query(RewriteDatabaseQuery(include=["basic"])).rewrite(inner_fgraph)
|
|
0 commit comments