Skip to content

Commit 28fdda5

Browse files
brandonwillardtwiecki
authored andcommitted
Make Metropolis, Slice, PGBART, MetropolisMLDA use point values
1 parent 622e12d commit 28fdda5

File tree

10 files changed

+199
-89
lines changed

10 files changed

+199
-89
lines changed

pymc3/aesaraf.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from typing import Dict, List
1415

1516
import aesara
1617
import numpy as np
@@ -222,7 +223,7 @@ def __hash__(self):
222223
return hash(type(self))
223224

224225

225-
def make_shared_replacements(vars, model):
226+
def make_shared_replacements(point, vars, model):
226227
"""
227228
Makes shared replacements for all *other* variables than the ones passed.
228229
@@ -231,6 +232,7 @@ def make_shared_replacements(vars, model):
231232
232233
Parameters
233234
----------
235+
point: dictionary mapping variable names to sample values
234236
vars: list of variables not to make shared
235237
model: model
236238
@@ -240,19 +242,24 @@ def make_shared_replacements(vars, model):
240242
"""
241243
othervars = set(model.vars) - set(vars)
242244
return {
243-
var: aesara.shared(
244-
var.tag.test_value, var.name + "_shared", broadcastable=var.broadcastable
245-
)
245+
var: aesara.shared(point[var.name], var.name + "_shared", broadcastable=var.broadcastable)
246246
for var in othervars
247247
}
248248

249249

