@@ -107,14 +107,16 @@ def test_data(inline_views):
107
107
with pm .Model (coords_mutable = {"test_dim" : range (3 )}) as m_old :
108
108
x = pm .MutableData ("x" , [0.0 , 1.0 , 2.0 ], dims = ("test_dim" ,))
109
109
y = pm .MutableData ("y" , [10.0 , 11.0 , 12.0 ], dims = ("test_dim" ,))
110
+ sigma = pm .MutableData ("sigma" , [1.0 ], shape = (1 ,))
110
111
b0 = pm .ConstantData ("b0" , np .zeros ((1 ,)))
111
112
b1 = pm .DiracDelta ("b1" , 1.0 )
112
113
mu = pm .Deterministic ("mu" , b0 + b1 * x , dims = ("test_dim" ,))
113
- obs = pm .Normal ("obs" , mu , sigma = 1e-5 , observed = y , dims = ("test_dim" ,))
114
+ obs = pm .Normal ("obs" , mu = mu , sigma = sigma , observed = y , dims = ("test_dim" ,))
114
115
115
116
m_fgraph , memo = fgraph_from_model (m_old , inlined_views = inline_views )
116
117
assert isinstance (memo [x ].owner .op , ModelNamed )
117
118
assert isinstance (memo [y ].owner .op , ModelNamed )
119
+ assert isinstance (memo [sigma ].owner .op , ModelNamed )
118
120
assert isinstance (memo [b0 ].owner .op , ModelNamed )
119
121
mu_inp = memo [mu ].owner .inputs [0 ]
120
122
obs = memo [obs ]
@@ -124,10 +126,13 @@ def test_data(inline_views):
124
126
assert mu_inp .owner .inputs [1 ].owner .inputs [1 ] is memo [x ].owner .inputs [0 ]
125
127
# ObservedRV(obs, y, *dims) not ObservedRV(obs, Named(y), *dims)
126
128
assert obs .owner .inputs [1 ] is memo [y ].owner .inputs [0 ]
129
+ # ObservedRV(Normal(..., sigma), ...) not ObservedRV(Normal(..., Named(sigma)), ...)
130
+ assert obs .owner .inputs [0 ].owner .inputs [4 ] is memo [sigma ].owner .inputs [0 ]
127
131
else :
128
132
assert mu_inp .owner .inputs [0 ] is memo [b0 ]
129
133
assert mu_inp .owner .inputs [1 ].owner .inputs [1 ] is memo [x ]
130
134
assert obs .owner .inputs [1 ] is memo [y ]
135
+ assert obs .owner .inputs [0 ].owner .inputs [4 ] is memo [sigma ]
131
136
132
137
m_new = model_from_fgraph (m_fgraph )
133
138
@@ -140,9 +145,17 @@ def test_data(inline_views):
140
145
# Shared model variables, dim lengths, and rngs are copied and no longer point to the same memory
141
146
assert not same_storage (m_new ["x" ], x )
142
147
assert not same_storage (m_new ["y" ], y )
148
+ assert not same_storage (m_new ["sigma" ], sigma )
143
149
assert not same_storage (m_new ["b1" ].owner .inputs [0 ], b1 .owner .inputs [0 ])
144
150
assert not same_storage (m_new .dim_lengths ["test_dim" ], m_old .dim_lengths ["test_dim" ])
145
151
152
+ # Check they have the same type
153
+ assert m_new ["x" ].type == x .type
154
+ assert m_new ["y" ].type == y .type
155
+ assert m_new ["sigma" ].type == sigma .type
156
+ assert m_new ["b1" ].owner .inputs [0 ].type == b1 .owner .inputs [0 ].type
157
+ assert m_new .dim_lengths ["test_dim" ].type == m_old .dim_lengths ["test_dim" ].type
158
+
146
159
# Updating model shared variables in new model, doesn't affect old one
147
160
with m_new :
148
161
pm .set_data ({"x" : [100.0 , 200.0 ]}, coords = {"test_dim" : range (2 )})
@@ -155,22 +168,31 @@ def test_data(inline_views):
155
168
@config .change_flags (floatX = "float64" ) # Avoid downcasting Ops in the graph
156
169
def test_shared_variable ():
157
170
"""Test that user defined shared variables (other than RNGs) aren't copied."""
158
- x = shared (np .array ([1 , 2 , 3.0 ]), name = "x" )
159
- y = shared (np .array ([1 , 2 , 3.0 ]), name = "y" )
171
+ mu = shared (np .array ([1 , 2 , 3.0 ]), shape = (None ,), name = "mu" )
172
+ sigma = shared (np .array ([1.0 ]), shape = (1 ,), name = "sigma" )
173
+ obs = shared (np .array ([1 , 2 , 3.0 ]), shape = (3 ,), name = "obs" )
160
174
161
175
with pm .Model () as m_old :
162
- test = pm .Normal ("test" , mu = x , observed = y )
176
+ test = pm .Normal ("test" , mu = mu , sigma = sigma , observed = obs )
163
177
164
- assert test .owner .inputs [3 ] is x
165
- assert m_old .rvs_to_values [test ] is y
178
+ assert test .owner .inputs [3 ] is mu
179
+ assert test .owner .inputs [4 ] is sigma
180
+ assert m_old .rvs_to_values [test ] is obs
166
181
167
182
m_new = clone_model (m_old )
168
183
test_new = m_new ["test" ]
169
184
# Shared Variables are cloned but still point to the same memory
170
- assert test_new .owner .inputs [3 ] is not x
171
- assert m_new .rvs_to_values [test_new ] is not y
172
- assert same_storage (test_new .owner .inputs [3 ], x )
173
- assert same_storage (m_new .rvs_to_values [test_new ], y )
185
+ mu_new , sigma_new = test_new .owner .inputs [3 :5 ]
186
+ obs_new = m_new .rvs_to_values [test_new ]
187
+ assert mu_new is not mu
188
+ assert sigma_new is not sigma
189
+ assert obs_new is not obs
190
+ assert mu_new .type == mu .type
191
+ assert sigma_new .type == sigma .type
192
+ assert obs_new .type == obs .type
193
+ assert same_storage (mu , mu_new )
194
+ assert same_storage (sigma , sigma_new )
195
+ assert same_storage (obs , obs_new )
174
196
175
197
176
198
@pytest .mark .parametrize ("inline_views" , (False , True ))
0 commit comments