Skip to content

Allow copy and deepcopy of PYMC models #7492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Oct 3, 2024
Merged
56 changes: 56 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,62 @@
def __contains__(self, key):
return key in self.named_vars or self.name_for(key) in self.named_vars

def __copy__(self):
"""

Check warning on line 1575 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L1574-L1575

Added lines #L1574 - L1575 were not covered by tests
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.

Check warning on line 1578 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L1578

Added line #L1578 was not covered by tests
Examples
--------
.. code-block:: python

import pymc as pm
import copy

with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))

clone_m = copy.copy(m)

# Access cloned variables by name
clone_x = clone_m["x"]

# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

return clone_model(self)

def __deepcopy__(self, _):
"""
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Copy link
Member

@ricardoV94 ricardoV94 Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Clone the model.
To access variables in the cloned model use `cloned_model["var_name"]`.


Examples
--------
.. code-block:: python

import pymc as pm
import copy

with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))

clone_m = copy.deepcopy(m)

# Access cloned variables by name
clone_x = clone_m["x"]

# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

return clone_model(self)

def replace_rvs_by_values(
self,
graphs: Sequence[TensorVariable],
Expand Down
10 changes: 10 additions & 0 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from copy import copy, deepcopy

import pytensor
Expand Down Expand Up @@ -158,6 +160,14 @@ def fgraph_from_model(
"Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
)

if any(
("_rotated_" in var_name or "_hsgp_coeffs_" in var_name) for var_name in model.named_vars
):
warnings.warn(
"Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
UserWarning,
)

# Collect PyTensor variables
rvs_to_values = model.rvs_to_values
rvs = list(rvs_to_values.keys())
Expand Down
60 changes: 60 additions & 0 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pickle
import threading
import traceback
Expand Down Expand Up @@ -1761,3 +1762,62 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
figsize=None,
dpi=300,
)


class TestModelCopy:
@staticmethod
def simple_model() -> pm.Model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use staticmethods as these aren't used in more than one single test?

with pm.Model() as simple_model:
error = pm.HalfNormal("error", 0.5)
alpha = pm.Normal("alpha", 0, 1)
pm.Normal("y", alpha, error)
return simple_model

@staticmethod
def gp_model() -> pm.Model:
with pm.Model() as gp_model:
ell = pm.Gamma("ell", alpha=2, beta=1)
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior("f", X=np.arange(10)[:, None])
pm.Normal("y", f * 2)
return gp_model

def test_copy_model(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests have a lot of duplicated code that we can avoid if you use pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))

simple_model = self.simple_model()
copy_simple_model = copy.copy(simple_model)
deepcopy_simple_model = copy.deepcopy(simple_model)

with simple_model:
simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)
Copy link
Member

@ricardoV94 ricardoV94 Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Taking a single draw should be enough.

I would also test that adding a deterministic to the copy model does not introduce one in the original model (basically the example you had in the docstrings)


with copy_simple_model:
copy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)

with deepcopy_simple_model:
deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)

simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean(
("chain", "draw")
)
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][
"y"
].mean(("chain", "draw"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to take them mean, now that it's a single value. Just retrieve it with simple_model_prior_preictive.prior["y"].values

deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive[
"prior"
]["y"].mean(("chain", "draw"))

assert np.isclose(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check exact equality, since the draws are exactly the same

simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean
)
assert np.isclose(
simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean
)

def test_guassian_process_copy_failure(self) -> None:
gaussian_process_model = self.gp_model()
with pytest.warns(UserWarning):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a match kwarg to check the UserWarning is actually the one we care about

Suggested change
with pytest.warns(UserWarning):
with pytest.warns(UserWarning match=...):

copy.copy(gaussian_process_model)

with pytest.warns(UserWarning):
copy.deepcopy(gaussian_process_model)
Loading