250-
def join_nonshared_inputs(xs, vars, shared, make_shared=False):
250+
def join_nonshared_inputs(
251+
point: Dict[str, np.ndarray],
252+
xs: List[TensorVariable],
253+
vars: List[TensorVariable],
254+
shared,
255+
make_shared: bool = False,
256+
):
251257
"""
252258
Takes a list of aesara Variables and joins their non shared inputs into a single input.
253259
254260
Parameters
255261
----------
262+
point: a sample point
256263
xs: list of aesara tensors
257264
vars: list of variables to join
258265
@@ -271,17 +278,20 @@ def join_nonshared_inputs(xs, vars, shared, make_shared=False):
271278
tensor_type = joined.type
272279
inarray = tensor_type("inarray")
273280
else:
274-
inarray = aesara.shared(joined.tag.test_value, "inarray")
281+
if point is None:
282+
raise ValueError("A point is required when `make_shared` is True")
283+
joined_values = np.concatenate([point[var.name].ravel() for var in vars])
284+
inarray = aesara.shared(joined_values, "inarray")
275285

276-
inarray.tag.test_value = joined.tag.test_value
286+
if aesara.config.compute_test_value != "off":
287+
inarray.tag.test_value = joined.tag.test_value
277288

278289
replace = {}
279290
last_idx = 0
280291
for var in vars:
281-
arr_len = at.prod(var.shape)
282-
replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], var.shape).astype(
283-
var.dtype
284-
)
292+
shape = point[var.name].shape
293+
arr_len = np.prod(shape, dtype=int)
294+
replace[var] = reshape_t(inarray[last_idx : last_idx + arr_len], shape).astype(var.dtype)
285295
last_idx += arr_len
286296

287297
replace.update(shared)

pymc3/smc/smc.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,12 +108,15 @@ def initialize_population(self):
108108

109109
def setup_kernel(self):
110110
"""Set up the likelihood logp function based on the chosen kernel."""
111-
shared = make_shared_replacements(self.variables, self.model)
111+
initial_values = self.model.test_point
112+
shared = make_shared_replacements(initial_values, self.variables, self.model)
112113

113114
if self.kernel == "abc":
114115
factors = [var.logpt for var in self.model.free_RVs]
115116
factors += [at.sum(factor) for factor in self.model.potentials]
116-
self.prior_logp_func = logp_forw([at.sum(factors)], self.variables, shared)
117+
self.prior_logp_func = logp_forw(
118+
initial_values, [at.sum(factors)], self.variables, shared
119+
)
117120
simulator = self.model.observed_RVs[0]
118121
distance = simulator.distribution.distance
119122
sum_stat = simulator.distribution.sum_stat
@@ -132,8 +135,12 @@ def setup_kernel(self):
132135
self.save_log_pseudolikelihood,
133136
)
134137
elif self.kernel == "metropolis":
135-
self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared)
136-
self.likelihood_logp_func = logp_forw([self.model.datalogpt], self.variables, shared)
138+
self.prior_logp_func = logp_forw(
139+
initial_values, [self.model.varlogpt], self.variables, shared
140+
)
141+
self.likelihood_logp_func = logp_forw(
142+
initial_values, [self.model.datalogpt], self.variables, shared
143+
)
137144

138145
def initialize_logp(self):
139146
"""Initialize the prior and likelihood log probabilities."""
@@ -271,7 +278,7 @@ def posterior_to_trace(self):
271278
return strace
272279

273280

274-
def logp_forw(out_vars, vars, shared):
281+
def logp_forw(point, out_vars, vars, shared):
275282
"""Compile Aesara function of the model and the input and output variables.
276283
277284
Parameters
@@ -283,7 +290,7 @@ def logp_forw(out_vars, vars, shared):
283290
shared: List
284291
containing :class:`aesara.tensor.Tensor` for depended shared data
285292
"""
286-
out_list, inarray0 = join_nonshared_inputs(out_vars, vars, shared)
293+
out_list, inarray0 = join_nonshared_inputs(point, out_vars, vars, shared)
287294
f = aesara_function([inarray0], out_list[0])
288295
f.trust_input = True
289296
return f

pymc3/step_methods/arraystep.py

Lines changed: 48 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,12 +145,24 @@ def step(self, point: Dict[str, np.ndarray]):
145145
if self.allvars:
146146
inputs.append(point)
147147

148+
apoint = DictToArrayBijection.map(point)
149+
step_res = self.astep(apoint, *inputs)
150+
148151
if self.generates_stats:
149-
apoint, stats = self.astep(DictToArrayBijection.map(point), *inputs)
150-
return DictToArrayBijection.rmap(apoint), stats
152+
apoint_new, stats = step_res
151153
else:
152-
apoint = self.astep(DictToArrayBijection.map(point), *inputs)
153-
return DictToArrayBijection.rmap(apoint)
154+
apoint_new = step_res
155+
156+
if not isinstance(apoint_new, RaveledVars):
157+
# We assume that the mapping has stayed the same
158+
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)
159+
160+
point_new = DictToArrayBijection.rmap(apoint_new)
161+
162+
if self.generates_stats:
163+
return point_new, stats
164+
165+
return point_new
154166

155167
def astep(self, apoint, point):
156168
raise NotImplementedError()
@@ -177,19 +189,41 @@ def __init__(self, vars, shared, blocked=True):
177189
self.blocked = blocked
178190

179191
def step(self, point):
180-
for var, share in self.shared.items():
181-
share.set_value(point[var])
192+
193+
# Remove shared variables from the sample point
194+
point_no_shared = point.copy()
195+
for name, shared_var in self.shared.items():
196+
shared_var.set_value(point[name])
197+
if name in point_no_shared:
198+
del point_no_shared[name]
199+
200+
q = DictToArrayBijection.map(point_no_shared)
201+
202+
step_res = self.astep(q)
182203

183204
if self.generates_stats:
184-
apoint, stats = self.astep(DictToArrayBijection.map(point))
185-
return DictToArrayBijection.rmap(apoint), stats
205+
apoint, stats = step_res
186206
else:
187-
array = DictToArrayBijection.map(point)
188-
apoint = self.astep(array)
189-
if not isinstance(apoint, RaveledVars):
190-
# We assume that the mapping has stayed the same
191-
apoint = RaveledVars(apoint, array.point_map_info)
192-
return DictToArrayBijection.rmap(apoint)
207+
apoint = step_res
208+
209+
if not isinstance(apoint, RaveledVars):
210+
# We assume that the mapping has stayed the same
211+
apoint = RaveledVars(apoint, q.point_map_info)
212+
213+
# We need to re-add the shared variables to the new sample point
214+
a_point = DictToArrayBijection.rmap(apoint)
215+
new_point = {}
216+
for name in point.keys():
217+
shared_value = self.shared.get(name, None)
218+
if shared_value is not None:
219+
new_point[name] = shared_value.get_value()
220+
else:
221+
new_point[name] = a_point[name]
222+
223+
if self.generates_stats:
224+
return new_point, stats
225+
226+
return new_point
193227

194228
def astep(self, apoint):
195229
raise NotImplementedError()

0 commit comments

Comments
 (0)