@@ -90,14 +90,16 @@ def model_free_rv(rv, value, transform, *dims):
90
90
91
91
92
92
def toposort_replace (
93
- fgraph : FunctionGraph , replacements : Sequence [Tuple [Variable , Variable ]]
93
+ fgraph : FunctionGraph , replacements : Sequence [Tuple [Variable , Variable ]], reverse : bool = False
94
94
) -> None :
95
95
"""Replace multiple variables in topological order."""
96
96
toposort = fgraph .toposort ()
97
97
sorted_replacements = sorted (
98
- replacements , key = lambda pair : toposort .index (pair [0 ].owner ) if pair [0 ].owner else - 1
98
+ replacements ,
99
+ key = lambda pair : toposort .index (pair [0 ].owner ) if pair [0 ].owner else - 1 ,
100
+ reverse = reverse ,
99
101
)
100
- fgraph .replace_all (tuple ( sorted_replacements ) , import_missing = True )
102
+ fgraph .replace_all (sorted_replacements , import_missing = True )
101
103
102
104
103
105
@node_rewriter ([Elemwise ])
@@ -109,11 +111,20 @@ def local_remove_identity(fgraph, node):
109
111
remove_identity_rewrite = out2in (local_remove_identity )
110
112
111
113
112
- def fgraph_from_model (model : Model ) -> Tuple [FunctionGraph , Dict [Variable , Variable ]]:
114
+ def fgraph_from_model (
115
+ model : Model , inlined_views = False
116
+ ) -> Tuple [FunctionGraph , Dict [Variable , Variable ]]:
113
117
"""Convert Model to FunctionGraph.
114
118
115
119
See: model_from_fgraph
116
120
121
+ Parameters
122
+ ----------
123
+ model: PyMC model
124
+ inlined_views: bool, default False
125
+ Whether "view" variables (Deterministics and Data) should be inlined among RVs in the fgraph,
126
+ or show up as separate branches.
127
+
117
128
Returns
118
129
-------
119
130
fgraph: FunctionGraph
@@ -138,19 +149,36 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
138
149
free_rvs = model .free_RVs
139
150
observed_rvs = model .observed_RVs
140
151
potentials = model .potentials
152
+ named_vars = model .named_vars .values ()
141
153
# We copy Deterministics (Identity Op) so that they don't show in between "main" variables
142
154
# We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator
143
155
old_deterministics = model .deterministics
144
- deterministics = [det .copy (det .name ) for det in old_deterministics ]
145
- # Other variables that are in model.named_vars but are not any of the categories above
156
+ deterministics = [det if inlined_views else det .copy (det .name ) for det in old_deterministics ]
157
+ # Value variables (we also have to decide whether to inline named ones)
158
+ old_value_vars = list (rvs_to_values .values ())
159
+ unnamed_value_vars = [val for val in old_value_vars if val not in named_vars ]
160
+ named_value_vars = [
161
+ val if inlined_views else val .copy (val .name ) for val in old_value_vars if val in named_vars
162
+ ]
163
+ value_vars = old_value_vars .copy ()
164
+ if inlined_views :
165
+ # In this case we want to use the named_value_vars as the value_vars in RVs
166
+ for named_val in named_value_vars :
167
+ idx = value_vars .index (named_val )
168
+ value_vars [idx ] = named_val
169
+ # Other variables that are in named_vars but are not any of the categories above
146
170
# E.g., MutableData, ConstantData, _dim_lengths
147
171
# We use the same trick as deterministics!
148
- accounted_for = free_rvs + observed_rvs + potentials + old_deterministics
149
- old_other_named_vars = [var for var in model .named_vars .values () if var not in accounted_for ]
150
- other_named_vars = [var .copy (var .name ) for var in old_other_named_vars ]
151
- value_vars = [val for val in rvs_to_values .values () if val not in old_other_named_vars ]
172
+ accounted_for = set (free_rvs + observed_rvs + potentials + old_deterministics + old_value_vars )
173
+ other_named_vars = [
174
+ var if inlined_views else var .copy (var .name )
175
+ for var in named_vars
176
+ if var not in accounted_for
177
+ ]
152
178
153
- model_vars = rvs + potentials + deterministics + other_named_vars + value_vars
179
+ model_vars = (
180
+ rvs + potentials + deterministics + other_named_vars + named_value_vars + unnamed_value_vars
181
+ )
154
182
155
183
memo = {}
156
184
@@ -176,13 +204,13 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
176
204
177
205
# Introduce dummy `ModelVar` Ops
178
206
free_rvs_to_transforms = {memo [k ]: tr for k , tr in rvs_to_transforms .items ()}
179
- free_rvs_to_values = {memo [k ]: memo [v ] for k , v in rvs_to_values . items ( ) if k in free_rvs }
207
+ free_rvs_to_values = {memo [k ]: memo [v ] for k , v in zip ( rvs , value_vars ) if k in free_rvs }
180
208
observed_rvs_to_values = {
181
- memo [k ]: memo [v ] for k , v in rvs_to_values . items ( ) if k in observed_rvs
209
+ memo [k ]: memo [v ] for k , v in zip ( rvs , value_vars ) if k in observed_rvs
182
210
}
183
211
potentials = [memo [k ] for k in potentials ]
184
212
deterministics = [memo [k ] for k in deterministics ]
185
- other_named_vars = [memo [k ] for k in other_named_vars ]
213
+ named_vars = [memo [k ] for k in other_named_vars + named_value_vars ]
186
214
187
215
vars = fgraph .outputs
188
216
new_vars = []
@@ -198,31 +226,31 @@ def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Varia
198
226
new_var = model_potential (var , * dims )
199
227
elif var in deterministics :
200
228
new_var = model_deterministic (var , * dims )
201
- elif var in other_named_vars :
229
+ elif var in named_vars :
202
230
new_var = model_named (var , * dims )
203
231
else :
204
- # Value variables
232
+ # Unnamed value variables
205
233
new_var = var
206
234
new_vars .append (new_var )
207
235
208
236
replacements = tuple (zip (vars , new_vars ))
209
- toposort_replace (fgraph , replacements )
237
+ toposort_replace (fgraph , replacements , reverse = True )
210
238
211
239
# Reference model vars in memo
212
240
inverse_memo = {v : k for k , v in memo .items ()}
213
241
for var , model_var in replacements :
214
- if isinstance (
215
- model_var .owner is not None and model_var .owner .op , (ModelDeterministic , ModelNamed )
242
+ if not inlined_views and (
243
+ model_var .owner and isinstance ( model_var .owner .op , (ModelDeterministic , ModelNamed ) )
216
244
):
217
245
# Ignore extra identity that will be removed at the end
218
246
var = var .owner .inputs [0 ]
219
247
original_var = inverse_memo [var ]
220
248
memo [original_var ] = model_var
221
249
222
- # Remove value variable as outputs, now that they are graph inputs
223
- first_value_idx = len (fgraph .outputs ) - len (value_vars )
224
- for _ in value_vars :
225
- fgraph .remove_output (first_value_idx )
250
+ # Remove the last outputs corresponding to unnamed value variables , now that they are graph inputs
251
+ first_idx_to_remove = len (fgraph .outputs ) - len (unnamed_value_vars )
252
+ for _ in unnamed_value_vars :
253
+ fgraph .remove_output (first_idx_to_remove )
226
254
227
255
# Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph
228
256
remove_identity_rewrite .apply (fgraph )
0 commit comments