Skip to content

Commit c434469

Browse files
committed
Replace deprecated tag.ignore_logprob
1 parent a94221e commit c434469

File tree

6 files changed

+72
-15
lines changed

6 files changed

+72
-15
lines changed

pymc/distributions/bound.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from pymc.distributions.continuous import BoundedContinuous, bounded_cont_transform
2424
from pymc.distributions.dist_math import check_parameters
2525
from pymc.distributions.distribution import Continuous, Discrete
26-
from pymc.distributions.logprob import logp
26+
from pymc.distributions.logprob import ignore_logprob, logp
2727
from pymc.distributions.shape_utils import to_tuple
2828
from pymc.distributions.transforms import _default_transform
2929
from pymc.model import modelcontext
@@ -193,7 +193,7 @@ def __new__(
193193
raise ValueError("Given dims do not exist in model coordinates.")
194194

195195
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
196-
dist.tag.ignore_logprob = True
196+
dist = ignore_logprob(dist)
197197

198198
if isinstance(dist.owner.op, Continuous):
199199
res = _ContinuousBounded(
@@ -228,7 +228,7 @@ def dist(
228228

229229
cls._argument_checks(dist, **kwargs)
230230
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
231-
dist.tag.ignore_logprob = True
231+
dist = ignore_logprob(dist)
232232
if isinstance(dist.owner.op, Continuous):
233233
res = _ContinuousBounded.dist(
234234
[dist, lower, upper],

pymc/distributions/logprob.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import numpy as np
2121

2222
from aeppl import factorized_joint_logprob
23+
from aeppl.abstract import assign_custom_measurable_outputs
2324
from aeppl.logprob import logcdf as logcdf_aeppl
2425
from aeppl.logprob import logprob as logp_aeppl
2526
from aeppl.transforms import TransformValuesOpt
@@ -221,7 +222,11 @@ def joint_logpt(
221222

222223
transform_opt = TransformValuesOpt(transform_map)
223224
temp_logp_var_dict = factorized_joint_logprob(
224-
tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs
225+
tmp_rvs_to_values,
226+
extra_rewrites=transform_opt,
227+
use_jacobian=jacobian,
228+
warn_missing_rvs=False,
229+
**kwargs,
225230
)
226231

227232
# Raise if there are unexpected RandomVariables in the logp graph
@@ -276,3 +281,20 @@ def logcdf(rv, value):
276281

277282
value = at.as_tensor_variable(value, dtype=rv.dtype)
278283
return logcdf_aeppl(rv, value)
284+
285+
286+
def ignore_logprob(rv):
287+
"""Return a duplicated variable that is ignored when creating Aeppl logprob graphs
288+
289+
This is used in SymbolicDistributions that use other RVs as inputs but account
290+
for their logp terms explicitly.
291+
292+
If the variable is already ignored, it is returned directly.
293+
"""
294+
prefix = "Unmeasurable"
295+
node = rv.owner
296+
op_type = type(node.op)
297+
if op_type.__name__.startswith(prefix):
298+
return rv
299+
new_node = assign_custom_measurable_outputs(node, type_prefix=prefix)
300+
return new_node.outputs[node.outputs.index(rv)]

pymc/distributions/mixture.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pymc.distributions.continuous import Normal, get_tau_sigma
3131
from pymc.distributions.dist_math import check_parameters
3232
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
33-
from pymc.distributions.logprob import logcdf, logp
33+
from pymc.distributions.logprob import ignore_logprob, logcdf, logp
3434
from pymc.distributions.shape_utils import to_tuple
3535
from pymc.distributions.transforms import _default_transform
3636
from pymc.util import check_dist_not_registered
@@ -252,6 +252,10 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
252252

253253
assert weights_ndim_batch == 0
254254

255+
# Component RVs terms are accounted by the Mixture logprob, so they can be
256+
# safely ignored by Aeppl
257+
components = [ignore_logprob(component) for component in components]
258+
255259
# Create a OpFromGraph that encapsulates the random generating process
256260
# Create dummy input variables with the same type as the ones provided
257261
weights_ = weights.type()
@@ -299,11 +303,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
299303
mix_out.tag.components = components
300304
mix_out.tag.choices_rng = mix_indexes_rng
301305

302-
# Component RVs terms are accounted by the Mixture logprob, so they can be
303-
# safely ignore by Aeppl (this tag prevents UserWarning)
304-
for component in components:
305-
component.tag.ignore_logprob = True
306-
307306
return mix_out
308307

309308
@classmethod

pymc/distributions/multivariate.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
multigammaln,
5858
)
5959
from pymc.distributions.distribution import Continuous, Discrete, moment
60+
from pymc.distributions.logprob import ignore_logprob
6061
from pymc.distributions.shape_utils import (
6162
broadcast_dist_samples_to,
6263
rv_size_is_none,
@@ -1182,11 +1183,9 @@ def dist(cls, eta, n, sd_dist, **kwargs):
11821183

11831184
# sd_dist is part of the generative graph, but should be completely ignored
11841185
# by the logp graph, since the LKJ logp explicitly includes these terms.
1185-
# Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about
1186-
# an unnacounted RandomVariable in the graph
11871186
# TODO: Things could be simplified a bit if we managed to extract the
11881187
# sd_dist prior components from the logp expression.
1189-
sd_dist.tag.ignore_logprob = True
1188+
sd_dist = ignore_logprob(sd_dist)
11901189

11911190
return super().dist([n, eta, sd_dist], **kwargs)
11921191

pymc/distributions/timeseries.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from pymc.distributions import distribution, logprob, multivariate
2727
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
2828
from pymc.distributions.dist_math import check_parameters
29+
from pymc.distributions.logprob import ignore_logprob
2930
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
3031
from pymc.util import check_dist_not_registered
3132

@@ -206,7 +207,7 @@ def dist(
206207
raise TypeError("init must be a univariate distribution variable")
207208

208209
# Ignores logprob of init var because that's accounted for in the logp method
209-
init.tag.ignore_logprob = True
210+
init = ignore_logprob(init)
210211

211212
return super().dist([mu, sigma, init, steps], size=size, **kwargs)
212213

pymc/tests/test_logprob.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818
import scipy.stats.distributions as sp
1919

20+
from aeppl.abstract import get_measurable_outputs
2021
from aesara.graph.basic import ancestors
2122
from aesara.tensor.random.op import RandomVariable
2223
from aesara.tensor.subtensor import (
@@ -32,7 +33,7 @@
3233
from pymc.aesaraf import floatX, walk_model
3334
from pymc.distributions.continuous import HalfFlat, Normal, TruncatedNormal, Uniform
3435
from pymc.distributions.discrete import Bernoulli
35-
from pymc.distributions.logprob import joint_logpt, logcdf, logp
36+
from pymc.distributions.logprob import ignore_logprob, joint_logpt, logcdf, logp
3637
from pymc.model import Model, Potential
3738
from pymc.tests.helpers import select_by_precision
3839

@@ -227,3 +228,38 @@ def test_unexpected_rvs():
227228

228229
with pytest.raises(ValueError, match="^Random variables detected in the logp graph"):
229230
model.logpt()
231+
232+
233+
def test_ignore_logprob_basic():
234+
x = Normal.dist()
235+
(measurable_x_out,) = get_measurable_outputs(x.owner.op, x.owner)
236+
assert measurable_x_out is x.owner.outputs[1]
237+
238+
new_x = ignore_logprob(x)
239+
assert new_x is not x
240+
assert isinstance(new_x.owner.op, Normal)
241+
assert type(new_x.owner.op).__name__ == "UnmeasurableNormalRV"
242+
# Confirm that it does not have measurable output
243+
assert get_measurable_outputs(new_x.owner.op, new_x.owner) is None
244+
245+
# Test that it will not clone a variable that is already unmeasurable
246+
new_new_x = ignore_logprob(new_x)
247+
assert new_new_x is new_x
248+
249+
250+
def test_ignore_logprob_model():
251+
# logp that does not depend on input
252+
def logp(value, x):
253+
return value
254+
255+
with Model() as m:
256+
x = Normal.dist()
257+
y = DensityDist("y", x, logp=logp)
258+
# Aeppl raises a KeyError when it finds an unexpected RV
259+
with pytest.raises(KeyError):
260+
joint_logpt([y], {y: y.type()})
261+
262+
with Model() as m:
263+
x = ignore_logprob(Normal.dist())
264+
y = DensityDist("y", x, logp=logp)
265+
assert joint_logpt([y], {y: y.type()})

0 commit comments

Comments
 (0)