@@ -255,6 +255,25 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
255
255
# Raise errors and warnings immediately
256
256
self .clone ()._marginalize (user_warnings = True )
257
257
258
+ def _transform_input (self , inputs ):
259
+ "Create a function from the untransformed space to the transformed space"
260
+ transformed_rvs = []
261
+ transformed_names = []
262
+
263
+ for rv in self .free_RVs :
264
+ transform = self .rvs_to_transforms .get (rv )
265
+ if transform is None :
266
+ transformed_rvs .append (rv )
267
+ transformed_names .append (rv .name )
268
+ else :
269
+ transformed_rv = transform .forward (rv , * rv .owner .inputs )
270
+ transformed_rvs .append (transformed_rv )
271
+ transformed_names .append (self .rvs_to_values [rv ].name )
272
+
273
+ fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
274
+ rets = fn (inputs )
275
+ return dict (zip (transformed_names , rets ))
276
+
258
277
def unmarginalize (self , rvs_to_unmarginalize ):
259
278
for rv in rvs_to_unmarginalize :
260
279
self .marginalized_rvs .remove (rv )
@@ -324,7 +343,6 @@ def recover_marginals(
324
343
other_values = [v for v in m .value_vars if v is not marginalized_value ]
325
344
326
345
# TODO: Handle constants
327
- # TODO: Handle transformed variables
328
346
329
347
joint_logps = vectorize_graph (
330
348
joint_logp ,
@@ -346,7 +364,7 @@ def recover_marginals(
346
364
on_unused_input = "ignore" ,
347
365
)
348
366
349
- logvs = [rv_loglike_fn (** vs ) for vs in posterior_pts ]
367
+ logvs = [rv_loglike_fn (** self . _transform_input ( vs ) ) for vs in posterior_pts ]
350
368
351
369
if include_samples :
352
370
logps , samples = zip (* logvs )
0 commit comments