|
37 | 37 |
|
38 | 38 | from pathlib import Path
|
39 | 39 |
|
40 |
| -import pytensor |
41 |
| - |
42 | 40 | from pytensor import tensor as pt
|
43 |
| -from pytensor.graph.fg import FunctionGraph |
44 |
| -from pytensor.graph.op import compute_test_value |
45 | 41 | from pytensor.graph.rewriting.basic import node_rewriter
|
46 | 42 | from pytensor.tensor import TensorVariable
|
47 |
| -from pytensor.tensor.basic import Alloc, Join, MakeVector |
| 43 | +from pytensor.tensor.basic import Join, MakeVector |
48 | 44 | from pytensor.tensor.elemwise import DimShuffle
|
49 | 45 | from pytensor.tensor.random.op import RandomVariable
|
50 | 46 | from pytensor.tensor.random.rewriting import (
|
51 | 47 | local_dimshuffle_rv_lift,
|
52 |
| - local_rv_size_lift, |
53 | 48 | )
|
54 | 49 |
|
55 | 50 | from pymc.logprob.abstract import MeasurableOp, _logprob, _logprob_helper
|
|
62 | 57 | from pymc.pytensorf import constant_fold
|
63 | 58 |
|
64 | 59 |
|
65 |
| -@node_rewriter([Alloc]) |
66 |
| -def naive_bcast_rv_lift(fgraph: FunctionGraph, node): |
67 |
| - """Lift an ``Alloc`` through a ``RandomVariable`` ``Op``. |
68 |
| -
|
69 |
| - XXX: This implementation simply broadcasts the ``RandomVariable``'s |
70 |
| - parameters, which won't always work (e.g. multivariate distributions). |
71 |
| -
|
72 |
| - TODO: Instead, it should use ``RandomVariable.ndim_supp``--and the like--to |
73 |
| - determine which dimensions of each parameter need to be broadcasted. |
74 |
| - Also, this doesn't need to remove ``size`` to perform the lifting, like it |
75 |
| - currently does. |
76 |
| - """ |
77 |
| - |
78 |
| - if not ( |
79 |
| - isinstance(node.op, Alloc) |
80 |
| - and node.inputs[0].owner |
81 |
| - and isinstance(node.inputs[0].owner.op, RandomVariable) |
82 |
| - ): |
83 |
| - return None # pragma: no cover |
84 |
| - |
85 |
| - bcast_shape = node.inputs[1:] |
86 |
| - |
87 |
| - rv_var = node.inputs[0] |
88 |
| - rv_node = rv_var.owner |
89 |
| - |
90 |
| - if hasattr(fgraph, "dont_touch_vars") and rv_var in fgraph.dont_touch_vars: |
91 |
| - return None # pragma: no cover |
92 |
| - |
93 |
| - # Do not replace RV if it is associated with a value variable |
94 |
| - rv_map_feature: PreserveRVMappings | None = getattr(fgraph, "preserve_rv_mappings", None) |
95 |
| - if rv_map_feature is not None and rv_var in rv_map_feature.rv_values: |
96 |
| - return None |
97 |
| - |
98 |
| - if not bcast_shape: |
99 |
| - # The `Alloc` is broadcasting a scalar to a scalar (i.e. doing nothing) |
100 |
| - assert rv_var.ndim == 0 |
101 |
| - return [rv_var] |
102 |
| - |
103 |
| - size_lift_res = local_rv_size_lift.transform(fgraph, rv_node) |
104 |
| - if size_lift_res is None: |
105 |
| - lifted_node = rv_node |
106 |
| - else: |
107 |
| - _, lifted_rv = size_lift_res |
108 |
| - lifted_node = lifted_rv.owner |
109 |
| - |
110 |
| - rng, size, *dist_params = lifted_node.inputs |
111 |
| - |
112 |
| - new_dist_params = [ |
113 |
| - pt.broadcast_to( |
114 |
| - param, |
115 |
| - pt.broadcast_shape(tuple(param.shape), tuple(bcast_shape), arrays_are_shapes=True), |
116 |
| - ) |
117 |
| - for param in dist_params |
118 |
| - ] |
119 |
| - bcasted_node = lifted_node.op.make_node(rng, size, *new_dist_params) |
120 |
| - |
121 |
| - if pytensor.config.compute_test_value != "off": |
122 |
| - compute_test_value(bcasted_node) |
123 |
| - |
124 |
| - return [bcasted_node.outputs[1]] |
125 |
| - |
126 |
| - |
127 | 60 | class MeasurableMakeVector(MeasurableOp, MakeVector):
|
128 | 61 | """A placeholder used to specify a log-likelihood for a cumsum sub-graph."""
|
129 | 62 |
|
|
0 commit comments