26
26
vectorize_graph ,
27
27
)
28
28
from pytensor .scan import map as scan_map
29
- from pytensor .tensor import TensorVariable
29
+ from pytensor .tensor import TensorType , TensorVariable
30
30
from pytensor .tensor .elemwise import Elemwise
31
31
from pytensor .tensor .shape import Shape
32
32
from pytensor .tensor .special import log_softmax
@@ -381,41 +381,36 @@ def transform_input(inputs):
381
381
382
382
rv_dict = {}
383
383
rv_dims = {}
384
- for seed , rv in zip (seeds , vars_to_recover ):
384
+ for seed , marginalized_rv in zip (seeds , vars_to_recover ):
385
385
supported_dists = (Bernoulli , Categorical , DiscreteUniform )
386
- if not isinstance (rv .owner .op , supported_dists ):
386
+ if not isinstance (marginalized_rv .owner .op , supported_dists ):
387
387
raise NotImplementedError (
388
- f"RV with distribution { rv .owner .op } cannot be recovered. "
388
+ f"RV with distribution { marginalized_rv .owner .op } cannot be recovered. "
389
389
f"Supported distribution include { supported_dists } "
390
390
)
391
391
392
392
m = self .clone ()
393
- rv = m .vars_to_clone [rv ]
394
- m .unmarginalize ([rv ])
395
- dependent_vars = find_conditional_dependent_rvs (rv , m .basic_RVs )
396
- joint_logps = m .logp (vars = dependent_vars + [ rv ] , sum = False )
393
+ marginalized_rv = m .vars_to_clone [marginalized_rv ]
394
+ m .unmarginalize ([marginalized_rv ])
395
+ dependent_vars = find_conditional_dependent_rvs (marginalized_rv , m .basic_RVs )
396
+ joint_logps = m .logp (vars = [ marginalized_rv ] + dependent_vars , sum = False )
397
397
398
- marginalized_value = m .rvs_to_values [rv ]
398
+ marginalized_value = m .rvs_to_values [marginalized_rv ]
399
399
other_values = [v for v in m .value_vars if v is not marginalized_value ]
400
400
401
401
# Handle batch dims for marginalized value and its dependent RVs
402
- joint_logp = joint_logps [- 1 ]
403
- for dv in joint_logps [:- 1 ]:
404
- dbcast = dv .type .broadcastable
405
- mbcast = marginalized_value .type .broadcastable
406
- mbcast = (True ,) * (len (dbcast ) - len (mbcast )) + mbcast
407
- values_axis_bcast = [
408
- i for i , (m , v ) in enumerate (zip (mbcast , dbcast )) if m and not v
409
- ]
410
- joint_logp += dv .sum (values_axis_bcast )
402
+ marginalized_logp , * dependent_logps = joint_logps
403
+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
404
+ marginalized_rv .type , dependent_logps
405
+ )
411
406
412
- rv_shape = constant_fold (tuple (rv .shape ))
413
- rv_domain = get_domain_of_finite_discrete_rv (rv )
407
+ rv_shape = constant_fold (tuple (marginalized_rv .shape ))
408
+ rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
414
409
rv_domain_tensor = pt .moveaxis (
415
410
pt .full (
416
411
(* rv_shape , len (rv_domain )),
417
412
rv_domain ,
418
- dtype = rv .dtype ,
413
+ dtype = marginalized_rv .dtype ,
419
414
),
420
415
- 1 ,
421
416
0 ,
@@ -431,7 +426,7 @@ def transform_input(inputs):
431
426
joint_logps_norm = log_softmax (joint_logps , axis = - 1 )
432
427
if return_samples :
433
428
sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
434
- if isinstance (rv .owner .op , DiscreteUniform ):
429
+ if isinstance (marginalized_rv .owner .op , DiscreteUniform ):
435
430
sample_rv_outs += rv_domain [0 ]
436
431
437
432
rv_loglike_fn = compile_pymc (
@@ -456,18 +451,20 @@ def transform_input(inputs):
456
451
logps , samples = zip (* logvs )
457
452
logps = np .array (logps )
458
453
samples = np .array (samples )
459
- rv_dict [rv .name ] = samples .reshape (
454
+ rv_dict [marginalized_rv .name ] = samples .reshape (
460
455
tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
461
456
)
462
457
else :
463
458
logps = np .array (logvs )
464
459
465
- rv_dict ["lp_" + rv .name ] = logps .reshape (
460
+ rv_dict ["lp_" + marginalized_rv .name ] = logps .reshape (
466
461
tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
467
462
)
468
- if rv .name in m .named_vars_to_dims :
469
- rv_dims [rv .name ] = list (m .named_vars_to_dims [rv .name ])
470
- rv_dims ["lp_" + rv .name ] = rv_dims [rv .name ] + ["lp_" + rv .name + "_dim" ]
463
+ if marginalized_rv .name in m .named_vars_to_dims :
464
+ rv_dims [marginalized_rv .name ] = list (m .named_vars_to_dims [marginalized_rv .name ])
465
+ rv_dims ["lp_" + marginalized_rv .name ] = rv_dims [marginalized_rv .name ] + [
466
+ "lp_" + marginalized_rv .name + "_dim"
467
+ ]
471
468
472
469
coords , dims = coords_and_dims_for_inferencedata (self )
473
470
dims .update (rv_dims )
@@ -647,6 +644,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
647
644
raise NotImplementedError (f"Cannot compute domain for op { op } " )
648
645
649
646
647
+ def _add_reduce_batch_dependent_logps (
648
+ marginalized_type : TensorType , dependent_logps : Sequence [TensorVariable ]
649
+ ):
650
+ """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
651
+
652
+ mbcast = marginalized_type .broadcastable
653
+ reduced_logps = []
654
+ for dependent_logp in dependent_logps :
655
+ dbcast = dependent_logp .type .broadcastable
656
+ dim_diff = len (dbcast ) - len (mbcast )
657
+ mbcast_aligned = (True ,) * dim_diff + mbcast
658
+ vbcast_axis = [i for i , (m , v ) in enumerate (zip (mbcast_aligned , dbcast )) if m and not v ]
659
+ reduced_logps .append (dependent_logp .sum (vbcast_axis ))
660
+ return pt .add (* reduced_logps )
661
+
662
+
650
663
@_logprob .register (FiniteDiscreteMarginalRV )
651
664
def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
652
665
# Clone the inner RV graph of the Marginalized RV
@@ -662,17 +675,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
662
675
logps_dict = conditional_logp (rv_values = inner_rvs_to_values , ** kwargs )
663
676
664
677
# Reduce logp dimensions corresponding to broadcasted variables
665
- joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
666
- for inner_rv , inner_value in inner_rvs_to_values .items ():
667
- if inner_rv is marginalized_rv :
668
- continue
669
- vbcast = inner_value .type .broadcastable
670
- mbcast = marginalized_rv .type .broadcastable
671
- mbcast = (True ,) * (len (vbcast ) - len (mbcast )) + mbcast
672
- values_axis_bcast = [i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ]
673
- joint_logp += logps_dict [inner_value ].sum (values_axis_bcast , keepdims = True )
674
-
675
- # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
678
+ marginalized_logp = logps_dict .pop (inner_rvs_to_values [marginalized_rv ])
679
+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
680
+ marginalized_rv .type , logps_dict .values ()
681
+ )
682
+
683
+ # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
676
684
# values of the marginalized RV
677
685
# Some inputs are not root inputs (such as transformed projections of value variables)
678
686
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -700,6 +708,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
700
708
)
701
709
702
710
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
711
+ # TODO: Try vectorize here
703
712
if len (marginalized_rv_domain ) <= 10 :
704
713
joint_logps = [
705
714
joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
0 commit comments