7
7
import pytensor .tensor as pt
8
8
import pytest
9
9
from arviz import InferenceData , dict_to_dataset
10
- from pymc import ImputationWarning , inputvars
11
10
from pymc .distributions import transforms
12
11
from pymc .logprob .abstract import _logprob
13
12
from pymc .model .fgraph import fgraph_from_model
13
+ from pymc .pytensorf import inputvars
14
14
from pymc .util import UNSET
15
15
from scipy .special import log_softmax , logsumexp
16
16
from scipy .stats import halfnorm , norm
@@ -45,9 +45,7 @@ def disaster_model():
45
45
early_rate = pm .Exponential ("early_rate" , 1.0 , initval = 3 )
46
46
late_rate = pm .Exponential ("late_rate" , 1.0 , initval = 1 )
47
47
rate = pm .math .switch (switchpoint >= years , early_rate , late_rate )
48
- with pytest .warns (ImputationWarning ), pytest .warns (
49
- RuntimeWarning , match = "invalid value encountered in cast"
50
- ):
48
+ with pytest .warns (Warning ):
51
49
disasters = pm .Poisson ("disasters" , rate , observed = disaster_data )
52
50
53
51
return disaster_model , years
@@ -294,7 +292,7 @@ def test_recover_marginals_basic():
294
292
295
293
with m :
296
294
prior = pm .sample_prior_predictive (
297
- samples = 20 ,
295
+ draws = 20 ,
298
296
random_seed = rng ,
299
297
return_inferencedata = False ,
300
298
)
@@ -337,7 +335,7 @@ def test_recover_marginals_coords():
337
335
338
336
with m :
339
337
prior = pm .sample_prior_predictive (
340
- samples = 20 ,
338
+ draws = 20 ,
341
339
random_seed = rng ,
342
340
return_inferencedata = False ,
343
341
)
@@ -364,7 +362,7 @@ def test_recover_batched_marginal():
364
362
365
363
with m :
366
364
prior = pm .sample_prior_predictive (
367
- samples = 20 ,
365
+ draws = 20 ,
368
366
random_seed = rng ,
369
367
return_inferencedata = False ,
370
368
)
@@ -394,7 +392,7 @@ def test_nested_recover_marginals():
394
392
395
393
with m :
396
394
prior = pm .sample_prior_predictive (
397
- samples = 20 ,
395
+ draws = 20 ,
398
396
random_seed = rng ,
399
397
return_inferencedata = False ,
400
398
)
@@ -565,7 +563,7 @@ def test_marginalized_transforms(transform, expected_warning):
565
563
w = w ,
566
564
comp_dists = pm .HalfNormal .dist ([1 , 2 , 3 ]),
567
565
initval = initval ,
568
- transform = transform ,
566
+ default_transform = transform ,
569
567
)
570
568
y = pm .Normal ("y" , 0 , sigma , observed = data )
571
569
@@ -583,7 +581,7 @@ def test_marginalized_transforms(transform, expected_warning):
583
581
),
584
582
),
585
583
initval = initval ,
586
- transform = transform ,
584
+ default_transform = transform ,
587
585
)
588
586
y = pm .Normal ("y" , 0 , sigma , observed = data )
589
587
0 commit comments