|
36 | 36 |
|
37 | 37 | import warnings
|
38 | 38 |
|
39 |
| -from collections import deque |
40 | 39 | from collections.abc import Sequence
|
41 | 40 | from typing import TypeAlias
|
42 | 41 |
|
43 | 42 | import numpy as np
|
44 | 43 | import pytensor.tensor as pt
|
45 | 44 |
|
46 |
| -from pytensor import config |
47 | 45 | from pytensor.graph.basic import (
|
48 | 46 | Constant,
|
49 | 47 | Variable,
|
50 | 48 | ancestors,
|
51 |
| - graph_inputs, |
52 |
| - io_toposort, |
53 | 49 | )
|
54 |
| -from pytensor.graph.op import compute_test_value |
55 | 50 | from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
|
56 | 51 | from pytensor.tensor.variable import TensorVariable
|
57 | 52 |
|
|
65 | 60 | from pymc.logprob.rewriting import cleanup_ir, construct_ir_fgraph
|
66 | 61 | from pymc.logprob.transform_value import TransformValuesRewrite
|
67 | 62 | from pymc.logprob.transforms import Transform
|
68 |
| -from pymc.logprob.utils import rvs_in_graph |
| 63 | +from pymc.logprob.utils import get_related_valued_nodes, rvs_in_graph |
69 | 64 | from pymc.pytensorf import replace_vars_in_graphs
|
70 | 65 |
|
71 | 66 | TensorLike: TypeAlias = Variable | float | np.ndarray
|
@@ -210,8 +205,9 @@ def normal_logp(value, mu, sigma):
|
210 | 205 | try:
|
211 | 206 | return _logprob_helper(rv, value, **kwargs)
|
212 | 207 | except NotImplementedError:
|
213 |
| - fgraph, _, _ = construct_ir_fgraph({rv: value}) |
214 |
| - [(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items() |
| 208 | + fgraph = construct_ir_fgraph({rv: value}) |
| 209 | + [ir_valued_var] = fgraph.outputs |
| 210 | + [ir_rv, ir_value] = ir_valued_var.owner.inputs |
215 | 211 | expr = _logprob_helper(ir_rv, ir_value, **kwargs)
|
216 | 212 | cleanup_ir([expr])
|
217 | 213 | if warn_rvs:
|
@@ -308,9 +304,10 @@ def normal_logcdf(value, mu, sigma):
|
308 | 304 | return _logcdf_helper(rv, value, **kwargs)
|
309 | 305 | except NotImplementedError:
|
310 | 306 | # Try to rewrite rv
|
311 |
| - fgraph, _, _ = construct_ir_fgraph({rv: value}) |
312 |
| - [ir_rv] = fgraph.outputs |
313 |
| - expr = _logcdf_helper(ir_rv, value, **kwargs) |
| 307 | + fgraph = construct_ir_fgraph({rv: value}) |
| 308 | + [ir_valued_rv] = fgraph.outputs |
| 309 | + [ir_rv, ir_value] = ir_valued_rv.owner.inputs |
| 310 | + expr = _logcdf_helper(ir_rv, ir_value, **kwargs) |
314 | 311 | cleanup_ir([expr])
|
315 | 312 | if warn_rvs:
|
316 | 313 | _warn_rvs_in_inferred_graph(expr)
|
@@ -390,9 +387,10 @@ def icdf(rv: TensorVariable, value: TensorLike, warn_rvs=None, **kwargs) -> Tens
|
390 | 387 | return _icdf_helper(rv, value, **kwargs)
|
391 | 388 | except NotImplementedError:
|
392 | 389 | # Try to rewrite rv
|
393 |
| - fgraph, _, _ = construct_ir_fgraph({rv: value}) |
394 |
| - [ir_rv] = fgraph.outputs |
395 |
| - expr = _icdf_helper(ir_rv, value, **kwargs) |
| 390 | + fgraph = construct_ir_fgraph({rv: value}) |
| 391 | + [ir_valued_rv] = fgraph.outputs |
| 392 | + [ir_rv, ir_value] = ir_valued_rv.owner.inputs |
| 393 | + expr = _icdf_helper(ir_rv, ir_value, **kwargs) |
396 | 394 | cleanup_ir([expr])
|
397 | 395 | if warn_rvs:
|
398 | 396 | _warn_rvs_in_inferred_graph(expr)
|
@@ -476,111 +474,96 @@ def conditional_logp(
|
476 | 474 | """
|
477 | 475 | warn_rvs, kwargs = _deprecate_warn_missing_rvs(warn_rvs, kwargs)
|
478 | 476 |
|
479 |
| - fgraph, rv_values, _ = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) |
| 477 | + fgraph = construct_ir_fgraph(rv_values, ir_rewriter=ir_rewriter) |
480 | 478 |
|
481 | 479 | if extra_rewrites is not None:
|
482 | 480 | extra_rewrites.rewrite(fgraph)
|
483 | 481 |
|
484 |
| - rv_remapper = fgraph.preserve_rv_mappings |
485 |
| - |
486 |
| - # This is the updated random-to-value-vars map with the lifted/rewritten |
487 |
| - # variables. The rewrites are supposed to produce new |
488 |
| - # `MeasurableOp`s whose variables are amenable to `_logprob`. |
489 |
| - updated_rv_values = rv_remapper.rv_values |
490 |
| - |
491 |
| - # Some rewrites also transform the original value variables. This is the |
492 |
| - # updated map from the new value variables to the original ones, which |
493 |
| - # we want to use as the keys in the final dictionary output |
494 |
| - original_values = rv_remapper.original_values |
495 |
| - |
496 |
| - # When a `_logprob` has been produced for a `MeasurableOp` node, all |
497 |
| - # other references to it need to be replaced with its value-variable all |
498 |
| - # throughout the `_logprob`-produced graphs. The following `dict` |
499 |
| - # cumulatively maintains remappings for all the variables/nodes that needed |
500 |
| - # to be recreated after replacing `MeasurableOp` variables with their |
501 |
| - # value-variables. Since these replacements work in topological order, all |
502 |
| - # the necessary value-variable replacements should be present for each |
503 |
| - # node. |
504 |
| - replacements = updated_rv_values.copy() |
| 482 | + # Walk the graph from its inputs to its outputs and construct the |
| 483 | + # log-probability |
| 484 | + replacements = {} |
505 | 485 |
|
506 | 486 | # To avoid cloning the value variables (or ancestors of value variables),
|
507 | 487 | # we map them to themselves in the `replacements` `dict`
|
508 | 488 | # (i.e. entries already existing in `replacements` aren't cloned)
|
509 | 489 | replacements.update(
|
510 |
| - { |
511 |
| - v: v |
512 |
| - for v in ancestors(rv_values.values()) |
513 |
| - if (not isinstance(v, Constant) and v not in replacements) |
514 |
| - } |
| 490 | + {v: v for v in ancestors(rv_values.values()) if not isinstance(v, Constant)} |
515 | 491 | )
|
516 | 492 |
|
517 | 493 | # Walk the graph from its inputs to its outputs and construct the
|
518 | 494 | # log-probability
|
519 |
| - q = deque(fgraph.toposort()) |
520 |
| - logprob_vars = {} |
521 |
| - |
522 |
| - while q: |
523 |
| - node = q.popleft() |
| 495 | + values_to_logprobs = {} |
| 496 | + original_values = tuple(rv_values.values()) |
524 | 497 |
|
| 498 | + # TODO: This seems too convoluted, can we just replace all RVs by their values, |
| 499 | + # except for the fgraph outputs (for which we want to call _logprob on)? |
| 500 | + for node in fgraph.toposort(): |
525 | 501 | if not isinstance(node.op, MeasurableOp):
|
526 | 502 | continue
|
527 | 503 |
|
528 |
| - q_values = [replacements[q_rv] for q_rv in node.outputs if q_rv in updated_rv_values] |
| 504 | + valued_nodes = get_related_valued_nodes(node, fgraph) |
529 | 505 |
|
530 |
| - if not q_values: |
| 506 | + if not valued_nodes: |
531 | 507 | continue
|
532 | 508 |
|
| 509 | + node_rvs = [valued_var.inputs[0] for valued_var in valued_nodes] |
| 510 | + node_values = [valued_var.inputs[1] for valued_var in valued_nodes] |
| 511 | + node_output_idxs = [ |
| 512 | + fgraph.outputs.index(valued_var.outputs[0]) for valued_var in valued_nodes |
| 513 | + ] |
| 514 | + |
533 | 515 | # Replace `RandomVariable`s in the inputs with value variables.
|
| 516 | + # Also, store the results in the `replacements` map for the nodes that follow. |
| 517 | + for node_rv, node_value in zip(node_rvs, node_values): |
| 518 | + replacements[node_rv] = node_value |
| 519 | + |
534 | 520 | remapped_vars = replace_vars_in_graphs(
|
535 |
| - graphs=q_values + list(node.inputs), |
| 521 | + graphs=node_values + list(node.inputs), |
536 | 522 | replacements=replacements,
|
537 | 523 | )
|
538 |
| - q_values = remapped_vars[: len(q_values)] |
539 |
| - q_rv_inputs = remapped_vars[len(q_values) :] |
| 524 | + node_values = remapped_vars[: len(node_values)] |
| 525 | + node_inputs = remapped_vars[len(node_values) :] |
540 | 526 |
|
541 |
| - q_logprob_vars = _logprob( |
| 527 | + node_logprobs = _logprob( |
542 | 528 | node.op,
|
543 |
| - q_values, |
544 |
| - *q_rv_inputs, |
| 529 | + node_values, |
| 530 | + *node_inputs, |
545 | 531 | **kwargs,
|
546 | 532 | )
|
547 | 533 |
|
548 |
| - if not isinstance(q_logprob_vars, list | tuple): |
549 |
| - q_logprob_vars = [q_logprob_vars] |
| 534 | + if not isinstance(node_logprobs, list | tuple): |
| 535 | + node_logprobs = [node_logprobs] |
550 | 536 |
|
551 |
| - for q_value_var, q_logprob_var in zip(q_values, q_logprob_vars): |
552 |
| - q_value_var = original_values[q_value_var] |
| 537 | + for node_output_idx, node_value, node_logprob in zip( |
| 538 | + node_output_idxs, node_values, node_logprobs |
| 539 | + ): |
| 540 | + original_value = original_values[node_output_idx] |
553 | 541 |
|
554 |
| - if q_value_var.name: |
555 |
| - q_logprob_var.name = f"{q_value_var.name}_logprob" |
| 542 | + if original_value.name: |
| 543 | + node_logprob.name = f"{original_value.name}_logprob" |
556 | 544 |
|
557 |
| - if q_value_var in logprob_vars: |
| 545 | + if original_value in values_to_logprobs: |
558 | 546 | raise ValueError(
|
559 |
| - f"More than one logprob term was assigned to the value var {q_value_var}" |
| 547 | + f"More than one logprob term was assigned to the value var {original_value}" |
560 | 548 | )
|
561 | 549 |
|
562 |
| - logprob_vars[q_value_var] = q_logprob_var |
563 |
| - |
564 |
| - # Recompute test values for the changes introduced by the replacements above. |
565 |
| - if config.compute_test_value != "off": |
566 |
| - for node in io_toposort(graph_inputs(q_logprob_vars), q_logprob_vars): |
567 |
| - compute_test_value(node) |
| 550 | + values_to_logprobs[original_value] = node_logprob |
568 | 551 |
|
569 |
| - missing_value_terms = set(original_values.values()) - set(logprob_vars.keys()) |
| 552 | + missing_value_terms = set(original_values) - set(values_to_logprobs) |
570 | 553 | if missing_value_terms:
|
571 | 554 | raise RuntimeError(
|
572 | 555 | f"The logprob terms of the following value variables could not be derived: {missing_value_terms}"
|
573 | 556 | )
|
574 | 557 |
|
575 |
| - logprob_expressions = list(logprob_vars.values()) |
576 |
| - cleanup_ir(logprob_expressions) |
| 558 | + logprobs = list(values_to_logprobs.values()) |
| 559 | + cleanup_ir(logprobs) |
577 | 560 |
|
578 | 561 | if warn_rvs:
|
579 |
| - rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprob_expressions) |
| 562 | + rvs_in_logp_expressions = _find_unallowed_rvs_in_graph(logprobs) |
580 | 563 | if rvs_in_logp_expressions:
|
581 | 564 | warnings.warn(RVS_IN_JOINT_LOGP_GRAPH_MSG % rvs_in_logp_expressions, UserWarning)
|
582 | 565 |
|
583 |
| - return logprob_vars |
| 566 | + return values_to_logprobs |
584 | 567 |
|
585 | 568 |
|
586 | 569 | def transformed_conditional_logp(
|
|
0 commit comments