Skip to content

Commit a7f63e2

Browse files
committed
Change to providing logprobs for marginalised variables in
the transformed space
1 parent f5369c5 commit a7f63e2

File tree

2 files changed

+36
-11
lines changed

2 files changed

+36
-11
lines changed

pymc_experimental/model/marginal_model.py

+33-10
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,24 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
255255
# Raise errors and warnings immediately
256256
self.clone()._marginalize(user_warnings=True)
257257

258+
def _to_transformed(self):
259+
"Create a function from the untransformed space to the transformed space"
260+
transformed_rvs = []
261+
transformed_names = []
262+
263+
for rv in self.free_RVs:
264+
transform = self.rvs_to_transforms.get(rv)
265+
if transform is None:
266+
transformed_rvs.append(rv)
267+
transformed_names.append(rv.name)
268+
else:
269+
transformed_rv = transform.forward(rv, *rv.owner.inputs)
270+
transformed_rvs.append(transformed_rv)
271+
transformed_names.append(self.rvs_to_values[rv].name)
272+
273+
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
274+
return fn, transformed_names
275+
258276
def unmarginalize(self, rvs_to_unmarginalize):
259277
for rv in rvs_to_unmarginalize:
260278
self.marginalized_rvs.remove(rv)
@@ -263,12 +281,14 @@ def unmarginalize(self, rvs_to_unmarginalize):
263281
def recover_marginals(
264282
self, idata, var_names=None, include_samples=False, extend_inferencedata=True
265283
):
266-
"""Computes log-likelihoods of marginalized variables conditioned on parameters
267-
of the model given InferenceData with posterior group
284+
"""Computes unnormalized posterior probabilities of marginalized variables
285+
conditioned on parameters of the model given InferenceData with posterior group
268286
269287
When there are multiple marginalized variables, each marginalized variable is
270288
conditioned on both the parameters and the other variables still marginalized
271289
290+
All log-probabilities are within the transformed space
291+
272292
Parameters
273293
----------
274294
idata : InferenceData
@@ -298,18 +318,20 @@ def recover_marginals(
298318

299319
sample_dims = ("chain", "draw")
300320
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
321+
322+
# Handle Transforms
323+
transform_fn, transform_names = self._to_transformed()
324+
325+
def transform_input(inputs):
326+
return dict(zip(transform_names, transform_fn(inputs)))
327+
328+
posterior_pts = [transform_input(vs) for vs in posterior_pts]
329+
301330
rv_dict = {}
302331
rv_dims_dict = {}
303332

304-
# Disable all transforms
305-
model = self.clone()
306-
model.rvs_to_transforms = {k: None for k in model.rvs_to_transforms}
307-
for rv in model.free_RVs:
308-
model.rvs_to_values[rv].name = rv.name
309-
var_names = [model.vars_to_clone[rv] for rv in var_names]
310-
311333
for rv in var_names:
312-
m = model.clone()
334+
m = self.clone()
313335
rv = m.vars_to_clone[rv]
314336
m.unmarginalize([rv])
315337
joint_logp = m.logp()
@@ -331,6 +353,7 @@ def recover_marginals(
331353
other_values = [v for v in m.value_vars if v is not marginalized_value]
332354

333355
# TODO: Handle constants
356+
334357
joint_logps = vectorize_graph(
335358
joint_logp,
336359
replace={marginalized_value: rv_domain_tensor},

pymc_experimental/tests/model/test_marginal_model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,9 @@ def test_recover_marginals_basic():
283283
def true_logp(y, sigma):
284284
y = y.repeat(len(p)).reshape(len(y), -1)
285285
sigma = sigma.repeat(len(p)).reshape(len(sigma), -1)
286-
return np.log(p) + norm.logpdf(y, loc=mu, scale=sigma) + halfnorm.logpdf(sigma)
286+
return (
287+
np.log(p) + norm.logpdf(y, loc=mu, scale=sigma) + halfnorm.logpdf(sigma) + np.log(sigma)
288+
)
287289

288290
np.testing.assert_almost_equal(
289291
true_logp(idata.y.values.flatten(), idata.sigma.values.flatten()),

0 commit comments

Comments
 (0)