2
2
import pymc as pm
3
3
import pytensor .tensor as pt
4
4
import pytest
5
+ from pytensor import config , shared
5
6
from pytensor .graph import Constant , FunctionGraph , node_rewriter
6
7
from pytensor .graph .rewriting .basic import in2out
7
8
from pytensor .tensor .exceptions import NotScalarConstantError
13
14
ModelObservedRV ,
14
15
ModelPotential ,
15
16
ModelVar ,
17
+ clone_model ,
16
18
fgraph_from_model ,
17
19
model_deterministic ,
18
20
model_free_rv ,
@@ -76,17 +78,22 @@ def test_basic():
76
78
)
77
79
78
80
81
+ def same_storage (shared_1 , shared_2 ) -> bool :
82
+ """Check if two shared variables have the same storage containers (i.e., they point to the same memory)."""
83
+ return shared_1 .container .storage is shared_2 .container .storage
84
+
85
+
79
86
@pytest .mark .parametrize ("inline_views" , (False , True ))
80
87
def test_data (inline_views ):
81
- """Test shared RNGs, MutableData, ConstantData and Dim lengths are handled correctly.
88
+ """Test shared RNGs, MutableData, ConstantData and dim lengths are handled correctly.
82
89
83
- Everything should be preserved across new and old models, except for shared RNGs
90
+ All model-related shared variables should be copied to become independent across models.
84
91
"""
85
92
with pm .Model (coords_mutable = {"test_dim" : range (3 )}) as m_old :
86
93
x = pm .MutableData ("x" , [0.0 , 1.0 , 2.0 ], dims = ("test_dim" ,))
87
94
y = pm .MutableData ("y" , [10.0 , 11.0 , 12.0 ], dims = ("test_dim" ,))
88
- b0 = pm .ConstantData ("b0" , np .zeros (3 ))
89
- b1 = pm .Normal ("b1" )
95
+ b0 = pm .ConstantData ("b0" , np .zeros (( 1 ,) ))
96
+ b1 = pm .DiracDelta ("b1" , 1.0 )
90
97
mu = pm .Deterministic ("mu" , b0 + b1 * x , dims = ("test_dim" ,))
91
98
obs = pm .Normal ("obs" , mu , sigma = 1e-5 , observed = y , dims = ("test_dim" ,))
92
99
@@ -109,22 +116,46 @@ def test_data(inline_views):
109
116
110
117
m_new = model_from_fgraph (m_fgraph )
111
118
112
- # ConstantData is preserved
113
- assert np .all (m_new ["b0" ].data == m_old ["b0" ].data )
114
-
115
- # Shared non-rng shared variables are preserved
116
- assert m_new ["x" ].container is x .container
117
- assert m_new ["y" ].container is y .container
119
+ # The rv-data mapping is preserved
118
120
assert m_new .rvs_to_values [m_new ["obs" ]] is m_new ["y" ]
119
121
120
- # Shared rng shared variables are not preserved
121
- assert m_new ["b1" ]. owner . inputs [ 0 ]. container is not m_old ["b1" ]. owner . inputs [ 0 ]. container
122
+ # ConstantData is still accessible as a model variable
123
+ np . testing . assert_array_equal ( m_new ["b0" ], m_old ["b0" ])
122
124
123
- with m_old :
124
- pm .set_data ({"x" : [100.0 , 200.0 ]}, coords = {"test_dim" : range (2 )})
125
+ # Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
126
+ assert not same_storage (m_new ["x" ], x )
127
+ assert not same_storage (m_new ["y" ], y )
128
+ assert not same_storage (m_new ["b1" ].owner .inputs [0 ], b1 .owner .inputs [0 ])
129
+ assert not same_storage (m_new .dim_lengths ["test_dim" ], m_old .dim_lengths ["test_dim" ])
125
130
131
+ # Updating model shared variables in new model, doesn't affect old one
132
+ with m_new :
133
+ pm .set_data ({"x" : [100.0 , 200.0 ]}, coords = {"test_dim" : range (2 )})
126
134
assert m_new .dim_lengths ["test_dim" ].eval () == 2
127
- np .testing .assert_array_almost_equal (pm .draw (m_new ["x" ], random_seed = 63 ), [100.0 , 200.0 ])
135
+ assert m_old .dim_lengths ["test_dim" ].eval () == 3
136
+ np .testing .assert_allclose (pm .draw (m_new ["mu" ]), [100.0 , 200.0 ])
137
+ np .testing .assert_allclose (pm .draw (m_old ["mu" ]), [0.0 , 1.0 , 2.0 ], atol = 1e-6 )
138
+
139
+
140
+ @config .change_flags (floatX = "float64" ) # Avoid downcasting Ops in the graph
141
+ def test_shared_variable ():
142
+ """Test that user defined shared variables (other than RNGs) aren't copied."""
143
+ x = shared (np .array ([1 , 2 , 3.0 ]), name = "x" )
144
+ y = shared (np .array ([1 , 2 , 3.0 ]), name = "y" )
145
+
146
+ with pm .Model () as m_old :
147
+ test = pm .Normal ("test" , mu = x , observed = y )
148
+
149
+ assert test .owner .inputs [3 ] is x
150
+ assert m_old .rvs_to_values [test ] is y
151
+
152
+ m_new = clone_model (m_old )
153
+ test_new = m_new ["test" ]
154
+ # Shared Variables are cloned but still point to the same memory
155
+ assert test_new .owner .inputs [3 ] is not x
156
+ assert m_new .rvs_to_values [test_new ] is not y
157
+ assert same_storage (test_new .owner .inputs [3 ], x )
158
+ assert same_storage (m_new .rvs_to_values [test_new ], y )
128
159
129
160
130
161
@pytest .mark .parametrize ("inline_views" , (False , True ))
0 commit comments