Skip to content

Commit 97df9c3

Browse files
committed
Introduce valued variables in logprob IR
This avoids rewrites across conditioning points, that could break dependencies Also extend logprob derivation of scans with multiple valued output types
1 parent b06d6c3 commit 97df9c3

21 files changed

+680
-823
lines changed

pymc/logprob/abstract.py

Lines changed: 127 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from collections.abc import Sequence
4141
from functools import singledispatch
4242

43-
from pytensor.graph.op import Op
43+
from pytensor.graph import Apply, Op, Variable
4444
from pytensor.graph.utils import MetaType
4545
from pytensor.tensor import TensorVariable
4646
from pytensor.tensor.elemwise import Elemwise
@@ -165,3 +165,129 @@ def __init__(self, scalar_op, *args, **kwargs):
165165

166166
def __str__(self):
167167
return f"Measurable{super().__str__()}"
168+
169+
170+
class ValuedRV(Op):
171+
r"""Represents the association of a measurable variable and its value.
172+
173+
A `ValuedVariable` node represents the pair :math:`(Y, y)`, where `y` the value at which :math:`Y`'s density
174+
or probability mass function is evaluated.
175+
176+
The log-probability function takes such pairs as input, which makes these nodes in a graph an intermediate form
177+
that serves to construct a log-probability from a model graph.
178+
179+
180+
Notes
181+
-----
182+
The introduction of these operations achieves two goals:
183+
1. Identify the conditioning points between multiple, potentially interdependent measurable variables,
184+
and introduce the respective value variables in the IR graph.
185+
2. Prevent automatic rewrites across conditioning points
186+
187+
About point 2. In the current framework, a RV logp cannot depend on a transformation of the value variable
188+
of a second RV it depends on. While this is mathematically trivial, we don't have the machinery to achieve it.
189+
190+
The only case we do something like this is in the ad-hoc transform_value rewrite, but there we are
191+
told explicitly what value variables must be transformed before being used in the density of dependent RVs.
192+
193+
For example ,the following is not supported:
194+
195+
```python
196+
x_log = pt.random.normal()
197+
x = pt.exp(x_log)
198+
y = pt.random.normal(loc=x_log)
199+
200+
x_value = pt.scalar()
201+
y_value = pt.scalar()
202+
conditional_logprob({x: x_value, y: y_value})
203+
```
204+
205+
Our framework doesn't know that the density of y should depend on a (log) transform of x_value.
206+
207+
Importantly, we need to prevent this limitation from being introduced automatically by our IR rewrites.
208+
For example given the following:
209+
210+
```python
211+
a_base = pm.Normal.dist()
212+
a = a_base * 5
213+
b = pm.Normal.dist(a * 8)
214+
215+
a_value = scalar()
216+
b_value = scalar()
217+
conditional_logp({a: a_value, b: b_value})
218+
```
219+
220+
We do not want `b` to be rewritten as `pm.Normal.dist(a_base * 40)`, as it would then be disconnected from the
221+
valued `a` associated with `pm.Normal.dist(a_base * 5). By introducing `ValuedRV` nodes the graph looks like:
222+
223+
```python
224+
a_base = pm.Normal.dist()
225+
a = valued_rv(a_base * 5, a_value)
226+
b = valued_rv(a * 8, b_value)
227+
```
228+
229+
Since, PyTensor doesn't know what to do with `ValuedRV` nodes, there is no risk of rewriting across them
230+
and breaking the dependency of `b` on `a`. The new nodes isolate the graphs between conditioning points.
231+
"""
232+
233+
def make_node(self, rv, value):
234+
assert isinstance(rv, Variable)
235+
assert isinstance(value, Variable)
236+
return Apply(self, [rv, value], [rv.type(name=rv.name)])
237+
238+
def perform(self, node, inputs, out):
239+
raise NotImplementedError("ValuedVar should not be present in the final graph!")
240+
241+
def infer_shape(self, fgraph, node, input_shapes):
242+
return [input_shapes[0]]
243+
244+
245+
valued_rv = ValuedRV()
246+
247+
248+
class PromisedValuedRV(Op):
249+
r"""Marks a variable as being promised a valued variable that will only be assigned by the logprob method.
250+
251+
Some measurable RVs like Join/MakeVector can combine multiple, potentially interdependent, RVs into a single
252+
composite valued node. Only in the logp function is this value split and sent to each component,
253+
but we still want to achieve the same goals that ValuedRVs achieve during the IR rewrites.
254+
255+
Here is an example analogous to the one described in the docstrings of ValuedRV:
256+
257+
```python
258+
a_base = pt.random.normal()
259+
a = a_base * 5
260+
b = pt.random.normal(a * 8)
261+
ab = pt.stack([a, b])
262+
ab_value = pt.vector(shape=(2,))
263+
264+
logp(ab, ab_value)
265+
```
266+
267+
The density of `ab[2]` (that is `b`) depends on `ab_value[1]` and `ab_value[0] * 8`, but this is not apparent
268+
in the IR representation because the values of `a` and `b` are merged together, and will only be split by the logp
269+
function (see why next). For the time being we introduce a PromisedValue to isolate the graphs of a and b, and
270+
freezing the dependency of `b` on `a` (not `a_base`).
271+
272+
Now why use a new Op and not just ValuedRV? Just for convenience! In the end we still want a function from
273+
`ab_value` to `stack([logp(a), logp(b | a)])`, and if we split the values ahead of time we wouldn't know how to
274+
stack them later (or even know that we were supposed to).
275+
276+
One final point, while this achieves the same goal as introducing ValuedRVs, it already constitutes a form of inference
277+
(knowing how/when to measure Join/MakeVectors), so we have to do it as an IR rewrite. However, we have to do it
278+
before any other rewrites, so you'll see that the related rewrites are registered in `early_measurable_ir_rewrites_db`.
279+
280+
"""
281+
282+
def make_node(self, rv):
283+
assert isinstance(rv, Variable)
284+
return Apply(self, [rv], [rv.type(name=rv.name)])
285+
286+
def perform(self, node, inputs, out):
287+
raise NotImplementedError("PromisedValuedRV should not be present in the final graph!")
288+
289+
def infer_shape(self, fgraph, node, input_shapes):
290+
return [input_shapes[0]]
291+
292+
293+
promised_valued_rv = PromisedValuedRV()

