@@ -255,6 +255,24 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
255
255
# Raise errors and warnings immediately
256
256
self .clone ()._marginalize (user_warnings = True )
257
257
258
+ def _to_transformed (self ):
259
+ "Create a function from the untransformed space to the transformed space"
260
+ transformed_rvs = []
261
+ transformed_names = []
262
+
263
+ for rv in self .free_RVs :
264
+ transform = self .rvs_to_transforms .get (rv )
265
+ if transform is None :
266
+ transformed_rvs .append (rv )
267
+ transformed_names .append (rv .name )
268
+ else :
269
+ transformed_rv = transform .forward (rv , * rv .owner .inputs )
270
+ transformed_rvs .append (transformed_rv )
271
+ transformed_names .append (self .rvs_to_values [rv ].name )
272
+
273
+ fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
274
+ return fn , transformed_names
275
+
258
276
def unmarginalize (self , rvs_to_unmarginalize ):
259
277
for rv in rvs_to_unmarginalize :
260
278
self .marginalized_rvs .remove (rv )
@@ -263,12 +281,14 @@ def unmarginalize(self, rvs_to_unmarginalize):
263
281
def recover_marginals (
264
282
self , idata , var_names = None , include_samples = False , extend_inferencedata = True
265
283
):
266
- """Computes log-likelihoods of marginalized variables conditioned on parameters
267
- of the model given InferenceData with posterior group
284
+ """Computes unnormalized posterior probabilities of marginalized variables
285
+ conditioned on parameters of the model given InferenceData with posterior group
268
286
269
287
When there are multiple marginalized variables, each marginalized variable is
270
288
conditioned on both the parameters and the other variables still marginalized
271
289
290
+ All log-probabilities are within the transformed space
291
+
272
292
Parameters
273
293
----------
274
294
idata : InferenceData
@@ -298,18 +318,20 @@ def recover_marginals(
298
318
299
319
sample_dims = ("chain" , "draw" )
300
320
posterior_pts , stacked_dims = dataset_to_point_list (posterior_values , sample_dims )
321
+
322
+ # Handle Transforms
323
+ transform_fn , transform_names = self ._to_transformed ()
324
+
325
+ def transform_input (inputs ):
326
+ return dict (zip (transform_names , transform_fn (inputs )))
327
+
328
+ posterior_pts = [transform_input (vs ) for vs in posterior_pts ]
329
+
301
330
rv_dict = {}
302
331
rv_dims_dict = {}
303
332
304
- # Disable all transforms
305
- model = self .clone ()
306
- model .rvs_to_transforms = {k : None for k in model .rvs_to_transforms }
307
- for rv in model .free_RVs :
308
- model .rvs_to_values [rv ].name = rv .name
309
- var_names = [model .vars_to_clone [rv ] for rv in var_names ]
310
-
311
333
for rv in var_names :
312
- m = model .clone ()
334
+ m = self .clone ()
313
335
rv = m .vars_to_clone [rv ]
314
336
m .unmarginalize ([rv ])
315
337
joint_logp = m .logp ()
@@ -331,6 +353,7 @@ def recover_marginals(
331
353
other_values = [v for v in m .value_vars if v is not marginalized_value ]
332
354
333
355
# TODO: Handle constants
356
+
334
357
joint_logps = vectorize_graph (
335
358
joint_logp ,
336
359
replace = {marginalized_value : rv_domain_tensor },
0 commit comments