-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow copy and deepcopy of PYMC models #7492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
b34742d
fe4e0c5
bcb4309
33c5766
88fde25
90419cb
07106ec
fb00f85
d057a9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1571,6 +1571,62 @@ | |||||||||||
def __contains__(self, key): | ||||||||||||
return key in self.named_vars or self.name_for(key) in self.named_vars | ||||||||||||
|
||||||||||||
def __copy__(self): | ||||||||||||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
""" | ||||||||||||
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. | ||||||||||||
Constants are not cloned and if guassian process variables are detected then a warning will be triggered. | ||||||||||||
|
||||||||||||
Examples | ||||||||||||
-------- | ||||||||||||
.. code-block:: python | ||||||||||||
|
||||||||||||
import pymc as pm | ||||||||||||
import copy | ||||||||||||
|
||||||||||||
with pm.Model() as m: | ||||||||||||
p = pm.Beta("p", 1, 1) | ||||||||||||
x = pm.Bernoulli("x", p=p, shape=(3,)) | ||||||||||||
|
||||||||||||
clone_m = copy.copy(m) | ||||||||||||
|
||||||||||||
# Access cloned variables by name | ||||||||||||
clone_x = clone_m["x"] | ||||||||||||
|
||||||||||||
# z will be part of clone_m but not m | ||||||||||||
z = pm.Deterministic("z", clone_x + 1) | ||||||||||||
""" | ||||||||||||
from pymc.model.fgraph import clone_model | ||||||||||||
|
||||||||||||
return clone_model(self) | ||||||||||||
|
||||||||||||
def __deepcopy__(self, _): | ||||||||||||
""" | ||||||||||||
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. | ||||||||||||
Constants are not cloned and if guassian process variables are detected then a warning will be triggered. | ||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
||||||||||||
Examples | ||||||||||||
-------- | ||||||||||||
.. code-block:: python | ||||||||||||
|
||||||||||||
import pymc as pm | ||||||||||||
import copy | ||||||||||||
|
||||||||||||
with pm.Model() as m: | ||||||||||||
p = pm.Beta("p", 1, 1) | ||||||||||||
x = pm.Bernoulli("x", p=p, shape=(3,)) | ||||||||||||
|
||||||||||||
clone_m = copy.deepcopy(m) | ||||||||||||
|
||||||||||||
# Access cloned variables by name | ||||||||||||
clone_x = clone_m["x"] | ||||||||||||
|
||||||||||||
# z will be part of clone_m but not m | ||||||||||||
z = pm.Deterministic("z", clone_x + 1) | ||||||||||||
""" | ||||||||||||
from pymc.model.fgraph import clone_model | ||||||||||||
|
||||||||||||
return clone_model(self) | ||||||||||||
|
||||||||||||
def replace_rvs_by_values( | ||||||||||||
self, | ||||||||||||
graphs: Sequence[TensorVariable], | ||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,6 +11,7 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
import copy | ||||||
import pickle | ||||||
import threading | ||||||
import traceback | ||||||
|
@@ -1761,3 +1762,62 @@ def test_graphviz_call_function(self, var_names, filenames) -> None: | |||||
figsize=None, | ||||||
dpi=300, | ||||||
) | ||||||
|
||||||
|
||||||
class TestModelCopy: | ||||||
@staticmethod | ||||||
def simple_model() -> pm.Model: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to use staticmethods as these aren't used in more than one single test? |
||||||
with pm.Model() as simple_model: | ||||||
error = pm.HalfNormal("error", 0.5) | ||||||
alpha = pm.Normal("alpha", 0, 1) | ||||||
pm.Normal("y", alpha, error) | ||||||
return simple_model | ||||||
|
||||||
@staticmethod | ||||||
def gp_model() -> pm.Model: | ||||||
with pm.Model() as gp_model: | ||||||
ell = pm.Gamma("ell", alpha=2, beta=1) | ||||||
cov = 2 * pm.gp.cov.ExpQuad(1, ell) | ||||||
gp = pm.gp.Latent(cov_func=cov) | ||||||
f = gp.prior("f", X=np.arange(10)[:, None]) | ||||||
pm.Normal("y", f * 2) | ||||||
return gp_model | ||||||
|
||||||
def test_copy_model(self) -> None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These tests have a lot of duplicated code that we can avoid if you use |
||||||
simple_model = self.simple_model() | ||||||
copy_simple_model = copy.copy(simple_model) | ||||||
deepcopy_simple_model = copy.deepcopy(simple_model) | ||||||
|
||||||
with simple_model: | ||||||
simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Taking a single draw should be enough. I would also test that adding a deterministic to the copy model does not introduce one in the original model (basically the example you had in the docstrings) |
||||||
|
||||||
with copy_simple_model: | ||||||
copy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42) | ||||||
|
||||||
with deepcopy_simple_model: | ||||||
deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42) | ||||||
|
||||||
simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean( | ||||||
("chain", "draw") | ||||||
) | ||||||
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][ | ||||||
"y" | ||||||
].mean(("chain", "draw")) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to take them mean, now that it's a single value. Just retrieve it with |
||||||
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive[ | ||||||
"prior" | ||||||
]["y"].mean(("chain", "draw")) | ||||||
|
||||||
assert np.isclose( | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can check exact equality, since the draws are exactly the same |
||||||
simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean | ||||||
) | ||||||
assert np.isclose( | ||||||
simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean | ||||||
) | ||||||
|
||||||
def test_guassian_process_copy_failure(self) -> None: | ||||||
gaussian_process_model = self.gp_model() | ||||||
with pytest.warns(UserWarning): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a match kwarg to check the UserWarning is actually the one we care about
Suggested change
|
||||||
copy.copy(gaussian_process_model) | ||||||
|
||||||
with pytest.warns(UserWarning): | ||||||
copy.deepcopy(gaussian_process_model) |
Uh oh!
There was an error while loading. Please reload this page.