Skip to content

Commit 84c38b2

Browse files
Update to pymc 5.5.0 (#191)
* removing not existing dependency * removing get_measurable_outputs import * bumping pymc version to 5.5.0 * changing factorized_joint_logp to conditional_logp
1 parent baa9363 commit 84c38b2

File tree

3 files changed

+4
-14
lines changed

3 files changed

+4
-14
lines changed

pymc_experimental/distributions/timeseries.py

-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
)
2020
from pymc.logprob.abstract import _logprob
2121
from pymc.logprob.basic import logp
22-
from pymc.logprob.utils import ignore_logprob
2322
from pymc.pytensorf import intX
2423
from pymc.util import check_dist_not_registered
2524
from pytensor.graph.basic import Node
@@ -166,9 +165,6 @@ def dist(cls, P=None, logit_P=None, steps=None, init_dist=None, n_lags=1, **kwar
166165
k = P.shape[-1]
167166
init_dist = pm.Categorical.dist(p=pt.full((k,), 1 / k))
168167

169-
# We can ignore init_dist, as it will be accounted for in the logp term
170-
init_dist = ignore_logprob(init_dist)
171-
172168
return super().dist([P, steps, init_dist], n_lags=n_lags, **kwargs)
173169

174170
@classmethod

pymc_experimental/marginal_model.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
from pymc import SymbolicRandomVariable
77
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
88
from pymc.distributions.transforms import Chain
9-
from pymc.logprob.abstract import _get_measurable_outputs, _logprob
10-
from pymc.logprob.basic import factorized_joint_logprob
9+
from pymc.logprob.abstract import _logprob
10+
from pymc.logprob.basic import conditional_logp
1111
from pymc.logprob.transforms import IntervalTransform
1212
from pymc.model import Model
1313
from pymc.pytensorf import constant_fold, inputvars
@@ -371,12 +371,6 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
371371
return rvs_to_marginalize, marginalized_rvs
372372

373373

374-
@_get_measurable_outputs.register(FiniteDiscreteMarginalRV)
375-
def _get_measurable_outputs_finite_discrete_marginal_rv(op, node):
376-
# Marginalized RVs are not measurable
377-
return node.outputs[1:]
378-
379-
380374
def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> Tuple[int, ...]:
381375
op = rv.owner.op
382376
if isinstance(op, Bernoulli):
@@ -403,7 +397,7 @@ def finite_discrete_marginal_rv_logp(op, values, *inputs, **kwargs):
403397

404398
# Obtain the joint_logp graph of the inner RV graph
405399
inner_rvs_to_values = {rv: rv.clone() for rv in inner_rvs}
406-
logps_dict = factorized_joint_logprob(rv_values=inner_rvs_to_values, **kwargs)
400+
logps_dict = conditional_logp(rv_values=inner_rvs_to_values, **kwargs)
407401

408402
# Reduce logp dimensions corresponding to broadcasted variables
409403
joint_logp = logps_dict[inner_rvs_to_values[marginalized_rv]]

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pymc>=5.4.1
1+
pymc>=5.5.0
22
scikit-learn

0 commit comments

Comments
 (0)