@@ -255,25 +255,6 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
255
255
# Raise errors and warnings immediately
256
256
self .clone ()._marginalize (user_warnings = True )
257
257
258
- def _transform_input (self , inputs ):
259
- "Create a function from the untransformed space to the transformed space"
260
- transformed_rvs = []
261
- transformed_names = []
262
-
263
- for rv in self .free_RVs :
264
- transform = self .rvs_to_transforms .get (rv )
265
- if transform is None :
266
- transformed_rvs .append (rv )
267
- transformed_names .append (rv .name )
268
- else :
269
- transformed_rv = transform .forward (rv , * rv .owner .inputs )
270
- transformed_rvs .append (transformed_rv )
271
- transformed_names .append (self .rvs_to_values [rv ].name )
272
-
273
- fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
274
- rets = fn (inputs )
275
- return dict (zip (transformed_names , rets ))
276
-
277
258
def unmarginalize (self , rvs_to_unmarginalize ):
278
259
for rv in rvs_to_unmarginalize :
279
260
self .marginalized_rvs .remove (rv )
@@ -320,8 +301,15 @@ def recover_marginals(
320
301
rv_dict = {}
321
302
rv_dims_dict = {}
322
303
304
+ # Disable all transforms
305
+ model = self .clone ()
306
+ model .rvs_to_transforms = {k : None for k in model .rvs_to_transforms }
307
+ for rv in model .free_RVs :
308
+ model .rvs_to_values [rv ].name = rv .name
309
+ var_names = [model .vars_to_clone [rv ] for rv in var_names ]
310
+
323
311
for rv in var_names :
324
- m = self .clone ()
312
+ m = model .clone ()
325
313
rv = m .vars_to_clone [rv ]
326
314
m .unmarginalize ([rv ])
327
315
joint_logp = m .logp ()
@@ -343,7 +331,6 @@ def recover_marginals(
343
331
other_values = [v for v in m .value_vars if v is not marginalized_value ]
344
332
345
333
# TODO: Handle constants
346
-
347
334
joint_logps = vectorize_graph (
348
335
joint_logp ,
349
336
replace = {marginalized_value : rv_domain_tensor },
@@ -364,7 +351,7 @@ def recover_marginals(
364
351
on_unused_input = "ignore" ,
365
352
)
366
353
367
- logvs = [rv_loglike_fn (** self . _transform_input ( vs ) ) for vs in posterior_pts ]
354
+ logvs = [rv_loglike_fn (** vs ) for vs in posterior_pts ]
368
355
369
356
if include_samples :
370
357
logps , samples = zip (* logvs )
0 commit comments