Skip to content

Copy model-related shared variables in model_fgraph #218

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 46 additions & 15 deletions pymc_experimental/tests/utils/test_model_fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pymc as pm
import pytensor.tensor as pt
import pytest
from pytensor import config, shared
from pytensor.graph import Constant, FunctionGraph, node_rewriter
from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.exceptions import NotScalarConstantError
Expand All @@ -13,6 +14,7 @@
ModelObservedRV,
ModelPotential,
ModelVar,
clone_model,
fgraph_from_model,
model_deterministic,
model_free_rv,
Expand Down Expand Up @@ -76,17 +78,22 @@ def test_basic():
)


def same_storage(shared_1, shared_2) -> bool:
"""Check if two shared variables have the same storage containers (i.e., they point to the same memory)."""
return shared_1.container.storage is shared_2.container.storage


@pytest.mark.parametrize("inline_views", (False, True))
def test_data(inline_views):
"""Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
"""Test shared RNGs, MutableData, ConstantData and dim lengths are handled correctly.

Everything should be preserved across new and old models, except for shared RNGs
All model-related shared variables should be copied to become independent across models.
"""
with pm.Model(coords_mutable={"test_dim": range(3)}) as m_old:
x = pm.MutableData("x", [0.0, 1.0, 2.0], dims=("test_dim",))
y = pm.MutableData("y", [10.0, 11.0, 12.0], dims=("test_dim",))
b0 = pm.ConstantData("b0", np.zeros(3))
b1 = pm.Normal("b1")
b0 = pm.ConstantData("b0", np.zeros((1,)))
b1 = pm.DiracDelta("b1", 1.0)
mu = pm.Deterministic("mu", b0 + b1 * x, dims=("test_dim",))
obs = pm.Normal("obs", mu, sigma=1e-5, observed=y, dims=("test_dim",))

Expand All @@ -109,22 +116,46 @@ def test_data(inline_views):

m_new = model_from_fgraph(m_fgraph)

# ConstantData is preserved
assert np.all(m_new["b0"].data == m_old["b0"].data)

# Shared non-rng shared variables are preserved
assert m_new["x"].container is x.container
assert m_new["y"].container is y.container
# The rv-data mapping is preserved
assert m_new.rvs_to_values[m_new["obs"]] is m_new["y"]

# Shared rng shared variables are not preserved
m_new["b1"].owner.inputs[0].container is not m_old["b1"].owner.inputs[0].container
# ConstantData is still accessible as a model variable
np.testing.assert_array_equal(m_new["b0"], m_old["b0"])

with m_old:
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
# Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
assert not same_storage(m_new["x"], x)
assert not same_storage(m_new["y"], y)
assert not same_storage(m_new["b1"].owner.inputs[0], b1.owner.inputs[0])
assert not same_storage(m_new.dim_lengths["test_dim"], m_old.dim_lengths["test_dim"])

# Updating model shared variables in new model, doesn't affect old one
with m_new:
pm.set_data({"x": [100.0, 200.0]}, coords={"test_dim": range(2)})
assert m_new.dim_lengths["test_dim"].eval() == 2
np.testing.assert_array_almost_equal(pm.draw(m_new["x"], random_seed=63), [100.0, 200.0])
assert m_old.dim_lengths["test_dim"].eval() == 3
np.testing.assert_allclose(pm.draw(m_new["mu"]), [100.0, 200.0])
np.testing.assert_allclose(pm.draw(m_old["mu"]), [0.0, 1.0, 2.0], atol=1e-6)


@config.change_flags(floatX="float64") # Avoid downcasting Ops in the graph
def test_shared_variable():
"""Test that user defined shared variables (other than RNGs) aren't copied."""
x = shared(np.array([1, 2, 3.0]), name="x")
y = shared(np.array([1, 2, 3.0]), name="y")

with pm.Model() as m_old:
test = pm.Normal("test", mu=x, observed=y)

assert test.owner.inputs[3] is x
assert m_old.rvs_to_values[test] is y

m_new = clone_model(m_old)
test_new = m_new["test"]
# Shared Variables are cloned but still point to the same memory
assert test_new.owner.inputs[3] is not x
assert m_new.rvs_to_values[test_new] is not y
assert same_storage(test_new.owner.inputs[3], x)
assert same_storage(m_new.rvs_to_values[test_new], y)


@pytest.mark.parametrize("inline_views", (False, True))
Expand Down
34 changes: 27 additions & 7 deletions pymc_experimental/utils/model_fgraph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from copy import copy
from typing import Dict, Optional, Sequence, Tuple

import pytensor
from pymc.logprob.transforms import RVTransform
from pymc.model import Model
from pymc.pytensorf import find_rng_nodes
from pytensor import Variable
from pytensor import Variable, shared
from pytensor.compile import SharedVariable
from pytensor.graph import Apply, FunctionGraph, Op, node_rewriter
from pytensor.graph.rewriting.basic import out2in
from pytensor.scalar import Identity
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.sharedvar import ScalarSharedVariable

from pymc_experimental.utils.pytensorf import StringType

Expand Down Expand Up @@ -182,11 +185,28 @@ def fgraph_from_model(

memo = {}

# Replace RNG nodes so that seeding does not interfere with old model
for rng in find_rng_nodes(model_vars):
new_rng = rng.clone()
new_rng.set_value(rng.get_value(borrow=False))
memo[rng] = new_rng
# Replace the following shared variables in the model:
# 1. RNGs
# 2. MutableData (could increase memory usage significantly)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to wait and get some feedback. If people complain about this, we can review the memory sharing.

# 3. Mutable coords dim lengths
shared_vars_to_copy = find_rng_nodes(model_vars)
shared_vars_to_copy += [v for v in model.dim_lengths.values() if isinstance(v, SharedVariable)]
shared_vars_to_copy += [v for v in model.named_vars.values() if isinstance(v, SharedVariable)]
for var in shared_vars_to_copy:
# FIXME: ScalarSharedVariables are converted to 0d numpy arrays internally,
# so calling shared(shared(5).get_value()) returns a different type: TensorSharedVariables!
# Furthermore, PyMC silently ignores mutable dim changes that are SharedTensorVariables...
# https://github.com/pymc-devs/pytensor/issues/396
if isinstance(var, ScalarSharedVariable):
new_var = shared(var.get_value(borrow=False).item())
else:
new_var = shared(var.get_value(borrow=False))

assert new_var.type == var.type
new_var.name = var.name
new_var.tag = copy(var.tag)
# We can replace input variables by placing them in the memo
memo[var] = new_var

fgraph = FunctionGraph(
outputs=model_vars,
Expand All @@ -197,7 +217,7 @@ def fgraph_from_model(
)
# Copy model meta-info to fgraph
fgraph._coords = model._coords.copy()
fgraph._dim_lengths = model._dim_lengths.copy()
fgraph._dim_lengths = {k: memo.get(v, v) for k, v in model._dim_lengths.items()}

rvs_to_transforms = model.rvs_to_transforms
named_vars_to_dims = model.named_vars_to_dims
Expand Down