pymc/logprob/basic.py

Lines changed: 56 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -36,22 +36,17 @@
3636

3737
import warnings
3838

39-
from collections import deque
4039
from collections.abc import Sequence
4140
from typing import TypeAlias
4241

4342
import numpy as np
4443
import pytensor.tensor as pt
4544

46-
from pytensor import config
4745
from pytensor.graph.basic import (
4846
Constant,
4947
Variable,
5048
ancestors,
51-
graph_inputs,
52-
io_toposort,
5349
)
54-
from pytensor.graph.op import compute_test_value
5550
from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
5651
from pytensor.tensor.variable import TensorVariable
5752

@@ -65,7 +60,7 @@
6560
from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
6661
from pymc.logprob.transform_value import TransformValuesRewrite
6762
from pymc.logprob.transforms import Transform
68-
from pymc.logprob.utils import rvs_in_graph
63+
from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph
6964
from pymc.pytensorf import replace_vars_in_graphs
7065

7166
TensorLike: TypeAlias = Variable | float | np.ndarray
@@ -210,8 +205,9 @@ def normal_logp(value, mu, sigma):
210205
try:
211206
return _logprob_helper(rv, value, **kwargs)
212207
except NotImplementedError:
213-
fgraph, _, _ = construct_ir_fgraph({rv: value})
214-
[(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
208+
fgraph = construct_ir_fgraph({rv: value})
209+
[ir_valued_var] = fgraph.outputs
210+
[ir_rv, ir_value] = ir_valued_var.owner.inputs
215211
expr = _logprob_helper(ir_rv, ir_value, **kwargs)
216212
cleanup_ir([expr])
217213
if warn_rvs:
@@ -308,9 +304,10 @@ def normal_logcdf(value, mu, sigma):
308304
return _logcdf_helper(rv, value, **kwargs)
309305
except NotImplementedError:
310306
# Try to rewrite rv
311-
fgraph, _, _ = construct_ir_fgraph({rv: value})
312-
[ir_rv] = fgraph.outputs
313-
expr = _logcdf_helper(ir_rv, value, **kwargs)
307+
fgraph = construct_ir_fgraph({rv: value})
308+
[ir_valued_rv] = fgraph.outputs
309+
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
310+
expr = _logcdf_helper(ir_rv, ir_value, **kwargs)
314311
cleanup_ir([expr])
315312
if warn_rvs:
316313
_warn_rvs_in_inferred_graph(expr)
@@ -390,9 +387,10 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
390387
return _icdf_helper(rv, value, **kwargs)
391388
except NotImplementedError:
392389
# Try to rewrite rv
393-
fgraph, _, _ = construct_ir_fgraph({rv: value})
394-
[ir_rv] = fgraph.outputs
395-
expr = _icdf_helper(ir_rv, value, **kwargs)
390+
fgraph = construct_ir_fgraph({rv: value})
391+
[ir_valued_rv] = fgraph.outputs
392+
[ir_rv, ir_value] = ir_valued_rv.owner.inputs
393+
expr = _icdf_helper(ir_rv, ir_value, **kwargs)
396394
cleanup_ir([expr])
397395
if warn_rvs:
398396
_warn_rvs_in_inferred_graph(expr)
@@ -476,111 +474,96 @@ def conditional_logp(
476474
"""
477475
warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)
478476

479-
fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
477+
fgraph = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter)
480478

481479
if extra_rewrites is not None:
482480
extra_rewrites.rewrite(fgraph)
483481

484-
rv_remapper = fgraph.preserve_rv_mappings
485-
486-
# This is the updated random-to-value-vars map with the lifted/rewritten
487-
# variables. The rewrites are supposed to produce new
488-
# `MeasurableOp`s whose variables are amenable to `_logprob`.
489-
updated_rv_values = rv_remapper.rv_values
490-
491-
# Some rewrites also transform the original value variables. This is the
492-
# updated map from the new value variables to the original ones, which
493-
# we want to use as the keys in the final dictionary output
494-
original_values = rv_remapper.original_values
495-
496-
# When a `_logprob` has been produced for a `MeasurableOp` node, all
497-
# other references to it need to be replaced with its value-variable all
498-
# throughout the `_logprob`-produced graphs. The following `dict`
499-
# cumulatively maintains remappings for all the variables/nodes that needed
500-
# to be recreated after replacing `MeasurableOp` variables with their
501-
# value-variables. Since these replacements work in topological order, all
502-
# the necessary value-variable replacements should be present for each
503-
# node.
504-
replacements = updated_rv_values.copy()
482+
# Walk the graph from its inputs to its outputs and construct the
483+
# log-probability
484+
replacements = {}
505485

506486
# To avoid cloning the value variables (or ancestors of value variables),
507487
# we map them to themselves in the `replacements` `dict`
508488
# (i.e. entries already existing in `replacements` aren't cloned)
509489
replacements.update(
510-
{
511-
v: v
512-
for v in ancestors(rv_values.values())
513-
if (not isinstance(v, Constant) and v not in replacements)
514-
}
490+
{v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)}
515491
)
516492

517493
# Walk the graph from its inputs to its outputs and construct the
518494
# log-probability
519-
q = deque(fgraph.toposort())
520-
logprob_vars = {}
521-
522-
while q:
523-
node = q.popleft()
495+
values_to_logprobs = {}
496+
original_values = tuple(rv_values.values())
524497

498+
# TODO: This seems too convoluted, can we just replace all RVs by their values,
499+
# except for the fgraph outputs (for which we want to call _logprob on)?
500+
for node in fgraph.toposort():
525501
if not isinstance(node.op, MeasurableOp):
526502
continue
527503

528-
q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values]
504+
valued_nodes = get_related_valued_nodes(node, fgraph)
529505

530-
if not q_values:
506+
if not valued_nodes:
531507
continue
532508

509+
node_rvs = [valued_var.inputs[0] for valued_var in valued_nodes]
510+
node_values = [valued_var.inputs[1] for valued_var in valued_nodes]
511+
node_output_idxs = [
512+
fgraph.outputs.index(valued_var.outputs[0]) for valued_var in valued_nodes
513+
]
514+
533515
# Replace `RandomVariable`s in the inputs with value variables.
516+
# Also, store the results in the `replacements` map for the nodes that follow.
517+
for node_rv, node_value in zip(node_rvs, node_values):
518+
replacements[node_rv] = node_value
519+
534520
remapped_vars = replace_vars_in_graphs(
535-
graphs=q_values + list(node.inputs),
521+
graphs=node_values + list(node.inputs),
536522
replacements=replacements,
537523
)
538-
q_values = remapped_vars[: len(q_values)]
539-
q_rv_inputs = remapped_vars[len(q_values) :]
524+
node_values = remapped_vars[: len(node_values)]
525+
node_inputs = remapped_vars[len(node_values) :]
540526

541-
q_logprob_vars = _logprob(
527+
node_logprobs = _logprob(
542528
node.op,
543-
q_values,
544-
*q_rv_inputs,
529+
node_values,
530+
*node_inputs,
545531
**kwargs,
546532
)
547533

548-
if not isinstance(q_logprob_vars, list | tuple):
549-
q_logprob_vars = [q_logprob_vars]
534+
if not isinstance(node_logprobs, list | tuple):
535+
node_logprobs = [node_logprobs]
550536

551-
for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars):
552-
q_value_var = original_values[q_value_var]
537+
for node_output_idx, node_value, node_logprob in zip(
538+
node_output_idxs, node_values, node_logprobs
539+
):
540+
original_value = original_values[node_output_idx]
553541

554-
if q_value_var.name:
555-
q_logprob_var.name = f"{q_value_var.name}_logprob"
542+
if original_value.name:
543+
node_logprob.name = f"{original_value.name}_logprob"
556544

557-
if q_value_var in logprob_vars:
545+
if original_value in values_to_logprobs:
558546
raise ValueError(
559-
f"More than one logprob term was assigned to the value var {q_value_var}"
547+
f"More than one logprob term was assigned to the value var {original_value}"
560548
)
561549

562-
logprob_vars[q_value_var] = q_logprob_var
563-
564-
# Recompute test values for the changes introduced by the replacements above.
565-
if config.compute_test_value != "off":
566-
for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars):
567-
compute_test_value(node)
550+
values_to_logprobs[original_value] = node_logprob
568551

569-
missing_value_terms = set(original_values.values()) - set(logprob_vars.keys())
552+
missing_value_terms = set(original_values) - set(values_to_logprobs)
570553
if missing_value_terms:
571554
raise RuntimeError(
572555
f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
573556
)
574557

575-
logprob_expressions = list(logprob_vars.values())
576-
cleanup_ir(logprob_expressions)
558+
logprobs = list(values_to_logprobs.values())
559+
cleanup_ir(logprobs)
577560

578561
if warn_rvs:
579-
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions)
562+
rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs)
580563
if rvs_in_logp_expressions:
581564
warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)
582565

583-
return logprob_vars
566+
return values_to_logprobs
584567

585568

586569
def transformed_conditional_logp(

0 commit comments

Comments
 (0)