26
26
vectorize_graph ,
27
27
)
28
28
from pytensor .scan import map as scan_map
29
- from pytensor .tensor import TensorVariable
29
+ from pytensor .tensor import TensorType , TensorVariable
30
30
from pytensor .tensor .elemwise import Elemwise
31
31
from pytensor .tensor .shape import Shape
32
32
from pytensor .tensor .special import log_softmax
@@ -379,41 +379,36 @@ def transform_input(inputs):
379
379
380
380
rv_dict = {}
381
381
rv_dims = {}
382
- for seed , rv in zip (seeds , vars_to_recover ):
382
+ for seed , marginalized_rv in zip (seeds , vars_to_recover ):
383
383
supported_dists = (Bernoulli , Categorical , DiscreteUniform )
384
- if not isinstance (rv .owner .op , supported_dists ):
384
+ if not isinstance (marginalized_rv .owner .op , supported_dists ):
385
385
raise NotImplementedError (
386
- f"RV with distribution { rv .owner .op } cannot be recovered. "
386
+ f"RV with distribution { marginalized_rv .owner .op } cannot be recovered. "
387
387
f"Supported distribution include { supported_dists } "
388
388
)
389
389
390
390
m = self .clone ()
391
- rv = m .vars_to_clone [rv ]
392
- m .unmarginalize ([rv ])
393
- dependent_vars = find_conditional_dependent_rvs (rv , m .basic_RVs )
394
- joint_logps = m .logp (vars = dependent_vars + [ rv ] , sum = False )
391
+ marginalized_rv = m .vars_to_clone [marginalized_rv ]
392
+ m .unmarginalize ([marginalized_rv ])
393
+ dependent_vars = find_conditional_dependent_rvs (marginalized_rv , m .basic_RVs )
394
+ joint_logps = m .logp (vars = [ marginalized_rv ] + dependent_vars , sum = False )
395
395
396
- marginalized_value = m .rvs_to_values [rv ]
396
+ marginalized_value = m .rvs_to_values [marginalized_rv ]
397
397
other_values = [v for v in m .value_vars if v is not marginalized_value ]
398
398
399
399
# Handle batch dims for marginalized value and its dependent RVs
400
- joint_logp = joint_logps [- 1 ]
401
- for dv in joint_logps [:- 1 ]:
402
- dbcast = dv .type .broadcastable
403
- mbcast = marginalized_value .type .broadcastable
404
- mbcast = (True ,) * (len (dbcast ) - len (mbcast )) + mbcast
405
- values_axis_bcast = [
406
- i for i , (m , v ) in enumerate (zip (mbcast , dbcast )) if m and not v
407
- ]
408
- joint_logp += dv .sum (values_axis_bcast )
400
+ marginalized_logp , * dependent_logps = joint_logps
401
+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
402
+ marginalized_rv .type , dependent_logps
403
+ )
409
404
410
- rv_shape = constant_fold (tuple (rv .shape ))
411
- rv_domain = get_domain_of_finite_discrete_rv (rv )
405
+ rv_shape = constant_fold (tuple (marginalized_rv .shape ))
406
+ rv_domain = get_domain_of_finite_discrete_rv (marginalized_rv )
412
407
rv_domain_tensor = pt .moveaxis (
413
408
pt .full (
414
409
(* rv_shape , len (rv_domain )),
415
410
rv_domain ,
416
- dtype = rv .dtype ,
411
+ dtype = marginalized_rv .dtype ,
417
412
),
418
413
- 1 ,
419
414
0 ,
@@ -429,7 +424,7 @@ def transform_input(inputs):
429
424
joint_logps_norm = log_softmax (joint_logps , axis = - 1 )
430
425
if return_samples :
431
426
sample_rv_outs = pymc .Categorical .dist (logit_p = joint_logps )
432
- if isinstance (rv .owner .op , DiscreteUniform ):
427
+ if isinstance (marginalized_rv .owner .op , DiscreteUniform ):
433
428
sample_rv_outs += rv_domain [0 ]
434
429
435
430
rv_loglike_fn = compile_pymc (
@@ -454,18 +449,20 @@ def transform_input(inputs):
454
449
logps , samples = zip (* logvs )
455
450
logps = np .array (logps )
456
451
samples = np .array (samples )
457
- rv_dict [rv .name ] = samples .reshape (
452
+ rv_dict [marginalized_rv .name ] = samples .reshape (
458
453
tuple (len (coord ) for coord in stacked_dims .values ()) + samples .shape [1 :],
459
454
)
460
455
else :
461
456
logps = np .array (logvs )
462
457
463
- rv_dict ["lp_" + rv .name ] = logps .reshape (
458
+ rv_dict ["lp_" + marginalized_rv .name ] = logps .reshape (
464
459
tuple (len (coord ) for coord in stacked_dims .values ()) + logps .shape [1 :],
465
460
)
466
- if rv .name in m .named_vars_to_dims :
467
- rv_dims [rv .name ] = list (m .named_vars_to_dims [rv .name ])
468
- rv_dims ["lp_" + rv .name ] = rv_dims [rv .name ] + ["lp_" + rv .name + "_dim" ]
461
+ if marginalized_rv .name in m .named_vars_to_dims :
462
+ rv_dims [marginalized_rv .name ] = list (m .named_vars_to_dims [marginalized_rv .name ])
463
+ rv_dims ["lp_" + marginalized_rv .name ] = rv_dims [marginalized_rv .name ] + [
464
+ "lp_" + marginalized_rv .name + "_dim"
465
+ ]
469
466
470
467
coords , dims = coords_and_dims_for_inferencedata (self )
471
468
dims .update (rv_dims )
@@ -645,6 +642,22 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
645
642
raise NotImplementedError (f"Cannot compute domain for op { op } " )
646
643
647
644
645
+ def _add_reduce_batch_dependent_logps (
646
+ marginalized_type : TensorType , dependent_logps : Sequence [TensorVariable ]
647
+ ):
648
+ """Add the logps of dependent RVs while reducing extra batch dims as assessed from the `marginalized_type`."""
649
+
650
+ mbcast = marginalized_type .broadcastable
651
+ reduced_logps = []
652
+ for dependent_logp in dependent_logps :
653
+ dbcast = dependent_logp .type .broadcastable
654
+ dim_diff = len (dbcast ) - len (mbcast )
655
+ mbcast_aligned = (True ,) * dim_diff + mbcast
656
+ vbcast_axis = [i for i , (m , v ) in enumerate (zip (mbcast_aligned , dbcast )) if m and not v ]
657
+ reduced_logps .append (dependent_logp .sum (vbcast_axis ))
658
+ return pt .add (* reduced_logps )
659
+
660
+
648
661
@_logprob .register (FiniteDiscreteMarginalRV )
649
662
def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
650
663
# Clone the inner RV graph of the Marginalized RV
@@ -660,17 +673,12 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
660
673
logps_dict = conditional_logp (rv_values = inner_rvs_to_values , ** kwargs )
661
674
662
675
# Reduce logp dimensions corresponding to broadcasted variables
663
- joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
664
- for inner_rv , inner_value in inner_rvs_to_values .items ():
665
- if inner_rv is marginalized_rv :
666
- continue
667
- vbcast = inner_value .type .broadcastable
668
- mbcast = marginalized_rv .type .broadcastable
669
- mbcast = (True ,) * (len (vbcast ) - len (mbcast )) + mbcast
670
- values_axis_bcast = [i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ]
671
- joint_logp += logps_dict [inner_value ].sum (values_axis_bcast , keepdims = True )
672
-
673
- # Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
676
+ marginalized_logp = logps_dict .pop (inner_rvs_to_values [marginalized_rv ])
677
+ joint_logp = marginalized_logp + _add_reduce_batch_dependent_logps (
678
+ marginalized_rv .type , logps_dict .values ()
679
+ )
680
+
681
+ # Wrap the joint_logp graph in an OpFromGraph, so that we can evaluate it at different
674
682
# values of the marginalized RV
675
683
# Some inputs are not root inputs (such as transformed projections of value variables)
676
684
# Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
@@ -698,6 +706,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
698
706
)
699
707
700
708
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
709
+ # TODO: Try vectorize here
701
710
if len (marginalized_rv_domain ) <= 10 :
702
711
joint_logps = [
703
712
joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
0 commit comments