Skip to content

Commit e3f5828

Browse files
committed
Implement utility to change value variable transforms
1 parent 430c3c8 commit e3f5828

File tree

3 files changed

+212
-3
lines changed

3 files changed

+212
-3
lines changed

pymc_experimental/model_transform/basic.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import List, Sequence, Union
2+
13
from pymc import Model
4+
from pytensor import Variable
25
from pytensor.graph import ancestors
36

47
from pymc_experimental.utils.model_fgraph import (
@@ -8,6 +11,8 @@
811
model_from_fgraph,
912
)
1013

14+
ModelVariable = Union[Variable, str]
15+
1116

1217
def prune_vars_detached_from_observed(model: Model) -> Model:
1318
"""Prune model variables that are not related to any observed variable in the Model."""
@@ -33,3 +38,9 @@ def prune_vars_detached_from_observed(model: Model) -> Model:
3338
for node_to_remove in nodes_to_remove:
3439
fgraph.remove_node(node_to_remove)
3540
return model_from_fgraph(fgraph)
41+
42+
43+
def parse_vars(model: Model, vars: Union[ModelVariable, Sequence[ModelVariable]]) -> List[Variable]:
44+
if not isinstance(vars, (list, tuple)):
45+
vars = (vars,)
46+
return [model[var] if isinstance(var, str) else var for var in vars]

pymc_experimental/model_transform/conditioning.py

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,23 @@
1-
from typing import Any, Dict, List, Sequence, Union
1+
from typing import Any, Dict, List, Optional, Sequence, Union
22

33
from pymc import Model
4+
from pymc.logprob.transforms import RVTransform
45
from pymc.pytensorf import _replace_vars_in_graphs
6+
from pymc.util import get_transformed_name, get_untransformed_name
57
from pytensor.tensor import TensorVariable
68

