6
6
from pymc import SymbolicRandomVariable
7
7
from pymc .distributions .discrete import Bernoulli , Categorical , DiscreteUniform
8
8
from pymc .distributions .transforms import Chain
9
- from pymc .logprob .abstract import _get_measurable_outputs , _logprob
10
- from pymc .logprob .basic import factorized_joint_logprob
9
+ from pymc .logprob .abstract import _logprob
10
+ from pymc .logprob .basic import conditional_logp
11
11
from pymc .logprob .transforms import IntervalTransform
12
12
from pymc .model import Model
13
13
from pymc .pytensorf import constant_fold , inputvars
@@ -371,12 +371,6 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
371
371
return rvs_to_marginalize , marginalized_rvs
372
372
373
373
374
- @_get_measurable_outputs .register (FiniteDiscreteMarginalRV )
375
- def _get_measurable_outputs_finite_discrete_marginal_rv (op , node ):
376
- # Marginalized RVs are not measurable
377
- return node .outputs [1 :]
378
-
379
-
380
374
def get_domain_of_finite_discrete_rv (rv : TensorVariable ) -> Tuple [int , ...]:
381
375
op = rv .owner .op
382
376
if isinstance (op , Bernoulli ):
@@ -403,7 +397,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
403
397
404
398
# Obtain the joint_logp graph of the inner RV graph
405
399
inner_rvs_to_values = {rv : rv .clone () for rv in inner_rvs }
406
- logps_dict = factorized_joint_logprob (rv_values = inner_rvs_to_values , ** kwargs )
400
+ logps_dict = conditional_logp (rv_values = inner_rvs_to_values , ** kwargs )
407
401
408
402
# Reduce logp dimensions corresponding to broadcasted variables
409
403
joint_logp = logps_dict [inner_rvs_to_values [marginalized_rv ]]
0 commit comments