11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
- import warnings
15
-
16
14
from typing import (
17
15
Callable ,
18
16
Dict ,
@@ -147,32 +145,6 @@ def dataframe_to_tensor_variable(df: pd.DataFrame, *args, **kwargs) -> TensorVar
147
145
return at .as_tensor_variable (df .to_numpy (), * args , ** kwargs )
148
146
149
147
150
- def extract_rv_and_value_vars (
151
- var : TensorVariable ,
152
- ) -> Tuple [TensorVariable , TensorVariable ]:
153
- """Return a random variable and it's observations or value variable, or ``None``.
154
-
155
- Parameters
156
- ==========
157
- var
158
- A variable corresponding to a ``RandomVariable``.
159
-
160
- Returns
161
- =======
162
- The first value in the tuple is the ``RandomVariable``, and the second is the
163
- measure/log-likelihood value variable that corresponds with the latter.
164
-
165
- """
166
- if not var .owner :
167
- return None , None
168
-
169
- if isinstance (var .owner .op , RandomVariable ):
170
- rv_value = getattr (var .tag , "observations" , getattr (var .tag , "value_var" , None ))
171
- return var , rv_value
172
-
173
- return None , None
174
-
175
-
176
148
def extract_obs_data (x : TensorVariable ) -> np .ndarray :
177
149
"""Extract data from observed symbolic variables.
178
150
@@ -200,20 +172,15 @@ def extract_obs_data(x: TensorVariable) -> np.ndarray:
200
172
201
173
def walk_model (
202
174
graphs : Iterable [TensorVariable ],
203
- walk_past_rvs : bool = False ,
204
175
stop_at_vars : Optional [Set [TensorVariable ]] = None ,
205
176
expand_fn : Callable [[TensorVariable ], Iterable [TensorVariable ]] = lambda var : [],
206
177
) -> Generator [TensorVariable , None , None ]:
207
178
"""Walk model graphs and yield their nodes.
208
179
209
- By default, these walks will not go past ``RandomVariable`` nodes.
210
-
211
180
Parameters
212
181
==========
213
182
graphs
214
183
The graphs to walk.
215
- walk_past_rvs
216
- If ``True``, the walk will not terminate at ``RandomVariable``s.
217
184
stop_at_vars
218
185
A list of variables at which the walk will terminate.
219
186
expand_fn
@@ -225,16 +192,12 @@ def walk_model(
225
192
def expand (var ):
226
193
new_vars = expand_fn (var )
227
194
228
- if (
229
- var .owner
230
- and (walk_past_rvs or not isinstance (var .owner .op , RandomVariable ))
231
- and (var not in stop_at_vars )
232
- ):
195
+ if var .owner and var not in stop_at_vars :
233
196
new_vars .extend (reversed (var .owner .inputs ))
234
197
235
198
return new_vars
236
199
237
- yield from walk (graphs , expand , False )
200
+ yield from walk (graphs , expand , bfs = False )
238
201
239
202
240
203
def replace_rvs_in_graphs (
@@ -263,7 +226,11 @@ def replace_rvs_in_graphs(
263
226
264
227
def expand_replace (var ):
265
228
new_nodes = []
266
- if var .owner and isinstance (var .owner .op , RandomVariable ):
229
+ if var .owner :
230
+ # Call replacement_fn to update replacements dict inplace and, optionally,
231
+ # specify new nodes that should also be walked for replacements. This
232
+ # includes `value` variables that are not simple input variables, and may
233
+ # contain other `random` variables in their graphs (e.g., IntervalTransform)
267
234
new_nodes .extend (replacement_fn (var , replacements ))
268
235
return new_nodes
269
236
@@ -290,10 +257,10 @@ def expand_replace(var):
290
257
291
258
def rvs_to_value_vars (
292
259
graphs : Iterable [TensorVariable ],
293
- apply_transforms : bool = False ,
260
+ apply_transforms : bool = True ,
294
261
initial_replacements : Optional [Dict [TensorVariable , TensorVariable ]] = None ,
295
262
** kwargs ,
296
- ) -> Tuple [ TensorVariable , Dict [ TensorVariable , TensorVariable ]] :
263
+ ) -> TensorVariable :
297
264
"""Clone and replace random variables in graphs with their value variables.
298
265
299
266
This will *not* recompute test values in the resulting graphs.
@@ -309,38 +276,30 @@ def rvs_to_value_vars(
309
276
310
277
"""
311
278
312
- # Avoid circular dependency
313
- from pymc .distributions import NoDistribution
314
-
315
- def transform_replacements (var , replacements ):
316
- rv_var , rv_value_var = extract_rv_and_value_vars (var )
317
-
318
- if rv_value_var is None :
319
- # If RandomVariable does not have a value_var and corresponds to
320
- # a NoDistribution, we allow further replacements in upstream graph
321
- if isinstance (rv_var .owner .op , NoDistribution ):
322
- return rv_var .owner .inputs
279
+ def populate_replacements (
280
+ random_var : TensorVariable , replacements : Dict [TensorVariable , TensorVariable ]
281
+ ) -> List [TensorVariable ]:
282
+ # Populate replacements dict with {rv: value} pairs indicating which graph
283
+ # RVs should be replaced by what value variables.
323
284
324
- else :
325
- warnings .warn (
326
- f"No value variable found for { rv_var } ; "
327
- "the random variable will not be replaced."
328
- )
329
- return []
285
+ value_var = getattr (
286
+ random_var .tag , "observations" , getattr (random_var .tag , "value_var" , None )
287
+ )
330
288
331
- transform = getattr (rv_value_var .tag , "transform" , None )
289
+ # No value variable to replace RV with
290
+ if value_var is None :
291
+ return []
332
292
333
- if transform is None or not apply_transforms :
334
- replacements [var ] = rv_value_var
335
- # In case the value variable is itself a graph, we walk it for
336
- # potential replacements
337
- return [rv_value_var ]
293
+ transform = getattr (value_var .tag , "transform" , None )
294
+ if transform is not None and apply_transforms :
295
+ # We want to replace uses of the RV by the back-transformation of its value
296
+ value_var = transform .backward (value_var , * random_var .owner .inputs )
338
297
339
- trans_rv_value = transform .backward (rv_value_var , * rv_var .owner .inputs )
340
- replacements [var ] = trans_rv_value
298
+ replacements [random_var ] = value_var
341
299
342
- # Walk the transformed variable and make replacements
343
- return [trans_rv_value ]
300
+ # Also walk the graph of the value variable to make any additional replacements
301
+ # if that is not a simple input variable
302
+ return [value_var ]
344
303
345
304
# Clone original graphs
346
305
inputs = [i for i in graph_inputs (graphs ) if not isinstance (i , Constant )]
@@ -352,7 +311,14 @@ def transform_replacements(var, replacements):
352
311
equiv .get (k , k ): equiv .get (v , v ) for k , v in initial_replacements .items ()
353
312
}
354
313
355
- return replace_rvs_in_graphs (graphs , transform_replacements , initial_replacements , ** kwargs )
314
+ graphs , _ = replace_rvs_in_graphs (
315
+ graphs ,
316
+ replacement_fn = populate_replacements ,
317
+ initial_replacements = initial_replacements ,
318
+ ** kwargs ,
319
+ )
320
+
321
+ return graphs
356
322
357
323
358
324
def inputvars (a ):
0 commit comments