diff --git a/pymc_experimental/distributions/timeseries.py b/pymc_experimental/distributions/timeseries.py index ddd5ff16..b591f7b0 100644 --- a/pymc_experimental/distributions/timeseries.py +++ b/pymc_experimental/distributions/timeseries.py @@ -19,7 +19,6 @@ ) from pymc.logprob.abstract import _logprob from pymc.logprob.basic import logp -from pymc.logprob.utils import ignore_logprob from pymc.pytensorf import intX from pymc.util import check_dist_not_registered from pytensor.graph.basic import Node @@ -166,9 +165,6 @@ def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwar k = P.shape[-1] init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k)) - # We can ignore init_dist, as it will be accounted for in the logp term - init_dist = ignore_logprob(init_dist) - return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs) @classmethod diff --git a/pymc_experimental/marginal_model.py b/pymc_experimental/marginal_model.py index ee74ae00..5596c4a0 100644 --- a/pymc_experimental/marginal_model.py +++ b/pymc_experimental/marginal_model.py @@ -6,8 +6,8 @@ from pymc import SymbolicRandomVariable from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain -from pymc.logprob.abstract import _get_measurable_outputs, _logprob -from pymc.logprob.basic import factorized_joint_logprob +from pymc.logprob.abstract import _logprob +from pymc.logprob.basic import conditional_logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model from pymc.pytensorf import constant_fold, inputvars @@ -371,12 +371,6 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs return rvs_to_marginalize, marginalized_rvs -@_get_measurable_outputs.register(FiniteDiscreteMarginalRV) -def _get_measurable_outputs_finite_discrete_marginal_rv(op, node): - # Marginalized RVs are not measurable - return node.outputs[1:] - - def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]: op = rv.owner.op if isinstance(op, Bernoulli): @@ -403,7 +397,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs): # Obtain the joint_logp graph of the inner RV graph inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs} - logps_dict = factorized_joint_logprob(rv_values=inner_rvs_to_values, **kwargs) + logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs) # Reduce logp dimensions corresponding to broadcasted variables joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]] diff --git a/requirements.txt b/requirements.txt index f00e5caa..4b306c2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pymc>=5.4.1 +pymc>=5.5.0 scikit-learn