Skip to content

Commit 9e7e8aa

Browse files
mikkelbuetwiecki
authored andcommitted
Fixed bug in delta_logp for MLDA that broke AEM and VR
1 parent 38295b7 commit 9e7e8aa

File tree

1 file changed

+8
-35
lines changed

1 file changed

+8
-35
lines changed

pymc/step_methods/mlda.py

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,14 @@ def __init__(self, *args, **kwargs):
6666
self.Q_last = np.nan
6767
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above
6868

69-
# extract some necessary variables
70-
vars = kwargs.get("vars", None)
71-
if vars is None:
72-
vars = model.value_vars
73-
else:
74-
vars = [model.rvs_to_values.get(var, var) for var in vars]
75-
vars = pm.inputvars(vars)
76-
shared = pm.make_shared_replacements(initial_values, vars, model)
77-
7869
# call parent class __init__
7970
super().__init__(*args, **kwargs)
8071

8172
# modify the delta function and point to model if VR is used
8273
if self.mlda_variance_reduction:
83-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
8474
self.model = model
75+
self.delta_logp_factory = self.delta_logp
76+
self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)
8577

8678
def reset_tuning(self):
8779
"""
@@ -136,22 +128,14 @@ def __init__(self, *args, **kwargs):
136128
self.Q_last = np.nan
137129
self.Q_reg = [np.nan] * self.mlda_subsampling_rate_above
138130

139-
# extract some necessary variables
140-
vars = kwargs.get("vars", None)
141-
if vars is None:
142-
vars = model.value_vars
143-
else:
144-
vars = [model.rvs_to_values.get(var, var) for var in vars]
145-
vars = pm.inputvars(vars)
146-
shared = pm.make_shared_replacements(initial_values, vars, model)
147-
148131
# call parent class __init__
149132
super().__init__(*args, **kwargs)
150133

151134
# modify the delta function and point to model if VR is used
152135
if self.mlda_variance_reduction:
153-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
154136
self.model = model
137+
self.delta_logp_factory = self.delta_logp
138+
self.delta_logp = lambda q, q0: -self.delta_logp_factory(q0, q)
155139

156140
def reset_tuning(self):
157141
"""Skips resetting of tuned sampler parameters
@@ -556,7 +540,7 @@ def __init__(
556540
# Construct Aesara function for current-level model likelihood
557541
# (for use in acceptance)
558542
shared = pm.make_shared_replacements(initial_values, vars, model)
559-
self.delta_logp = delta_logp_inverse(initial_values, model.logpt, vars, shared)
543+
self.delta_logp = delta_logp(initial_values, model.logpt, vars, shared)
560544

561545
# Construct Aesara function for below-level model likelihood
562546
# (for use in acceptance)
@@ -749,7 +733,9 @@ def astep(self, q0: RaveledVars) -> Tuple[RaveledVars, List[Dict[str, Any]]]:
749733
accept = np.float64(0.0)
750734
skipped_logp = True
751735
else:
752-
accept = self.delta_logp(q.data, q0.data) + self.delta_logp_below(q0.data, q.data)
736+
# NB! The order and sign of the first term are swapped compared
737+
# to the convention to make sure the proposal is evaluated last.
738+
accept = -self.delta_logp(q0.data, q.data) + self.delta_logp_below(q0.data, q.data)
753739
skipped_logp = False
754740

755741
# Accept/reject sample - next sample is stored in q_new
@@ -954,19 +940,6 @@ def update(self, x):
954940
self.t += 1
955941

956942

957-
def delta_logp_inverse(point, logp, vars, shared):
958-
[logp0], inarray0 = pm.join_nonshared_inputs(point, [logp], vars, shared)
959-
960-
tensor_type = inarray0.type
961-
inarray1 = tensor_type("inarray1")
962-
963-
logp1 = pm.CallableTensor(logp0)(inarray1)
964-
965-
f = compile_rv_inplace([inarray1, inarray0], -logp0 + logp1)
966-
f.trust_input = True
967-
return f
968-
969-
970943
def extract_Q_estimate(trace, levels):
971944
"""
972945
Returns expectation and standard error of quantity of interest,

0 commit comments

Comments
 (0)