2
2
from typing import Sequence , Tuple , Union
3
3
4
4
import numpy as np
5
+ import pymc
5
6
import pytensor .tensor as pt
7
+ from arviz import dict_to_dataset
6
8
from pymc import SymbolicRandomVariable
9
+ from pymc .backends .arviz import coords_and_dims_for_inferencedata
7
10
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
8
11
from pymc .distributions .transforms import Chain
9
12
from pymc .logprob .abstract import _logprob
10
13
from pymc .logprob .basic import conditional_logp
11
14
from pymc .logprob .transforms import IntervalTransform
12
15
from pymc .model import Model
13
- from pymc .pytensorf import constant_fold , inputvars
16
+ from pymc .pytensorf import compile_pymc , constant_fold , inputvars
17
+ from pymc .util import _get_seeds_per_chain , dataset_to_point_list , treedict
14
18
from pytensor import Mode
15
19
from pytensor .compile import SharedVariable
16
20
from pytensor .compile .builders import OpFromGraph
17
- from pytensor .graph import Constant , FunctionGraph , ancestors , clone_replace
21
+ from pytensor .graph import (
22
+ Constant ,
23
+ FunctionGraph ,
24
+ ancestors ,
25
+ clone_replace ,
26
+ vectorize_graph ,
27
+ )
18
28
from pytensor .scan import map as scan_map
19
29
from pytensor .tensor import TensorVariable
20
30
from pytensor .tensor .elemwise import Elemwise
31
+ from pytensor .tensor .shape import Shape
32
+ from pytensor .tensor .special import log_softmax
21
33
22
34
__all__ = ["MarginalModel" ]
23
35
24
- from pytensor .tensor .shape import Shape
25
-
26
36
27
37
class MarginalModel (Model ):
28
38
"""Subclass of PyMC Model that implements functionality for automatic
@@ -74,6 +84,7 @@ class MarginalModel(Model):
74
84
def __init__ (self , * args , ** kwargs ):
75
85
super ().__init__ (* args , ** kwargs )
76
86
self .marginalized_rvs = []
87
+ self ._marginalized_named_vars_to_dims = treedict ()
77
88
78
89
def _delete_rv_mappings (self , rv : TensorVariable ) -> None :
79
90
"""Remove all model mappings referring to rv
@@ -205,8 +216,9 @@ def clone(self):
205
216
vars = self .basic_RVs + self .potentials + self .deterministics + self .marginalized_rvs
206
217
cloned_vars = clone_replace (vars )
207
218
vars_to_clone = {var : cloned_var for var , cloned_var in zip (vars , cloned_vars )}
219
+ m .vars_to_clone = vars_to_clone
208
220
209
- m .named_vars = {name : vars_to_clone [var ] for name , var in self .named_vars .items ()}
221
+ m .named_vars = treedict ( {name : vars_to_clone [var ] for name , var in self .named_vars .items ()})
210
222
m .named_vars_to_dims = self .named_vars_to_dims
211
223
m .values_to_rvs = {i : vars_to_clone [rv ] for i , rv in self .values_to_rvs .items ()}
212
224
m .rvs_to_values = {vars_to_clone [rv ]: i for rv , i in self .rvs_to_values .items ()}
@@ -220,11 +232,18 @@ def clone(self):
220
232
m .deterministics = [vars_to_clone [det ] for det in self .deterministics ]
221
233
222
234
m .marginalized_rvs = [vars_to_clone [rv ] for rv in self .marginalized_rvs ]
235
+ m ._marginalized_named_vars_to_dims = self ._marginalized_named_vars_to_dims
223
236
return m
224
237
225
- def marginalize (self , rvs_to_marginalize : Union [TensorVariable , Sequence [TensorVariable ]]):
238
+ def marginalize (
239
+ self ,
240
+ rvs_to_marginalize : Union [TensorVariable , str , Sequence [TensorVariable ], Sequence [str ]],
241
+ ):
226
242
if not isinstance (rvs_to_marginalize , Sequence ):
227
243
rvs_to_marginalize = (rvs_to_marginalize ,)
244
+ rvs_to_marginalize = [
245
+ self [var ] if isinstance (var , str ) else var for var in rvs_to_marginalize
246
+ ]
228
247
229
248
supported_dists = (Bernoulli , Categorical , DiscreteUniform )
230
249
for rv_to_marginalize in rvs_to_marginalize :
@@ -238,12 +257,233 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
238
257
f"Supported distribution include { supported_dists } "
239
258
)
240
259
260
+ if rv_to_marginalize .name in self .named_vars_to_dims :
261
+ dims = self .named_vars_to_dims [rv_to_marginalize .name ]
262
+ self ._marginalized_named_vars_to_dims [rv_to_marginalize .name ] = dims
263
+
241
264
self ._delete_rv_mappings (rv_to_marginalize )
242
265
self .marginalized_rvs .append (rv_to_marginalize )
243
266
244
267
# Raise errors and warnings immediately
245
268
self .clone ()._marginalize (user_warnings = True )
246
269
270
+ def _to_transformed (self ):
271
+ "Create a function from the untransformed space to the transformed space"
272
+ transformed_rvs = []
273
+ transformed_names = []
274
+
275
+ for rv in self .free_RVs :
276
+ transform = self .rvs_to_transforms .get (rv )
277
+ if transform is None :
278
+ transformed_rvs .append (rv )
279
+ transformed_names .append (rv .name )
280
+ else :
281
+ transformed_rv = transform .forward (rv , * rv .owner .inputs )
282
+ transformed_rvs .append (transformed_rv )
283
+ transformed_names .append (self .rvs_to_values [rv ].name )
284
+
285
+ fn = self .compile_fn (inputs = self .free_RVs , outs = transformed_rvs )
286
+ return fn , transformed_names
287
+
288
+ def unmarginalize (self , rvs_to_unmarginalize ):
289
+ for rv in rvs_to_unmarginalize :
290
+ self .marginalized_rvs .remove (rv )
291
+ if rv .name in self ._marginalized_named_vars_to_dims :
292
+ dims = self ._marginalized_named_vars_to_dims .pop (rv .name )
293
+ else :
294
+ dims = None
295
+ self .register_rv (rv , name = rv .name , dims = dims )
296
+
297
+ def recover_marginals (
298
+ self ,
299
+ idata ,
300
+ var_names = None ,
301
+ return_samples = True ,
302
+ extend_inferencedata = True ,
303
+ random_seed = None ,
304
+ ):
305
+ """Computes posterior log-probabilities and samples of marginalized variables
306
+ conditioned on parameters of the model given InferenceData with posterior group
307
+
308
+ When there are multiple marginalized variables, each marginalized variable is
309
+ conditioned on both the parameters and the other variables still marginalized
310
+
311
+ All log-probabilities are within the transformed space
312
+
313
+ Parameters
314
+ ----------
315
+ idata : InferenceData
316
+ InferenceData with posterior group
317
+ var_names : sequence of str, optional
318
+ List of variable names for which to compute posterior log-probabilities and samples. Defaults to all marginalized variables
319
+ return_samples : bool, default True
320
+ If True, also return samples of the marginalized variables
321
+ extend_inferencedata : bool, default True
322
+ Whether to extend the original InferenceData or return a new one
323
+ random_seed: int, array-like of int or SeedSequence, optional
324
+ Seed used to generating samples
325
+
326
+ Returns
327
+ -------
328
+ idata : InferenceData
329
+ InferenceData with where a lp_{varname} and {varname} for each marginalized variable in var_names added to the posterior group
330
+
331
+ .. code-block:: python
332
+
333
+ import pymc as pm
334
+ from pymc_experimental import MarginalModel
335
+
336
+ with MarginalModel() as m:
337
+ p = pm.Beta("p", 1, 1)
338
+ x = pm.Bernoulli("x", p=p, shape=(3,))
339
+ y = pm.Normal("y", pm.math.switch(x, -10, 10), observed=[10, 10, -10])
340
+
341
+ m.marginalize([x])
342
+
343
+ idata = pm.sample()
344
+ m.recover_marginals(idata, var_names=["x"])
345
+
346
+
347
+ """
348
+ if var_names is None :
349
+ var_names = [var .name for var in self .marginalized_rvs ]
350
+
351
+ var_names = [var if isinstance (var , str ) else var .name for var in var_names ]
352
+ vars_to_recover = [v for v in self .marginalized_rvs if v .name in var_names ]
353
+ missing_names = [v .name for v in vars_to_recover if v not in self .marginalized_rvs ]
354
+ if missing_names :
355
+ raise ValueError (f"Unrecognized var_names: { missing_names } " )
356
+
357
+ if return_samples and random_seed is not None :
358
+ seeds = _get_seeds_per_chain (random_seed , len (vars_to_recover ))
359
+ else :
360
+ seeds = [None ] * len (vars_to_recover )
361
+
362
+ posterior = idata .posterior
363
+
364
+ # Remove Deterministics
365
+ posterior_values = posterior [
366
+ [rv .name for rv in self .free_RVs if rv not in self .marginalized_rvs ]
367
+ ]
368
+
369
+ sample_dims = ("chain" , "draw" )
370
+ posterior_pts , stacked_dims = dataset_to_point_list (posterior_values , sample_dims )
371
+
372
+ # Handle Transforms
373
+ transform_fn , transform_names = self ._to_transformed ()
374
+
375
+ def transform_input (inputs ):
376
+ return dict (zip (transform_names , transform_fn (inputs )))
377
+
378
+ posterior_pts = [transform_input (vs ) for vs in posterior_pts ]
379
+
380
+ rv_dict = {}
381
+ rv_dims = {}
382
+ for seed , rv in zip (seeds , vars_to_recover ):
383
+ supported_dists = (Bernoulli , Categorical , DiscreteUniform )
384
+ if not isinstance (rv .owner .op , supported_dists ):
385
+ raise NotImplementedError (
386
+ f"RV with distribution { rv .owner .op } cannot be recovered. "
387
+ f"Supported distribution include { supported_dists } "
388
+ )
389
+
390
+ m = self .clone ()
391
+ rv = m .vars_to_clone [rv ]
392
+ m .unmarginalize ([rv ])
393
+ dependent_vars = find_conditional_dependent_rvs (rv , m .basic_RVs )
394
+ joint_logps = m .logp (vars = dependent_vars + [rv ], sum = False )
395
+
396
+ marginalized_value = m .rvs_to_values [rv ]
397
+ other_values = [v for v in m .value_vars if v is not marginalized_value ]
398
+
399
+ # Handle batch dims for marginalized value and its dependent RVs
400
+ joint_logp = joint_logps [- 1 ]
401
+ for dv in joint_logps [:- 1 ]:
402
+ dbcast = dv .type .broadcastable
403
+ mbcast = marginalized_value .type .broadcastable
404
+ mbcast = (True ,) * (len (dbcast ) - len (mbcast )) + mbcast
405
+ values_axis_bcast = [
406
+ i for i , (m , v ) in enumerate (zip (mbcast , dbcast )) if m and not v
407
+ ]
408
+ joint_logp += dv .sum (values_axis_bcast )
409
+
410
+ rv_shape = constant_fold (tuple (rv .shape ))
411
+ rv_domain = get_domain_of_finite_discrete_rv (rv )
412
+ rv_domain_tensor = pt .moveaxis (
413
+ pt .full (
414
+ (* rv_shape , len (rv_domain )),
415
+ rv_domain ,
416
+ dtype = rv .dtype ,
417
+ ),
418
+ - 1 ,
419
+ 0 ,
420
+ )
421
+
422
+ joint_logps = vectorize_graph (
423
+ joint_logp ,
424
+ replace = {marginalized_value : rv_domain_tensor },
425
+ )
426
+ joint_logps = pt .moveaxis (joint_logps , 0 , - 1 )
427
+
428
+ rv_loglike_fn = None
429
+ joint_logps_norm = log_softmax (joint_logps , axis = - 1 )
430
+ if return_samples :
431
+ sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
432
+ if isinstance (rv .owner .op , DiscreteUniform ):
433
+ sample_rv_outs += rv_domain [0 ]
434
+
435
+ rv_loglike_fn = compile_pymc (
436
+ inputs = other_values ,
437
+ outputs = [joint_logps_norm , sample_rv_outs ],
438
+ on_unused_input = "ignore" ,
439
+ random_seed = seed ,
440
+ )
441
+ else :
442
+ rv_loglike_fn = compile_pymc (
443
+ inputs = other_values ,
444
+ outputs = joint_logps_norm ,
445
+ on_unused_input = "ignore" ,
446
+ random_seed = seed ,
447
+ )
448
+
449
+ logvs = [rv_loglike_fn (** vs ) for vs in posterior_pts ]
450
+
451
+ logps = None
452
+ samples = None
453
+ if return_samples :
454
+ logps , samples = zip (* logvs )
455
+ logps = np .array (logps )
456
+ samples = np .array (samples )
457
+ rv_dict [rv .name ] = samples .reshape (
458
+ tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
459
+ )
460
+ else :
461
+ logps = np .array (logvs )
462
+
463
+ rv_dict ["lp_" + rv .name ] = logps .reshape (
464
+ tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
465
+ )
466
+ if rv .name in m .named_vars_to_dims :
467
+ rv_dims [rv .name ] = list (m .named_vars_to_dims [rv .name ])
468
+ rv_dims ["lp_" + rv .name ] = rv_dims [rv .name ] + ["lp_" + rv .name + "_dim" ]
469
+
470
+ coords , dims = coords_and_dims_for_inferencedata (self )
471
+ dims .update (rv_dims )
472
+ rv_dataset = dict_to_dataset (
473
+ rv_dict ,
474
+ library = pymc ,
475
+ dims = dims ,
476
+ coords = coords ,
477
+ default_dims = list (sample_dims ),
478
+ skip_event_dims = True ,
479
+ )
480
+
481
+ if extend_inferencedata :
482
+ idata .posterior = idata .posterior .assign (rv_dataset )
483
+ return idata
484
+ else :
485
+ return rv_dataset
486
+
247
487
248
488
class MarginalRV (SymbolicRandomVariable ):
249
489
"""Base class for Marginalized RVs"""
@@ -444,14 +684,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
444
684
# PyMC does not allow RVs in the logp graph, even if we are just using the shape
445
685
marginalized_rv_shape = constant_fold (tuple (marginalized_rv .shape ))
446
686
marginalized_rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
447
- marginalized_rv_domain_tensor = pt .swapaxes (
687
+ marginalized_rv_domain_tensor = pt .moveaxis (
448
688
pt .full (
449
689
(* marginalized_rv_shape , len (marginalized_rv_domain )),
450
690
marginalized_rv_domain ,
451
691
dtype = marginalized_rv .dtype ,
452
692
),
453
- axis1 = 0 ,
454
- axis2 = - 1 ,
693
+ - 1 ,
694
+ 0 ,
455
695
)
456
696
457
697
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
0 commit comments