@@ -395,44 +395,38 @@ def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
395
395
def finite_discrete_marginal_rv_logp (op , values , * inputs , ** kwargs ):
396
396
# Clone the inner RV graph of the Marginalized RV
397
397
marginalized_rvs_node = op .make_node (* inputs )
398
- marginalized_rv , * dependent_rvs = clone_replace (
398
+ inner_rvs = clone_replace (
399
399
op .inner_outputs ,
400
400
replace = {u : v for u , v in zip (op .inner_inputs , marginalized_rvs_node .inputs )},
401
401
)
402
+ marginalized_rv = inner_rvs [0 ]
402
403
403
404
# Obtain the joint_logp graph of the inner RV graph
404
- # Some inputs are not root inputs (such as transformed projections of value variables)
405
- # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
406
- inputs = list (inputvars (inputs ))
407
- rvs_to_values = {}
408
- dummy_marginalized_value = marginalized_rv .clone ()
409
- rvs_to_values [marginalized_rv ] = dummy_marginalized_value
410
- rvs_to_values .update (zip (dependent_rvs , values ))
411
- logps_dict = factorized_joint_logprob (rv_values = rvs_to_values , ** kwargs )
405
+ inner_rvs_to_values = {rv : rv .clone () for rv in inner_rvs }
406
+ logps_dict = factorized_joint_logprob (rv_values = inner_rvs_to_values , ** kwargs )
412
407
413
408
# Reduce logp dimensions corresponding to broadcasted variables
414
- values_axis_bcast = []
415
- for value in values :
416
- vbcast = value .type .broadcastable
417
- mbcast = dummy_marginalized_value .type .broadcastable
409
+ joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
410
+ for inner_rv , inner_value in inner_rvs_to_values .items ():
411
+ if inner_rv is marginalized_rv :
412
+ continue
413
+ vbcast = inner_value .type .broadcastable
414
+ mbcast = marginalized_rv .type .broadcastable
418
415
mbcast = (True ,) * (len (vbcast ) - len (mbcast )) + mbcast
419
- values_axis_bcast .append ([i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ])
420
- joint_logp = logps_dict [dummy_marginalized_value ]
421
- for value , values_axis_bcast in zip (values , values_axis_bcast ):
422
- joint_logp += logps_dict [value ].sum (values_axis_bcast , keepdims = True )
416
+ values_axis_bcast = [i for i , (m , v ) in enumerate (zip (mbcast , vbcast )) if m != v ]
417
+ joint_logp += logps_dict [inner_value ].sum (values_axis_bcast , keepdims = True )
423
418
424
419
# Wrap the joint_logp graph in an OpFromGrah, so that we can evaluate it at different
425
420
# values of the marginalized RV
426
- # OpFromGraph does not accept constant inputs
427
- non_const_values = [
428
- value
429
- for value in rvs_to_values .values ()
430
- if not isinstance (value , (Constant , SharedVariable ))
431
- ]
432
- joint_logp_op = OpFromGraph ([* non_const_values , * inputs ], [joint_logp ], inline = True )
421
+ # Some inputs are not root inputs (such as transformed projections of value variables)
422
+ # Or cannot be used as inputs to an OpFromGraph (shared variables and constants)
423
+ inputs = list (inputvars (inputs ))
424
+ joint_logp_op = OpFromGraph (
425
+ list (inner_rvs_to_values .values ()) + inputs , [joint_logp ], inline = True
426
+ )
433
427
434
428
# Compute the joint_logp for all possible n values of the marginalized RV. We assume
435
- # each original dimension is independent so that it sufficies to evaluate the graph
429
+ # each original dimension is independent so that it suffices to evaluate the graph
436
430
# n times, once with each possible value of the marginalized RV replicated across
437
431
# batched dimensions of the marginalized RV
438
432
@@ -449,18 +443,14 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
449
443
axis2 = - 1 ,
450
444
)
451
445
452
- # OpFromGraph does not accept constant inputs
453
- non_const_values = [
454
- value for value in values if not isinstance (value , (Constant , SharedVariable ))
455
- ]
456
446
# Arbitrary cutoff to switch to Scan implementation to keep graph size under control
457
447
if len (marginalized_rv_domain ) <= 10 :
458
448
joint_logps = [
459
- joint_logp_op (marginalized_rv_domain_tensor [i ], * non_const_values , * inputs )
449
+ joint_logp_op (marginalized_rv_domain_tensor [i ], * values , * inputs )
460
450
for i in range (len (marginalized_rv_domain ))
461
451
]
462
452
else :
463
- # Make sure this is rewrite is registered
453
+ # Make sure this rewrite is registered
464
454
from pymc .pytensorf import local_remove_check_parameter
465
455
466
456
def logp_fn (marginalized_rv_const , * non_sequences ):
@@ -469,7 +459,7 @@ def logp_fn(marginalized_rv_const, *non_sequences):
469
459
joint_logps , _ = scan_map (
470
460
fn = logp_fn ,
471
461
sequences = marginalized_rv_domain_tensor ,
472
- non_sequences = [* non_const_values , * inputs ],
462
+ non_sequences = [* values , * inputs ],
473
463
mode = Mode ().including ("local_remove_check_parameter" ),
474
464
)
475
465
0 commit comments