@@ -161,3 +161,32 @@ def test_get_batched_jittered_initial_points():
161
161
162
162
assert ips [0 ].shape == (2 , 2 , 3 )
163
163
assert np .all (ips [0 ][0 ] != ips [0 ][1 ])
164
+
165
+
166
+ @pytest .mark .parametrize ("random_seed" , (None , 123 ))
167
+ @pytest .mark .parametrize ("chains" , (1 , 2 ))
168
+ def test_seeding (chains , random_seed ):
169
+ sample_kwargs = dict (
170
+ tune = 100 ,
171
+ draws = 5 ,
172
+ chains = chains ,
173
+ random_seed = random_seed ,
174
+ )
175
+
176
+ with pm .Model (rng_seeder = 456 ) as m :
177
+ pm .Normal ("x" , mu = 0 , sigma = 1 )
178
+ result1 = sample_numpyro_nuts (** sample_kwargs )
179
+
180
+ with pm .Model (rng_seeder = 456 ) as m :
181
+ pm .Normal ("x" , mu = 0 , sigma = 1 )
182
+ result2 = sample_numpyro_nuts (** sample_kwargs )
183
+ result3 = sample_numpyro_nuts (** sample_kwargs )
184
+
185
+ assert np .all (result1 .posterior ["x" ] == result2 .posterior ["x" ])
186
+ expected_equal_result3 = random_seed is not None
187
+ assert np .all (result2 .posterior ["x" ] == result3 .posterior ["x" ]) == expected_equal_result3
188
+
189
+ if chains > 1 :
190
+ assert np .all (result1 .posterior ["x" ].sel (chain = 0 ) != result1 .posterior ["x" ].sel (chain = 1 ))
191
+ assert np .all (result2 .posterior ["x" ].sel (chain = 0 ) != result2 .posterior ["x" ].sel (chain = 1 ))
192
+ assert np .all (result3 .posterior ["x" ].sel (chain = 0 ) != result3 .posterior ["x" ].sel (chain = 1 ))
0 commit comments