Skip to content

Commit 19be124

Browse files
committed
Reintroduce dummy intermediate variables in implementation of TruncatedRV
Partially reverts 9d4a3d7 and 3888d53 The logprob derivation(s) in the icdf implementation of `Truncated` can duplicate nodes and cause spurious input variables to be marked as missing. We replace these by dummies so the graph above is hidden, and variables cannot be accidentally cloned/modified during logprob inference.
1 parent 43c5a8e commit 19be124

File tree

2 files changed

+44
-24
lines changed

2 files changed

+44
-24
lines changed

pymc/distributions/truncated.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -106,28 +106,34 @@ def rv_op(cls, dist, lower, upper, max_n_steps, *, size=None):
106106
]
107107
graph_inputs = [*rv_inputs, lower, upper]
108108

109-
rv = dist.owner.op.make_node(*rv_inputs).default_output()
109+
# Variables with `_` suffix identify dummy inputs for the OpFromGraph
110+
graph_inputs_ = [
111+
inp.type() if not isinstance(inp.type, RandomType) else inp for inp in graph_inputs
112+
]
113+
*rv_inputs_, lower_, upper_ = graph_inputs_
114+
115+
rv_ = dist.owner.op.make_node(*rv_inputs_).default_output()
110116

111117
# Try to use inverted cdf sampling
112118
# truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
113119
try:
114-
logcdf_lower, logcdf_upper = cls._create_logcdf_exprs(rv, rv, lower, upper)
120+
logcdf_lower_, logcdf_upper_ = TruncatedRV._create_logcdf_exprs(
121+
rv_, rv_, lower_, upper_
122+
)
115123
# We use the first RNG from the base RV, so we don't have to introduce a new one
116124
# This is not problematic because the RNG won't be used in the RV logcdf graph
117-
uniform_rng = next(inp for inp in rv_inputs if isinstance(inp.type, RandomType))
118-
uniform_next_rng, uniform = pt.random.uniform(
119-
pt.exp(logcdf_lower),
120-
pt.exp(logcdf_upper),
121-
rng=uniform_rng,
122-
size=rv.shape,
125+
uniform_rng_ = next(inp_ for inp_ in rv_inputs_ if isinstance(inp_.type, RandomType))
126+
uniform_next_rng_, uniform_ = pt.random.uniform(
127+
pt.exp(logcdf_lower_),
128+
pt.exp(logcdf_upper_),
129+
rng=uniform_rng_,
130+
size=rv_.shape,
123131
).owner.outputs
124-
# So icdf does not see the random graph of uniform
125-
uniform_type = uniform.type()
126-
truncated_rv = graph_replace(icdf(rv, uniform_type), {uniform_type: uniform})
132+
truncated_rv_ = icdf(rv_, uniform_, warn_rvs=False)
127133
return TruncatedRV(
128134
base_rv_op=dist.owner.op,
129-
inputs=graph_inputs,
130-
outputs=[truncated_rv, uniform_next_rng],
135+
inputs=graph_inputs_,
136+
outputs=[truncated_rv_, uniform_next_rng_],
131137
ndim_supp=0,
132138
max_n_steps=max_n_steps,
133139
)(*graph_inputs)
@@ -154,25 +160,25 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, *rv_inputs):
154160

155161
return (
156162
(truncated_rv, reject_draws),
157-
collect_default_updates(new_truncated_rv, inputs=rv_inputs),
163+
collect_default_updates(new_truncated_rv),
158164
until(~pt.any(reject_draws)),
159165
)
160166

161-
(truncated_rv, reject_draws_), updates = scan(
167+
(truncated_rv_, reject_draws_), updates = scan(
162168
loop_fn,
163169
outputs_info=[
164-
pt.zeros_like(rv),
165-
pt.ones_like(rv, dtype=bool),
170+
pt.zeros_like(rv_),
171+
pt.ones_like(rv_, dtype=bool),
166172
],
167-
non_sequences=[lower, upper, *rv_inputs],
173+
non_sequences=[lower_, upper_, *rv_inputs_],
168174
n_steps=max_n_steps,
169175
strict=True,
170176
)
171177

172-
truncated_rv = truncated_rv[-1]
173-
convergence = ~pt.any(reject_draws_[-1])
174-
truncated_rv = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
175-
truncated_rv, convergence
178+
truncated_rv_ = truncated_rv_[-1]
179+
convergence_ = ~pt.any(reject_draws_[-1])
180+
truncated_rv_ = TruncationCheck(f"Truncation did not converge in {max_n_steps} steps")(
181+
truncated_rv_, convergence_
176182
)
177183

178184
# Sort updates of each RNG so that they show in the same order as the input RNGs
@@ -184,8 +190,8 @@ def sort_updates(update):
184190

185191
return TruncatedRV(
186192
base_rv_op=dist.owner.op,
187-
inputs=graph_inputs,
188-
outputs=[truncated_rv, *next_rngs],
193+
inputs=graph_inputs_,
194+
outputs=[truncated_rv_, *next_rngs],
189195
ndim_supp=0,
190196
max_n_steps=max_n_steps,
191197
)(*graph_inputs)

tests/distributions/test_truncated.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,3 +585,17 @@ def test_truncated_identity_input(dist_op):
585585

586586
rv_out = Truncated.dist(dist=dist_op(mu_identity, 5), lower=0, upper=1)
587587
assert np.ptp(draw(rv_out, draws=500)) < 1
588+
589+
590+
@pytest.mark.parametrize("rv_op", [icdf_normal, rejection_normal])
591+
def test_truncated_custom_dist_indexed_argument(rv_op):
592+
# Regression test for https://github.com/pymc-devs/pymc/issues/7312
593+
594+
def dist(scale, size):
595+
return pt.exp(rv_op(scale=scale, size=size))
596+
597+
scale = Exponential.dist(scale=[1, 2, 3])
598+
latent = CustomDist.dist(scale[[0, 0, 1, 1, 2, 2]], dist=dist)
599+
rv_out = Truncated.dist(latent, upper=7)
600+
601+
assert np.ptp(draw(rv_out, draws=100)) < 7

0 commit comments

Comments
 (0)