|
1 |
| -from typing import Any, Dict, List, Sequence, Union |
| 1 | +from typing import Any, Dict, List, Optional, Sequence, Union |
2 | 2 |
|
3 | 3 | from pymc import Model
|
| 4 | +from pymc.logprob.transforms import RVTransform |
4 | 5 | from pymc.pytensorf import _replace_vars_in_graphs
|
| 6 | +from pymc.util import get_transformed_name, get_untransformed_name |
5 | 7 | from pytensor.tensor import TensorVariable
|
6 | 8 |
|
7 |
| -from pymc_experimental.model_transform.basic import prune_vars_detached_from_observed |
| 9 | +from pymc_experimental.model_transform.basic import ( |
| 10 | + ModelVariable, |
| 11 | + parse_vars, |
| 12 | + prune_vars_detached_from_observed, |
| 13 | +) |
8 | 14 | from pymc_experimental.utils.model_fgraph import (
|
9 | 15 | ModelDeterministic,
|
10 | 16 | ModelFreeRV,
|
11 | 17 | extract_dims,
|
12 | 18 | fgraph_from_model,
|
13 | 19 | model_deterministic,
|
| 20 | + model_free_rv, |
14 | 21 | model_from_fgraph,
|
15 | 22 | model_named,
|
16 | 23 | model_observed_rv,
|
@@ -206,3 +213,132 @@ def do(
|
206 | 213 | if prune_vars:
|
207 | 214 | return prune_vars_detached_from_observed(model)
|
208 | 215 | return model
|
| 216 | + |
| 217 | + |
| 218 | +def change_value_transforms( |
| 219 | + model: Model, |
| 220 | + vars_to_transforms: Dict[ModelVariable, Union[RVTransform, None]], |
| 221 | +) -> Model: |
| 222 | + """Change the value variables transforms in the model |
| 223 | +
|
| 224 | + Parameters |
| 225 | + ---------- |
| 226 | + model : Model |
| 227 | + vars_to_transforms : Dict |
| 228 | + Dictionary that maps RVs to new transforms to be applied to the respective value variables |
| 229 | +
|
| 230 | + Returns |
| 231 | + ------- |
| 232 | + new_model : Model |
| 233 | + Model with the updated transformed value variables |
| 234 | +
|
| 235 | + Examples |
| 236 | + -------- |
| 237 | + Extract untransformed space Hessian after finding transformed space MAP |
| 238 | +
|
| 239 | + .. code-block:: python |
| 240 | +
|
| 241 | + import pymc as pm |
| 242 | + from pymc.distributions.transforms import logodds |
| 243 | + from pymc_experimental.model_transform.conditioning import change_value_transforms |
| 244 | +
|
| 245 | + with pm.Model() as base_m: |
| 246 | + p = pm.Uniform("p", 0, 1, transform=None) |
| 247 | + w = pm.Binomial("w", n=9, p=p, observed=6) |
| 248 | +
|
| 249 | + with change_value_transforms(base_m, {"p": logodds}) as transformed_p: |
| 250 | + mean_q = pm.find_MAP() |
| 251 | +
|
| 252 | + with change_value_transforms(transformed_p, {"p": None}) as untransformed_p: |
| 253 | + new_p = untransformed_p['p'] |
| 254 | + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] |
| 255 | +
|
| 256 | + print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") |
| 257 | + # Mean, Standard deviation |
| 258 | + # p 0.67, 0.16 |
| 259 | +
|
| 260 | + """ |
| 261 | + vars_to_transforms = { |
| 262 | + parse_vars(model, var)[0]: transform for var, transform in vars_to_transforms.items() |
| 263 | + } |
| 264 | + |
| 265 | + if set(vars_to_transforms.keys()) - set(model.free_RVs): |
| 266 | + raise ValueError(f"All keys must be free variables in the model: {model.free_RVs}") |
| 267 | + |
| 268 | + fgraph, memo = fgraph_from_model(model) |
| 269 | + |
| 270 | + vars_to_transforms = {memo[var]: transform for var, transform in vars_to_transforms.items()} |
| 271 | + replacements = {} |
| 272 | + for node in fgraph.apply_nodes: |
| 273 | + if not isinstance(node.op, ModelFreeRV): |
| 274 | + continue |
| 275 | + |
| 276 | + [dummy_rv] = node.outputs |
| 277 | + if dummy_rv not in vars_to_transforms: |
| 278 | + continue |
| 279 | + |
| 280 | + transform = vars_to_transforms[dummy_rv] |
| 281 | + |
| 282 | + rv, value, *dims = node.inputs |
| 283 | + |
| 284 | + new_value = rv.type() |
| 285 | + try: |
| 286 | + untransformed_name = get_untransformed_name(value.name) |
| 287 | + except ValueError: |
| 288 | + untransformed_name = value.name |
| 289 | + if transform: |
| 290 | + new_name = get_transformed_name(untransformed_name, transform) |
| 291 | + else: |
| 292 | + new_name = untransformed_name |
| 293 | + new_value.name = new_name |
| 294 | + |
| 295 | + new_dummy_rv = model_free_rv(rv, new_value, transform, *dims) |
| 296 | + replacements[dummy_rv] = new_dummy_rv |
| 297 | + |
| 298 | + toposort_replace(fgraph, tuple(replacements.items())) |
| 299 | + return model_from_fgraph(fgraph) |
| 300 | + |
| 301 | + |
| 302 | +def remove_value_transforms( |
| 303 | + model: Model, |
| 304 | + vars: Optional[Sequence[ModelVariable]] = None, |
| 305 | +) -> Model: |
| 306 | + """Remove the value variables transforms in the model |
| 307 | +
|
| 308 | + Parameters |
| 309 | + ---------- |
| 310 | + model : Model |
| 311 | + vars : Model variables, optional |
| 312 | + Model variables for which to remove transforms. Defaults to all transformed variables |
| 313 | +
|
| 314 | + Returns |
| 315 | + ------- |
| 316 | + new_model : Model |
| 317 | + Model with the removed transformed value variables |
| 318 | +
|
| 319 | + Examples |
| 320 | + -------- |
| 321 | + Extract untransformed space Hessian after finding transformed space MAP |
| 322 | +
|
| 323 | + .. code-block:: python |
| 324 | +
|
| 325 | + import pymc as pm |
| 326 | + from pymc_experimental.model_transform.conditioning import remove_value_transforms |
| 327 | +
|
| 328 | + with pm.Model() as transformed_m: |
| 329 | + p = pm.Uniform("p", 0, 1) |
| 330 | + w = pm.Binomial("w", n=9, p=p, observed=6) |
| 331 | + mean_q = pm.find_MAP() |
| 332 | +
|
| 333 | + with remove_value_transforms(transformed_m) as untransformed_m: |
| 334 | + new_p = untransformed_m["p"] |
| 335 | + std_q = ((1 / pm.find_hessian(mean_q, vars=[new_p])) ** 0.5)[0] |
| 336 | + print(f" Mean, Standard deviation\np {mean_q['p']:.2}, {std_q[0]:.2}") |
| 337 | +
|
| 338 | + # Mean, Standard deviation |
| 339 | + # p 0.67, 0.16 |
| 340 | +
|
| 341 | + """ |
| 342 | + if vars is None: |
| 343 | + vars = model.free_RVs |
| 344 | + return change_value_transforms(model, {var: None for var in vars}) |
0 commit comments