Skip to content

Commit 8494887

Browse files
committed
Add transform handling code
1 parent 28dfb0f commit 8494887

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

pymc_experimental/model/marginal_model.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,25 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
255255
# Raise errors and warnings immediately
256256
self.clone()._marginalize(user_warnings=True)
257257

258+
def _transform_input(self, inputs):
259+
"Create a function from the untransformed space to the transformed space"
260+
transformed_rvs = []
261+
transformed_names = []
262+
263+
for rv in self.free_RVs:
264+
transform = self.rvs_to_transforms.get(rv)
265+
if transform is None:
266+
transformed_rvs.append(rv)
267+
transformed_names.append(rv.name)
268+
else:
269+
transformed_rv = transform.forward(rv, *rv.owner.inputs)
270+
transformed_rvs.append(transformed_rv)
271+
transformed_names.append(self.rvs_to_values[rv].name)
272+
273+
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
274+
rets = fn(inputs)
275+
return dict(zip(transformed_names, rets))
276+
258277
def unmarginalize(self, rvs_to_unmarginalize):
259278
for rv in rvs_to_unmarginalize:
260279
self.marginalized_rvs.remove(rv)
@@ -324,7 +343,6 @@ def recover_marginals(
324343
other_values = [v for v in m.value_vars if v is not marginalized_value]
325344

326345
# TODO: Handle constants
327-
# TODO: Handle transformed variables
328346

329347
joint_logps = vectorize_graph(
330348
joint_logp,
@@ -346,7 +364,7 @@ def recover_marginals(
346364
on_unused_input="ignore",
347365
)
348366

349-
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
367+
logvs = [rv_loglike_fn(**self._transform_input(vs)) for vs in posterior_pts]
350368

351369
if include_samples:
352370
logps, samples = zip(*logvs)

0 commit comments

Comments
 (0)