Skip to content

Commit f5369c5

Browse files
committed
Don't have to worry about transforms if you never transform
1 parent 8494887 commit f5369c5

File tree

2 files changed

+10
-23
lines changed

2 files changed

+10
-23
lines changed

pymc_experimental/model/marginal_model.py

+9-22
Original file line numberDiff line numberDiff line change
@@ -255,25 +255,6 @@ def marginalize(self, rvs_to_marginalize: Union[TensorVariable, Sequence[TensorV
255255
# Raise errors and warnings immediately
256256
self.clone()._marginalize(user_warnings=True)
257257

258-
def _transform_input(self, inputs):
259-
"Create a function from the untransformed space to the transformed space"
260-
transformed_rvs = []
261-
transformed_names = []
262-
263-
for rv in self.free_RVs:
264-
transform = self.rvs_to_transforms.get(rv)
265-
if transform is None:
266-
transformed_rvs.append(rv)
267-
transformed_names.append(rv.name)
268-
else:
269-
transformed_rv = transform.forward(rv, *rv.owner.inputs)
270-
transformed_rvs.append(transformed_rv)
271-
transformed_names.append(self.rvs_to_values[rv].name)
272-
273-
fn = self.compile_fn(inputs=self.free_RVs, outs=transformed_rvs)
274-
rets = fn(inputs)
275-
return dict(zip(transformed_names, rets))
276-
277258
def unmarginalize(self, rvs_to_unmarginalize):
278259
for rv in rvs_to_unmarginalize:
279260
self.marginalized_rvs.remove(rv)
@@ -320,8 +301,15 @@ def recover_marginals(
320301
rv_dict = {}
321302
rv_dims_dict = {}
322303

304+
# Disable all transforms
305+
model = self.clone()
306+
model.rvs_to_transforms = {k: None for k in model.rvs_to_transforms}
307+
for rv in model.free_RVs:
308+
model.rvs_to_values[rv].name = rv.name
309+
var_names = [model.vars_to_clone[rv] for rv in var_names]
310+
323311
for rv in var_names:
324-
m = self.clone()
312+
m = model.clone()
325313
rv = m.vars_to_clone[rv]
326314
m.unmarginalize([rv])
327315
joint_logp = m.logp()
@@ -343,7 +331,6 @@ def recover_marginals(
343331
other_values = [v for v in m.value_vars if v is not marginalized_value]
344332

345333
# TODO: Handle constants
346-
347334
joint_logps = vectorize_graph(
348335
joint_logp,
349336
replace={marginalized_value: rv_domain_tensor},
@@ -364,7 +351,7 @@ def recover_marginals(
364351
on_unused_input="ignore",
365352
)
366353

367-
logvs = [rv_loglike_fn(**self._transform_input(vs)) for vs in posterior_pts]
354+
logvs = [rv_loglike_fn(**vs) for vs in posterior_pts]
368355

369356
if include_samples:
370357
logps, samples = zip(*logvs)

pymc_experimental/tests/model/test_marginal_model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def test_marginalized_change_point_model_sampling(disaster_model):
255255

256256
def test_recover_marginals_basic():
257257
with MarginalModel() as m:
258-
sigma = pm.HalfNormal("sigma", transform=None)
258+
sigma = pm.HalfNormal("sigma")
259259
p = np.array([0.5, 0.2, 0.3])
260260
k = pm.Categorical("k", p=p)
261261
mu = np.array([-3.0, 0.0, 3.0])

0 commit comments

Comments
 (0)