diff --git a/docs/api_reference.rst b/docs/api_reference.rst index 0f633ac21..386d3166a 100644 --- a/docs/api_reference.rst +++ b/docs/api_reference.rst @@ -41,5 +41,8 @@ Utils .. autosummary:: :toctree: generated/ + clone_model spline.bspline_interpolation prior.prior_from_idata + model_fgraph.fgraph_from_model + model_fgraph.model_from_fgraph diff --git a/pymc_experimental/tests/utils/__init__.py b/pymc_experimental/tests/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymc_experimental/tests/utils/test_model_fgraph.py b/pymc_experimental/tests/utils/test_model_fgraph.py new file mode 100644 index 000000000..284a9bfdf --- /dev/null +++ b/pymc_experimental/tests/utils/test_model_fgraph.py @@ -0,0 +1,288 @@ +import numpy as np +import pymc as pm +import pytensor.tensor as pt +import pytest +from pytensor.graph import Constant, FunctionGraph, node_rewriter +from pytensor.graph.rewriting.basic import in2out +from pytensor.tensor.exceptions import NotScalarConstantError + +from pymc_experimental.utils.model_fgraph import ( + ModelDeterministic, + ModelFreeRV, + ModelNamed, + ModelObservedRV, + ModelPotential, + ModelVar, + fgraph_from_model, + model_deterministic, + model_free_rv, + model_from_fgraph, +) + + +def test_basic(): + """Test we can convert from a PyMC Model to a FunctionGraph and back""" + with pm.Model(coords={"test_dim": range(3)}) as m_old: + x = pm.Normal("x") + y = pm.Deterministic("y", x + 1) + w = pm.HalfNormal("w", pm.math.exp(y)) + z = pm.Normal("z", y, w, observed=[0, 1, 2], dims=("test_dim",)) + pot = pm.Potential("pot", x * 2) + + m_fgraph, memo = fgraph_from_model(m_old) + assert isinstance(m_fgraph, FunctionGraph) + + assert isinstance(memo[x].owner.op, ModelFreeRV) + assert isinstance(memo[y].owner.op, ModelDeterministic) + assert isinstance(memo[w].owner.op, ModelFreeRV) + assert isinstance(memo[z].owner.op, ModelObservedRV) + assert isinstance(memo[pot].owner.op, ModelPotential) + + m_new = model_from_fgraph(m_fgraph) + assert isinstance(m_new, pm.Model) + + assert m_new.coords == {"test_dim": tuple(range(3))} + assert m_new._dim_lengths["test_dim"].eval() == 3 + assert m_new.named_vars_to_dims == {"z": ["test_dim"]} + + named_vars = {"x", "y", "w", "z", "pot"} + assert set(m_new.named_vars) == named_vars + for named_var in named_vars: + assert m_new[named_var] is not m_old[named_var] + for value_new, value_old in zip(m_new.rvs_to_values.values(), m_old.rvs_to_values.values()): + # Constants are not cloned + if not isinstance(value_new, Constant): + assert value_new is not value_old + assert m_new["x"] in m_new.free_RVs + assert m_new["w"] in m_new.free_RVs + assert m_new["y"] in m_new.deterministics + assert m_new["z"] in m_new.observed_RVs + assert m_new["pot"] in m_new.potentials + assert m_new.rvs_to_transforms[m_new["x"]] is None + assert m_new.rvs_to_transforms[m_new["w"]] is pm.distributions.transforms.log + assert m_new.rvs_to_transforms[m_new["z"]] is None + + # Test random + new_y_draw, new_z_draw = pm.draw([m_new["y"], m_new["z"]], draws=5, random_seed=1) + old_y_draw, old_z_draw = pm.draw([m_old["y"], m_old["z"]], draws=5, random_seed=1) + np.testing.assert_array_equal(new_y_draw, old_y_draw) + np.testing.assert_array_equal(new_z_draw, old_z_draw) + + # Test logp + ip = m_new.initial_point() + np.testing.assert_equal( + m_new.compile_logp()(ip), + m_old.compile_logp()(ip), + ) + + +def test_data(): + """Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly. + + Everything should be preserved across new and old models, except for shared RNGs + """ + with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old: + x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",)) + y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",)) + b0 = pm.ConstantData("b0", 0.0) + b1 = pm.Normal("b1") + mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",)) + obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",)) + + m_fgraph, memo = fgraph_from_model(m_old) + assert isinstance(memo[x].owner.op, ModelNamed) + assert isinstance(memo[y].owner.op, ModelNamed) + assert isinstance(memo[b0].owner.op, ModelNamed) + + m_new = model_from_fgraph(m_fgraph) + + # ConstantData is preserved + assert m_new["b0"].data == m_old["b0"].data + + # Shared non-rng shared variables are preserved + assert m_new["x"].container is x.container + assert m_new["y"].container is y.container + assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"] + + # Shared rng shared variables are not preserved + m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container + + with m_old: + pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)}) + + assert m_new.dim_lengths["test_dim"].eval() == 2 + np.testing.assert_array_almost_equal(pm.draw(m_new["x"]), [100.0, 200.0]) + + +def test_deterministics(): + """Test handling of deterministics. + + We don't want Deterministics in the middle of the FunctionGraph, as they would make rewrites cumbersome + However we want them in the middle of Model.basic_RVs, so they display nicely in graphviz + + There is one edge case that has to be considered, when a Deterministic is just a copy of a RV. + In that case we don't bother to reintroduce it in between other Model.basic_RVs + """ + with pm.Model() as m: + x = pm.Normal("x") + mu = pm.Deterministic("mu", pm.math.abs(x)) + sigma = pm.math.exp(x) + pm.Deterministic("sigma", sigma) + y = pm.Normal("y", mu, sigma) + # Special case where the Deterministic + # is a direct view on another model variable + y_ = pm.Deterministic("y_", y) + # Just for kicks, make it a double one! + y__ = pm.Deterministic("y__", y_) + z = pm.Normal("z", y__) + + # Deterministic mu is in the graph of x to y but not sigma + assert m["y"].owner.inputs[3] is m["mu"] + assert m["y"].owner.inputs[4] is not m["sigma"] + + fg, _ = fgraph_from_model(m) + + # Check that no Deterministics are in graph of x to y and y to z + x, y, z, det_mu, det_sigma, det_y_, det_y__ = fg.outputs + # [Det(mu), Det(sigma)] + mu = det_mu.owner.inputs[0] + sigma = det_sigma.owner.inputs[0] + # [FreeRV(y(mu, sigma))] not [FreeRV(y(Det(mu), Det(sigma)))] + assert y.owner.inputs[0].owner.inputs[3] is mu + assert y.owner.inputs[0].owner.inputs[4] is sigma + # [FreeRV(z(y))] not [FreeRV(z(Det(Det(y))))] + assert z.owner.inputs[0].owner.inputs[3] is y + # [Det(y), Det(y)], not [Det(y), Det(Det(y))] + assert det_y_.owner.inputs[0] is y + assert det_y__.owner.inputs[0] is y + assert det_y_ is not det_y__ + + # Both mu and sigma deterministics are now in the graph of x to y + m = model_from_fgraph(fg) + assert m["y"].owner.inputs[3] is m["mu"] + assert m["y"].owner.inputs[4] is m["sigma"] + # But not y_* in y to z, since there was no real Op in between + assert m["z"].owner.inputs[3] is m["y"] + assert m["y_"].owner.inputs[0] is m["y"] + assert m["y__"].owner.inputs[0] is m["y"] + + +def test_context_error(): + """Test that model_from_fgraph fails when called inside a Model context. + + We can't allow it, because the new Model that's returned would be a child of whatever Model context is active. + """ + with pm.Model() as m: + x = pm.Normal("x") + + fg = fgraph_from_model(m) + + with pytest.raises(RuntimeError, match="cannot be called inside a PyMC model context"): + model_from_fgraph(fg) + + +def test_sub_model_error(): + """Test Error is raised when trying to convert a sub-model to fgraph.""" + with pm.Model() as m: + x = pm.Beta("x", 1, 1) + with pm.Model() as sub_m: + y = pm.Normal("y", x) + + nodes = [v for v in fgraph_from_model(m)[0].toposort() if not isinstance(v.op, ModelVar)] + assert len(nodes) == 2 + assert isinstance(nodes[0].op, pm.Beta) + assert isinstance(nodes[1].op, pm.Normal) + + with pytest.raises(ValueError, match="Nested sub-models cannot be converted"): + fgraph_from_model(sub_m) + + +@pytest.fixture() +def non_centered_rewrite(): + @node_rewriter(tracks=[ModelFreeRV]) + def non_centered_param(fgraph: FunctionGraph, node): + """Rewrite that replaces centered normal by non-centered parametrization.""" + + rv, value, *dims = node.inputs + if not isinstance(rv.owner.op, pm.Normal): + return + rng, size, dtype, loc, scale = rv.owner.inputs + + # Only apply rewrite if size information is explicit + if size.ndim == 0: + return None + + try: + is_unit = ( + pt.get_underlying_scalar_constant_value(loc) == 0 + and pt.get_underlying_scalar_constant_value(scale) == 1 + ) + except NotScalarConstantError: + is_unit = False + + # Nothing to do here + if is_unit: + return + + raw_norm = pm.Normal.dist(0, 1, size=size, rng=rng) + raw_norm.name = f"{rv.name}_raw_" + raw_norm_value = raw_norm.clone() + fgraph.add_input(raw_norm_value) + raw_norm = model_free_rv(raw_norm, raw_norm_value, node.op.transform, *dims) + + new_norm = loc + raw_norm * scale + new_norm.name = rv.name + new_norm_det = model_deterministic(new_norm, *dims) + fgraph.add_output(new_norm_det) + + return [new_norm] + + return in2out(non_centered_param) + + +def test_fgraph_rewrite(non_centered_rewrite): + """Test we can apply a simple rewrite to a PyMC Model.""" + + with pm.Model(coords={"subject": range(10)}) as m_old: + group_mean = pm.Normal("group_mean") + group_std = pm.HalfNormal("group_std") + subject_mean = pm.Normal("subject_mean", group_mean, group_std, dims=("subject",)) + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10), dims=("subject",)) + + fg, _ = fgraph_from_model(m_old) + non_centered_rewrite.apply(fg) + + m_new = model_from_fgraph(fg) + assert m_new.named_vars_to_dims == { + "subject_mean": ["subject"], + "subject_mean_raw_": ["subject"], + "obs": ["subject"], + } + assert set(m_new.named_vars) == { + "group_mean", + "group_std", + "subject_mean_raw_", + "subject_mean", + "obs", + } + assert {rv.name for rv in m_new.free_RVs} == {"group_mean", "group_std", "subject_mean_raw_"} + assert {rv.name for rv in m_new.observed_RVs} == {"obs"} + assert {rv.name for rv in m_new.deterministics} == {"subject_mean"} + + with pm.Model() as m_ref: + group_mean = pm.Normal("group_mean") + group_std = pm.HalfNormal("group_std") + subject_mean_raw = pm.Normal("subject_mean_raw_", 0, 1, shape=(10,)) + subject_mean = pm.Deterministic("subject_mean", group_mean + subject_mean_raw * group_std) + obs = pm.Normal("obs", subject_mean, 1, observed=np.zeros(10)) + + np.testing.assert_array_equal( + pm.draw(m_new["subject_mean_raw_"], draws=7, random_seed=1), + pm.draw(m_ref["subject_mean_raw_"], draws=7, random_seed=1), + ) + + ip = m_new.initial_point() + np.testing.assert_equal( + m_new.compile_logp()(ip), + m_ref.compile_logp()(ip), + ) diff --git a/pymc_experimental/utils/__init__.py b/pymc_experimental/utils/__init__.py index db751aa2a..705d21076 100644 --- a/pymc_experimental/utils/__init__.py +++ b/pymc_experimental/utils/__init__.py @@ -15,5 +15,11 @@ from pymc_experimental.utils import prior, spline from pymc_experimental.utils.linear_cg import linear_cg +from pymc_experimental.utils.model_fgraph import clone_model -# from pymc_experimental.utils.pivoted_cholesky import pivoted_cholesky +__all__ = ( + "clone_model", + "linear_cg", + "prior", + "spline", +) diff --git a/pymc_experimental/utils/model_fgraph.py b/pymc_experimental/utils/model_fgraph.py new file mode 100644 index 000000000..909cd9a07 --- /dev/null +++ b/pymc_experimental/utils/model_fgraph.py @@ -0,0 +1,328 @@ +from typing import Dict, Optional, Sequence, Tuple + +import pytensor +from pymc.logprob.transforms import RVTransform +from pymc.model import Model +from pymc.pytensorf import find_rng_nodes +from pytensor import Variable +from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter +from pytensor.graph.rewriting.basic import out2in +from pytensor.scalar import Identity +from pytensor.tensor.elemwise import Elemwise + +from pymc_experimental.utils.pytensorf import StringType + + +class ModelVar(Op): + """A dummy Op that describes the purpose of a Model variable and contains + meta-information as additional inputs (value and dims). + """ + + def make_node(self, rv, *dims): + assert isinstance(rv, Variable) + dims = self._parse_dims(rv, *dims) + return Apply(self, [rv, *dims], [rv.type(name=rv.name)]) + + def _parse_dims(self, rv, *dims): + if dims: + dims = [pytensor.as_symbolic(dim) for dim in dims] + assert all(isinstance(dim.type, StringType) for dim in dims) + assert len(dims) == rv.type.ndim + return dims + + def infer_shape(self, fgraph, node, inputs_shape): + return [inputs_shape[0]] + + def do_constant_folding(self, fgraph, node): + return False + + def perform(self, *args, **kwargs): + raise RuntimeError("ModelVars should never be in a final graph!") + + +class ModelValuedVar(ModelVar): + + __props__ = ("transform",) + + def __init__(self, transform: Optional[RVTransform] = None): + if transform is not None and not isinstance(transform, RVTransform): + raise TypeError(f"transform must be None or RVTransform type, got {type(transform)}") + self.transform = transform + super().__init__() + + def make_node(self, rv, value, *dims): + assert isinstance(rv, Variable) + dims = self._parse_dims(rv, *dims) + if value is not None: + assert isinstance(value, Variable) + assert rv.type.in_same_class(value.type) + return Apply(self, [rv, value, *dims], [rv.type(name=rv.name)]) + + +class ModelFreeRV(ModelValuedVar): + pass + + +class ModelObservedRV(ModelValuedVar): + pass + + +class ModelPotential(ModelVar): + pass + + +class ModelDeterministic(ModelVar): + pass + + +class ModelNamed(ModelVar): + pass + + +def model_free_rv(rv, value, transform, *dims): + return ModelFreeRV(transform=transform)(rv, value, *dims) + + +model_observed_rv = ModelObservedRV() +model_potential = ModelPotential() +model_deterministic = ModelDeterministic() +model_named = ModelNamed() + + +def toposort_replace( + fgraph: FunctionGraph, replacements: Sequence[Tuple[Variable, Variable]] +) -> None: + """Replace multiple variables in topological order.""" + toposort = fgraph.toposort() + sorted_replacements = sorted( + replacements, key=lambda pair: toposort.index(pair[0].owner) if pair[0].owner else -1 + ) + fgraph.replace_all(tuple(sorted_replacements), import_missing=True) + + +@node_rewriter([Elemwise]) +def local_remove_identity(fgraph, node): + if isinstance(node.op.scalar_op, Identity): + return [node.inputs[0]] + + +remove_identity_rewrite = out2in(local_remove_identity) + + +def fgraph_from_model(model: Model) -> Tuple[FunctionGraph, Dict[Variable, Variable]]: + """Convert Model to FunctionGraph. + + See: model_from_fgraph + + Returns + ------- + fgraph: FunctionGraph + FunctionGraph that includes a copy of model variables, wrapped in dummy `ModelVar` Ops. + It should be possible to reconstruct a valid PyMC model using `model_from_fgraph`. + + memo: Dict + A dictionary mapping original model variables to the equivalent nodes in the fgraph. + """ + + if any(v is not None for v in model.rvs_to_initial_values.values()): + raise NotImplementedError("Cannot convert models with non-default initial_values") + + if model.parent is not None: + raise ValueError( + "Nested sub-models cannot be converted to fgraph. Convert the parent model instead" + ) + + # Collect PyTensor variables + rvs_to_values = model.rvs_to_values + rvs = list(rvs_to_values.keys()) + free_rvs = model.free_RVs + observed_rvs = model.observed_RVs + potentials = model.potentials + # We copy Deterministics (Identity Op) so that they don't show in between "main" variables + # We later remove these Identity Ops when we have a Deterministic ModelVar Op as a separator + old_deterministics = model.deterministics + deterministics = [det.copy(det.name) for det in old_deterministics] + # Other variables that are in model.named_vars but are not any of the categories above + # E.g., MutableData, ConstantData, _dim_lengths + # We use the same trick as deterministics! + accounted_for = free_rvs + observed_rvs + potentials + old_deterministics + old_other_named_vars = [var for var in model.named_vars.values() if var not in accounted_for] + other_named_vars = [var.copy(var.name) for var in old_other_named_vars] + value_vars = [val for val in rvs_to_values.values() if val not in old_other_named_vars] + + model_vars = rvs + potentials + deterministics + other_named_vars + value_vars + + memo = {} + + # Replace RNG nodes so that seeding does not interfere with old model + for rng in find_rng_nodes(model_vars): + new_rng = rng.clone() + new_rng.set_value(rng.get_value(borrow=False)) + memo[rng] = new_rng + + fgraph = FunctionGraph( + outputs=model_vars, + clone=True, + memo=memo, + copy_orphans=True, + copy_inputs=True, + ) + # Copy model meta-info to fgraph + fgraph._coords = model._coords.copy() + fgraph._dim_lengths = model._dim_lengths.copy() + + rvs_to_transforms = model.rvs_to_transforms + named_vars_to_dims = model.named_vars_to_dims + + # Introduce dummy `ModelVar` Ops + free_rvs_to_transforms = {memo[k]: tr for k, tr in rvs_to_transforms.items()} + free_rvs_to_values = {memo[k]: memo[v] for k, v in rvs_to_values.items() if k in free_rvs} + observed_rvs_to_values = { + memo[k]: memo[v] for k, v in rvs_to_values.items() if k in observed_rvs + } + potentials = [memo[k] for k in potentials] + deterministics = [memo[k] for k in deterministics] + other_named_vars = [memo[k] for k in other_named_vars] + + vars = fgraph.outputs + new_vars = [] + for var in vars: + dims = named_vars_to_dims.get(var.name, ()) + if var in free_rvs_to_values: + new_var = model_free_rv( + var, free_rvs_to_values[var], free_rvs_to_transforms[var], *dims + ) + elif var in observed_rvs_to_values: + new_var = model_observed_rv(var, observed_rvs_to_values[var], *dims) + elif var in potentials: + new_var = model_potential(var, *dims) + elif var in deterministics: + new_var = model_deterministic(var, *dims) + elif var in other_named_vars: + new_var = model_named(var, *dims) + else: + # Value variables + new_var = var + new_vars.append(new_var) + + replacements = tuple(zip(vars, new_vars)) + toposort_replace(fgraph, replacements) + + # Reference model vars in memo + inverse_memo = {v: k for k, v in memo.items()} + for var, model_var in replacements: + if isinstance( + model_var.owner is not None and model_var.owner.op, (ModelDeterministic, ModelNamed) + ): + # Ignore extra identity that will be removed at the end + var = var.owner.inputs[0] + original_var = inverse_memo[var] + memo[original_var] = model_var + + # Remove value variable as outputs, now that they are graph inputs + first_value_idx = len(fgraph.outputs) - len(value_vars) + for _ in value_vars: + fgraph.remove_output(first_value_idx) + + # Now that we have Deterministic dummy Ops, we remove the noisy `Identity`s from the graph + remove_identity_rewrite.apply(fgraph) + + return fgraph, memo + + +def model_from_fgraph(fgraph: FunctionGraph) -> Model: + """Convert FunctionGraph to PyMC model. + + This requires nodes to be properly tagged with `ModelVar` dummy Ops. + + See: fgraph_from_model + """ + model = Model() + if model.parent is not None: + raise RuntimeError("model_to_fgraph cannot be called inside a PyMC model context") + model._coords = getattr(fgraph, "_coords", {}) + model._dim_lengths = getattr(fgraph, "_dim_lengths", {}) + + # Replace dummy `ModelVar` Ops by the underlying variables, + # Except for Deterministics which could reintroduce the old graphs + fgraph = fgraph.clone() + model_dummy_vars = [ + model_node.outputs[0] + for model_node in fgraph.toposort() + if isinstance(model_node.op, ModelVar) + ] + model_dummy_vars_to_vars = { + dummy_var: dummy_var.owner.inputs[0] + for dummy_var in model_dummy_vars + # Don't include Deterministics! + if not isinstance(dummy_var.owner.op, ModelDeterministic) + } + toposort_replace(fgraph, tuple(model_dummy_vars_to_vars.items())) + + # Populate new PyMC model mappings + non_det_model_vars = set(model_dummy_vars_to_vars.values()) + for model_var in model_dummy_vars: + if isinstance(model_var.owner.op, ModelFreeRV): + var, value, *dims = model_var.owner.inputs + transform = model_var.owner.op.transform + model.free_RVs.append(var) + # PyMC does not allow setting transform when we pass a value_var. Why? + model.create_value_var(var, transform=None, value_var=value) + model.rvs_to_transforms[var] = transform + model.set_initval(var, initval=None) + elif isinstance(model_var.owner.op, ModelObservedRV): + var, value, *dims = model_var.owner.inputs + model.observed_RVs.append(var) + model.create_value_var(var, transform=None, value_var=value) + elif isinstance(model_var.owner.op, ModelPotential): + var, *dims = model_var.owner.inputs + model.potentials.append(var) + elif isinstance(model_var.owner.op, ModelDeterministic): + var, *dims = model_var.owner.inputs + # Register the original var (not the copy) as the Deterministic + # So it shows in the expected place in graphviz. + # unless it's another model var, in which case we need a copy! + if var in non_det_model_vars: + var = var.copy() + model.deterministics.append(var) + elif isinstance(model_var.owner.op, ModelNamed): + var, *dims = model_var.owner.inputs + else: + raise TypeError(f"Unexpected ModelVar type {type(model_var)}") + + var.name = model_var.name + dims = [dim.data for dim in dims] if dims else None + model.add_named_variable(var, dims=dims) + + return model + + +def clone_model(model: Model) -> Tuple[Model]: + """Clone a PyMC model. + + Recreates a PyMC model with clones of the original variables. + Shared variables will point to the same container but be otherwise different objects. + Constants are not cloned. + + + Examples + -------- + + .. code-block:: python + + import pymc as pm + from pymc_experimental.utils import clone_model + + with pm.Model() as m: + p = pm.Beta("p", 1, 1) + x = pm.Bernoulli("x", p=p, shape=(3,)) + + with clone_model(m) as clone_m: + # Access cloned variables by name + clone_x = clone_m["x"] + + # z will be part of clone_m but not m + z = pm.Deterministic("z", clone_x + 1) + + """ + return model_from_fgraph(fgraph_from_model(model)[0]) diff --git a/pymc_experimental/utils/pytensorf.py b/pymc_experimental/utils/pytensorf.py new file mode 100644 index 000000000..76358c273 --- /dev/null +++ b/pymc_experimental/utils/pytensorf.py @@ -0,0 +1,33 @@ +import pytensor +from pytensor.graph import Constant, Type + + +class StringType(Type[str]): + def clone(self, **kwargs): + return type(self)() + + def filter(self, x, strict=False, allow_downcast=None): + if isinstance(x, str): + return x + else: + raise TypeError("Expected a string!") + + def __str__(self): + return "string" + + @staticmethod + def may_share_memory(a, b): + return isinstance(a, str) and a is b + + +stringtype = StringType() + + +class StringConstant(Constant): + pass + + +@pytensor._as_symbolic.register(str) +def as_symbolic_string(x, **kwargs): + + return StringConstant(stringtype, x)