7-
from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed
9+
from pymc_experimental.model_transform.basic import (
10+
ModelVariable,
11+
parse_vars,
12+
prune_vars_detached_from_observed,
13+
)
814
from pymc_experimental.utils.model_fgraph import (
915
ModelDeterministic,
1016
ModelFreeRV,
1117
extract_dims,
1218
fgraph_from_model,
1319
model_deterministic,
20+
model_free_rv,
1421
model_from_fgraph,
1522
model_named,
1623
model_observed_rv,
@@ -206,3 +213,132 @@ def do(
206213
if prune_vars:
207214
return prune_vars_detached_from_observed(model)
208215
return model
216+
217+
218+
def change_value_transforms(
219+
model: Model,
220+
vars_to_transforms: Dict[ModelVariable, Union[RVTransform, None]],
221+
) -> Model:
222+
"""Change the value variables transforms in the model
223+
224+
Parameters
225+
----------
226+
model: Model
227+
vars_to_transforms: Dict
228+
Mapping between RVs and new transforms to be applied to the respective value variables
229+
230+
Returns
231+
-------
232+
new_model: Model
233+
Model with the updated transformed value variables
234+
235+
Examples
236+
--------
237+
Extract untransformed space Hessian after finding transformed space MAP
238+
239+
.. code-block:: python
240+
241+
import pymc as pm
242+
from pymc.distributions.transforms import logodds
243+
from pymc_experimental.model_transform.conditioning import change_value_transforms
244+
245+
with pm.Model() as base_m:
246+
p = pm.Uniform("p", 0, 1, transform=None)
247+
w = pm.Binomial("w", n=9, p=p, observed=6)
248+
249+
with change_value_transforms(base_m, {"p": logodds}) as transformed_p:
250+
mean_q = pm.find_MAP()
251+
252+
with change_value_transforms(transformed_p, {"p": None}) as untransformed_p:
253+
new_p = untransformed_p['p']
254+
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]
255+
256+
print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}")
257+
# Mean, Standard deviation
258+
# p 0.67, 0.16
259+
260+
"""
261+
vars_to_transforms = {
262+
parse_vars(model, var)[0]: transform for var, transform in vars_to_transforms.items()
263+
}
264+
265+
if set(vars_to_transforms.keys()) - set(model.free_RVs):
266+
raise ValueError(f"All keys must be free variables in the model: {vars_to_transforms}")
267+
268+
fgraph, memo = fgraph_from_model(model)
269+
270+
vars_to_transforms = {memo[var]: transform for var, transform in vars_to_transforms.items()}
271+
replacements = {}
272+
for node in fgraph.apply_nodes:
273+
if not isinstance(node.op, ModelFreeRV):
274+
continue
275+
276+
[dummy_rv] = node.outputs
277+
if dummy_rv not in vars_to_transforms:
278+
continue
279+
280+
transform = vars_to_transforms[dummy_rv]
281+
282+
rv, value, *dims = node.inputs
283+
284+
new_value = rv.type()
285+
try:
286+
untransformed_name = get_untransformed_name(value.name)
287+
except ValueError:
288+
untransformed_name = value.name
289+
if transform:
290+
new_name = get_transformed_name(untransformed_name, transform)
291+
else:
292+
new_name = untransformed_name
293+
new_value.name = new_name
294+
295+
new_dummy_rv = model_free_rv(rv, new_value, transform, *dims)
296+
replacements[dummy_rv] = new_dummy_rv
297+
298+
toposort_replace(fgraph, tuple(replacements.items()))
299+
return model_from_fgraph(fgraph)
300+
301+
302+
def remove_value_transforms(
303+
model: Model,
304+
vars: Optional[Sequence[ModelVariable]] = None,
305+
) -> Model:
306+
"""Remove the value variables transforms in the model
307+
308+
Parameters
309+
----------
310+
model: Model
311+
vars: Model variables, optional
312+
Model variables for which to remove transforms. Defaults to all transformed variables
313+
314+
Returns
315+
-------
316+
new_model: Model
317+
Model with the removed transformed value variables
318+
319+
Examples
320+
--------
321+
Extract untransformed space Hessian after finding transformed space MAP
322+
323+
.. code-block:: python
324+
325+
import pymc as pm
326+
from pymc_experimental.model_transform.conditioning import remove_value_transforms
327+
328+
with pm.Model() as transformed_m:
329+
p = pm.Uniform("p", 0, 1)
330+
w = pm.Binomial("w", n=9, p=p, observed=6)
331+
mean_q = pm.find_MAP()
332+
333+
with remove_value_transforms(transformed_m) as untransformed_m:
334+
new_p = untransformed_m["p"]
335+
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]
336+
print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}")
337+
338+
# Mean, Standard deviation
339+
# p 0.67, 0.16
340+
341+
"""
342+
if vars is None:
343+
vars = model.free_RVs
344+
return change_value_transforms(model, {var: None for var in vars})

pymc_experimental/tests/model_transform/test_conditioning.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
import numpy as np
33
import pymc as pm
44
import pytest
5+
from pymc.distributions.transforms import logodds
56
from pymc.variational.minibatch_rv import create_minibatch_rv
67
from pytensor import config
78

8-
from pymc_experimental.model_transform.conditioning import do, observe
9+
from pymc_experimental.model_transform.conditioning import (
10+
change_value_transforms,
11+
do,
12+
observe,
13+
remove_value_transforms,
14+
)
915

1016

1117
def test_observe():
@@ -214,3 +220,59 @@ def test_do_prune(prune):
214220
assert set(do_m.named_vars) == {"x1", "z", "llike"}
215221
else:
216222
assert set(do_m.named_vars) == orig_named_vars
223+
224+
225+
def test_change_value_transforms():
226+
with pm.Model() as base_m:
227+
p = pm.Uniform("p", 0, 1, transform=None)
228+
w = pm.Binomial("w", n=9, p=p, observed=6)
229+
assert base_m.rvs_to_transforms == {p: None, w: None}
230+
231+
with change_value_transforms(base_m, {"p": logodds}) as transformed_p:
232+
new_p = transformed_p["p"]
233+
new_w = transformed_p["w"]
234+
assert transformed_p.rvs_to_transforms == {new_p: logodds, new_w: None}
235+
mean_q = pm.find_MAP(progressbar=False)
236+
237+
with change_value_transforms(transformed_p, {"p": None}) as untransformed_p:
238+
new_p = untransformed_p["p"]
239+
new_w = untransformed_p["w"]
240+
assert untransformed_p.rvs_to_transforms == {new_p: None, new_w: None}
241+
std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0]
242+
243+
assert np.round(mean_q["p"], 2) == 0.67
244+
assert np.round(std_q[0], 2) == 0.16
245+
246+
247+
def test_change_value_transforms_error():
248+
with pm.Model() as m:
249+
x = pm.Uniform("x", observed=5.0)
250+
251+
with pytest.raises(ValueError, match="All keys must be free variables in the model"):
252+
change_value_transforms(m, {x: logodds})
253+
254+
255+
def test_remove_value_transforms():
256+
with pm.Model() as base_m:
257+
p = pm.Uniform("p", transform=logodds)
258+
q = pm.Uniform("q", transform=logodds)
259+
260+
new_m = remove_value_transforms(base_m)
261+
new_p = new_m["p"]
262+
new_q = new_m["q"]
263+
assert new_m.rvs_to_transforms == {new_p: None, new_q: None}
264+
265+
new_m = remove_value_transforms(base_m, [p, q])
266+
new_p = new_m["p"]
267+
new_q = new_m["q"]
268+
assert new_m.rvs_to_transforms == {new_p: None, new_q: None}
269+
270+
new_m = remove_value_transforms(base_m, [p])
271+
new_p = new_m["p"]
272+
new_q = new_m["q"]
273+
assert new_m.rvs_to_transforms == {new_p: None, new_q: logodds}
274+
275+
new_m = remove_value_transforms(base_m, ["q"])
276+
new_p = new_m["p"]
277+
new_q = new_m["q"]
278+
assert new_m.rvs_to_transforms == {new_p: logodds, new_q: None}

0 commit comments

Comments
 (0)