1
1
import warnings
2
2
from typing import Sequence , Tuple , Union
3
3
4
- from arviz import dict_to_dataset
5
4
import numpy as np
6
- import pytensor .tensor as pt
7
5
import pymc
6
+ import pytensor .tensor as pt
7
+ from arviz import dict_to_dataset
8
8
from pymc import SymbolicRandomVariable
9
9
from pymc .backends .arviz import coords_and_dims_for_inferencedata
10
10
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
13
13
from pymc .logprob .basic import conditional_logp
14
14
from pymc .logprob .transforms import IntervalTransform
15
15
from pymc .model import Model
16
- from pymc .pytensorf import constant_fold , compile_pymc , inputvars
16
+ from pymc .pytensorf import compile_pymc , constant_fold , inputvars
17
17
from pymc .util import dataset_to_point_list , treedict
18
18
from pytensor import Mode
19
19
from pytensor .compile import SharedVariable
@@ -248,7 +248,9 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
248
248
# Raise errors and warnings immediately
249
249
self .clone ()._marginalize (user_warnings = True )
250
250
251
- def unmunmarginalize (self , idata , var_names = None , include_samples = False , extend_inferencedata = True ):
251
+ def unmunmarginalize (
252
+ self , idata , var_names = None , include_samples = False , extend_inferencedata = True
253
+ ):
252
254
"""Computes log-likelihoods of marginalized variables conditioned on parameters
253
255
of the model given InferenceData with posterior group
254
256
@@ -276,7 +278,9 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
276
278
posterior = idata .posterior
277
279
278
280
# Remove Deterministics
279
- posterior_values = posterior [[rv .name for rv in mm .free_RVs if rv not in self .marginalized_rvs ]]
281
+ posterior_values = posterior [
282
+ [rv .name for rv in mm .free_RVs if rv not in self .marginalized_rvs ]
283
+ ]
280
284
281
285
sample_dims = ("chain" , "draw" )
282
286
posterior_pts , stacked_dims = dataset_to_point_list (posterior_values , sample_dims )
@@ -306,10 +310,11 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
306
310
307
311
# TODO: Handle constants
308
312
# TODO: Handle transformed variables
309
- joint_logp_op = OpFromGraph ([marginalized_value ] + other_values , [joint_logp ], inline = True )
313
+ joint_logp_op = OpFromGraph (
314
+ [marginalized_value ] + other_values , [joint_logp ], inline = True
315
+ )
310
316
joint_logps = [
311
- joint_logp_op (rv_domain_tensor [i ], * other_values )
312
- for i in range (len (rv_domain ))
317
+ joint_logp_op (rv_domain_tensor [i ], * other_values ) for i in range (len (rv_domain ))
313
318
]
314
319
315
320
rv_loglike_fn = None
@@ -333,20 +338,17 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
333
338
logps , samples = zip (* logvs )
334
339
logps = np .array (logps )
335
340
rv_dict [rv .name ] = np .reshape (
336
- samples ,
337
- tuple (len (coord ) for coord in stacked_dims .values ())
341
+ samples , tuple (len (coord ) for coord in stacked_dims .values ())
338
342
)
339
343
rv_dims_dict [rv .name ] = sample_dims
340
344
rv_dict ["lp_" + rv .name ] = np .reshape (
341
- logps ,
342
- tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :]
345
+ logps , tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :]
343
346
)
344
347
rv_dims_dict ["lp_" + rv .name ] = sample_dims + ("lp_" + rv .name + "_dims" ,)
345
348
else :
346
349
logps = np .array (logvs )
347
350
rv_dict ["lp_" + rv .name ] = np .reshape (
348
- logps ,
349
- tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :]
351
+ logps , tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :]
350
352
)
351
353
rv_dims_dict ["lp_" + rv .name ] = sample_dims + ("lp_" + rv .name + "_dims" ,)
352
354
@@ -361,7 +363,7 @@ def unmunmarginalize(self, idata, var_names=None, include_samples=False, extend_
361
363
)
362
364
363
365
if extend_inferencedata :
364
- rv_dict = {k :(rv_dims_dict [k ], v ) for (k ,v ) in rv_dict .items ()}
366
+ rv_dict = {k : (rv_dims_dict [k ], v ) for (k , v ) in rv_dict .items ()}
365
367
idata = idata .posterior .assign (** rv_dict )
366
368
return idata
367
369
else :
0 commit comments