File tree Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Expand file tree Collapse file tree 1 file changed +23
-0
lines changed Original file line number Diff line number Diff line change @@ -505,6 +505,29 @@ def test_normal_scalar(self):
505
505
)
506
506
assert ppc ["a" ].shape == (nchains * ndraws , 5 )
507
507
508
+ def test_normal_scalar_idata (self ):
509
+ nchains = 2
510
+ ndraws = 500
511
+ with pm .Model () as model :
512
+ mu = pm .Normal ("mu" , 0.0 , 1.0 )
513
+ a = pm .Normal ("a" , mu = mu , sigma = 1 , observed = 0.0 )
514
+ trace = pm .sample (
515
+ draws = ndraws ,
516
+ chains = nchains ,
517
+ return_inferencedata = False ,
518
+ discard_tuned_samples = False ,
519
+ )
520
+
521
+ assert not isinstance (trace , InferenceData )
522
+
523
+ with model :
524
+ # test keep_size parameter and idata input
525
+ idata = pm .to_inference_data (trace )
526
+ assert isinstance (idata , InferenceData )
527
+
528
+ ppc = pm .sample_posterior_predictive (idata , keep_size = True , return_inferencedata = False )
529
+ assert ppc ["a" ].shape == (nchains , ndraws )
530
+
508
531
def test_normal_vector (self , caplog ):
509
532
with pm .Model () as model :
510
533
mu = pm .Normal ("mu" , 0.0 , 1.0 )
You can’t perform that action at this time.
0 commit comments