Skip to content

Commit f2bb88b

Browse files
committed
Make metropolis elemwise updates independent of each other
Not updating q0 after each elemwise update rendered subsequent proposals dependent on the previous ones.
1 parent 1ed4475 commit f2bb88b

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

pymc/step_methods/metropolis.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,14 +257,15 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, StatsType]:
257257
q = floatX(q0d + delta)
258258

259259
if self.elemwise_update:
260+
q0d = q0d.copy()
260261
q_temp = q0d.copy()
261262
# Shuffle order of updates (probably we don't need to do this in every step)
262263
np.random.shuffle(self.enum_dims)
263264
for i in self.enum_dims:
264265
q_temp[i] = q[i]
265266
accept_rate_i = self.delta_logp(q_temp, q0d)
266267
q_temp_, accepted_i = metrop_select(accept_rate_i, q_temp, q0d)
267-
q_temp[i] = q_temp_[i]
268+
q_temp[i] = q0d[i] = q_temp_[i]
268269
self.accept_rate_iter[i] = accept_rate_i
269270
self.accepted_iter[i] = accepted_i
270271
self.accepted_sum[i] += accepted_i

tests/step_methods/test_metropolis.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,16 @@ def test_elemwise_update(self, batched_dist):
120120
assert az.rhat(trace).max()["batched_dist"].values < 1.1
121121
assert az.ess(trace).min()["batched_dist"].values > 50
122122

123+
def test_elemwise_update_different_scales(self):
124+
mu = [1, 2, 3, 4, 5, 100, 1_000, 10_000]
125+
with pm.Model() as m:
126+
x = pm.Poisson("x", mu=mu)
127+
step = pm.Metropolis([x])
128+
trace = pm.sample(draws=1000, chains=2, step=step, random_seed=128).posterior
129+
130+
np.testing.assert_allclose(trace["x"].mean(("draw", "chain")), mu, rtol=0.1)
131+
np.testing.assert_allclose(trace["x"].var(("draw", "chain")), mu, rtol=0.2)
132+
123133
def test_multinomial_no_elemwise_update(self):
124134
with pm.Model() as m:
125135
batched_dist = pm.Multinomial("batched_dist", n=5, p=np.ones(4) / 4, shape=(10, 4))

0 commit comments

Comments
 (0)