Skip to content

Commit fbda63e

Browse files
committed
black
1 parent cd4ce13 commit fbda63e

File tree

1 file changed

+96
-75
lines changed

1 file changed

+96
-75
lines changed

pymc/tests/test_sampling.py

Lines changed: 96 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def test_return_inferencedata(self):
204204
assert len(result._groups_warmup) > 0
205205

206206
# inferencedata without tuning, with idata_kwargs
207-
prior = pm.sample_prior_predictive()
207+
prior = pm.sample_prior_predictive(return_inferencedata=False)
208208
result = pm.sample(
209209
**kwargs,
210210
return_inferencedata=True,
@@ -472,7 +472,6 @@ def test_normal_scalar(self):
472472
trace = pm.sample(
473473
draws=ndraws,
474474
chains=nchains,
475-
return_inferencedata=False,
476475
)
477476

478477
with model:
@@ -486,43 +485,23 @@ def test_normal_scalar(self):
486485

487486
# test keep_size parameter
488487
ppc = pm.sample_posterior_predictive(trace, keep_size=True)
489-
assert ppc["a"].shape == (nchains, ndraws)
488+
assert ppc.posterior_predictive["a"].shape == (1, nchains, ndraws)
490489

491490
# test default case
492491
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
493-
assert "a" in ppc
494-
assert ppc["a"].shape == (nchains * ndraws,)
492+
assert "a" in ppc.posterior_predictive.data_vars
493+
assert ppc.posterior_predictive["a"].shape == (1, nchains * ndraws)
495494
# mu's standard deviation may have changed thanks to a's observed
496-
_, pval = stats.kstest(ppc["a"] - trace["mu"], stats.norm(loc=0, scale=1).cdf)
495+
_, pval = stats.kstest(
496+
ppc.posterior_predictive["a"] - trace.posterior["mu"],
497+
stats.norm(loc=0, scale=1).cdf,
498+
)
497499
assert pval > 0.001
498500

499501
# size argument not introduced to fast version [2019/08/20:rpg]
500502
with model:
501503
ppc = pm.sample_posterior_predictive(trace, size=5, var_names=["a"])
502-
assert ppc["a"].shape == (nchains * ndraws, 5)
503-
504-
def test_normal_scalar_idata(self):
505-
nchains = 2
506-
ndraws = 500
507-
with pm.Model() as model:
508-
mu = pm.Normal("mu", 0.0, 1.0)
509-
a = pm.Normal("a", mu=mu, sigma=1, observed=0.0)
510-
trace = pm.sample(
511-
draws=ndraws,
512-
chains=nchains,
513-
return_inferencedata=False,
514-
discard_tuned_samples=False,
515-
)
516-
517-
assert not isinstance(trace, InferenceData)
518-
519-
with model:
520-
# test keep_size parameter and idata input
521-
idata = pm.to_inference_data(trace)
522-
assert isinstance(idata, InferenceData)
523-
524-
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
525-
assert ppc["a"].shape == (nchains, ndraws)
504+
assert ppc.posterior_predictive["a"].shape == (1, nchains * ndraws, 5)
526505

527506
def test_normal_vector(self, caplog):
528507
with pm.Model() as model:
@@ -532,25 +511,35 @@ def test_normal_vector(self, caplog):
532511

533512
with model:
534513
# test list input
535-
ppc0 = pm.sample_posterior_predictive([model.initial_point], samples=10)
536-
ppc = pm.sample_posterior_predictive(trace, samples=12, var_names=[])
514+
ppc0 = pm.sample_posterior_predictive(
515+
[model.initial_point], return_inferencedata=False, samples=10
516+
)
517+
ppc = pm.sample_posterior_predictive(
518+
trace, return_inferencedata=False, samples=12, var_names=[]
519+
)
537520
assert len(ppc) == 0
538521

539522
# test keep_size parameter
540-
ppc = pm.sample_posterior_predictive(trace, keep_size=True)
523+
ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False, keep_size=True)
541524
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
542525
with pytest.warns(UserWarning):
543-
ppc = pm.sample_posterior_predictive(trace, samples=12, var_names=["a"])
526+
ppc = pm.sample_posterior_predictive(
527+
trace, return_inferencedata=False, samples=12, var_names=["a"]
528+
)
544529
assert "a" in ppc
545530
assert ppc["a"].shape == (12, 2)
546531

547532
with pytest.warns(UserWarning):
548-
ppc = pm.sample_posterior_predictive(trace, samples=12, var_names=["a"])
533+
ppc = pm.sample_posterior_predictive(
534+
trace, return_inferencedata=False, samples=12, var_names=["a"]
535+
)
549536
assert "a" in ppc
550537
assert ppc["a"].shape == (12, 2)
551538

552539
# size unsupported by fast_ version argument. [2019/08/19:rpg]
553-
ppc = pm.sample_posterior_predictive(trace, samples=10, var_names=["a"], size=4)
540+
ppc = pm.sample_posterior_predictive(
541+
trace, return_inferencedata=False, samples=10, var_names=["a"], size=4
542+
)
554543
assert "a" in ppc
555544
assert ppc["a"].shape == (10, 4, 2)
556545

@@ -567,7 +556,7 @@ def test_normal_vector_idata(self, caplog):
567556
idata = pm.to_inference_data(trace)
568557
assert isinstance(idata, InferenceData)
569558

570-
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
559+
ppc = pm.sample_posterior_predictive(idata, return_inferencedata=False, keep_size=True)
571560
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
572561

573562
def test_exceptions(self, caplog):
@@ -600,11 +589,15 @@ def test_vector_observed(self):
600589
# TODO: Assert something about the output
601590
# ppc = pm.sample_posterior_predictive(idata, samples=12, var_names=[])
602591
# assert len(ppc) == 0
603-
ppc = pm.sample_posterior_predictive(idata, samples=12, var_names=["a"])
592+
ppc = pm.sample_posterior_predictive(
593+
idata, return_inferencedata=False, samples=12, var_names=["a"]
594+
)
604595
assert "a" in ppc
605596
assert ppc["a"].shape == (12, 2)
606597

607-
ppc = pm.sample_posterior_predictive(idata, samples=10, var_names=["a"], size=4)
598+
ppc = pm.sample_posterior_predictive(
599+
idata, return_inferencedata=False, samples=10, var_names=["a"], size=4
600+
)
608601
assert "a" in ppc
609602
assert ppc["a"].shape == (10, 4, 2)
610603

@@ -616,9 +609,13 @@ def test_sum_normal(self):
616609

617610
with model:
618611
# test list input
619-
ppc0 = pm.sample_posterior_predictive([model.initial_point], samples=10)
612+
ppc0 = pm.sample_posterior_predictive(
613+
[model.initial_point], return_inferencedata=False, samples=10
614+
)
620615
assert ppc0 == {}
621-
ppc = pm.sample_posterior_predictive(idata, samples=1000, var_names=["b"])
616+
ppc = pm.sample_posterior_predictive(
617+
idata, return_inferencedata=False, samples=1000, var_names=["b"]
618+
)
622619
assert len(ppc) == 1
623620
assert ppc["b"].shape == (1000,)
624621
scale = np.sqrt(1 + 0.2 ** 2)
@@ -637,7 +634,7 @@ def test_model_not_drawable_prior(self):
637634
with pytest.raises(NotImplementedError) as excinfo:
638635
pm.sample_prior_predictive(50)
639636
assert "Cannot sample" in str(excinfo.value)
640-
samples = pm.sample_posterior_predictive(idata, 40)
637+
samples = pm.sample_posterior_predictive(idata, 40, return_inferencedata=False)
641638
assert samples["foo"].shape == (40, 200)
642639

643640
def test_model_shared_variable(self):
@@ -660,7 +657,7 @@ def test_model_shared_variable(self):
660657
samples = 100
661658
with model:
662659
post_pred = pm.sample_posterior_predictive(
663-
trace, samples=samples, var_names=["p", "obs"]
660+
trace, return_inferencedata=False, samples=samples, var_names=["p", "obs"]
664661
)
665662

666663
expected_p = np.array([logistic.eval({coeff: val}) for val in trace["x"][:samples]])
@@ -694,6 +691,7 @@ def test_deterministic_of_observed(self):
694691
rtol = 1e-5 if aesara.config.floatX == "float64" else 1e-4
695692

696693
ppc = pm.sample_posterior_predictive(
694+
return_inferencedata=False,
697695
model=model,
698696
trace=trace,
699697
samples=len(trace) * nchains,
@@ -728,6 +726,7 @@ def test_deterministic_of_observed_modified_interface(self):
728726
trace, varnames=[n for n in trace.varnames if n != "out"]
729727
).to_dict("records")
730728
ppc = pm.sample_posterior_predictive(
729+
return_inferencedata=False,
731730
model=model,
732731
trace=ppc_trace,
733732
samples=len(ppc_trace),
@@ -745,7 +744,7 @@ def test_variable_type(self):
745744
trace = pm.sample(compute_convergence_checks=False, return_inferencedata=False)
746745

747746
with model:
748-
ppc = pm.sample_posterior_predictive(trace, samples=1)
747+
ppc = pm.sample_posterior_predictive(trace, return_inferencedata=False, samples=1)
749748
assert ppc["a"].dtype.kind == "f"
750749
assert ppc["b"].dtype.kind == "i"
751750

@@ -918,7 +917,7 @@ def test_ignores_observed(self):
918917
positive_mu = pm.Deterministic("positive_mu", np.abs(mu))
919918
z = -1 - positive_mu
920919
pm.Normal("x_obs", mu=z, sigma=1, observed=observed_data)
921-
prior = pm.sample_prior_predictive()
920+
prior = pm.sample_prior_predictive(return_inferencedata=False)
922921

923922
assert "observed_data" not in prior
924923
assert (prior["mu"] < -90).all()
@@ -932,8 +931,12 @@ def test_respects_shape(self):
932931
with pm.Model():
933932
mu = pm.Gamma("mu", 3, 1, size=1)
934933
goals = pm.Poisson("goals", mu, size=shape)
935-
trace1 = pm.sample_prior_predictive(10, var_names=["mu", "mu", "goals"])
936-
trace2 = pm.sample_prior_predictive(10, var_names=["mu", "goals"])
934+
trace1 = pm.sample_prior_predictive(
935+
10, return_inferencedata=False, var_names=["mu", "mu", "goals"]
936+
)
937+
trace2 = pm.sample_prior_predictive(
938+
10, return_inferencedata=False, var_names=["mu", "goals"]
939+
)
937940
if shape == 2: # want to test shape as an int
938941
shape = (2,)
939942
assert trace1["goals"].shape == (10,) + shape
@@ -944,7 +947,7 @@ def test_multivariate(self):
944947
m = pm.Multinomial("m", n=5, p=np.array([0.25, 0.25, 0.25, 0.25]))
945948
trace = pm.sample_prior_predictive(10)
946949

947-
assert trace["m"].shape == (10, 4)
950+
assert trace.prior["m"].shape == (1, 10, 4)
948951

949952
def test_multivariate2(self):
950953
# Added test for issue #3271
@@ -955,8 +958,12 @@ def test_multivariate2(self):
955958
burned_trace = pm.sample(
956959
20, tune=10, cores=1, return_inferencedata=False, compute_convergence_checks=False
957960
)
958-
sim_priors = pm.sample_prior_predictive(samples=20, model=dm_model)
959-
sim_ppc = pm.sample_posterior_predictive(burned_trace, samples=20, model=dm_model)
961+
sim_priors = pm.sample_prior_predictive(
962+
return_inferencedata=False, samples=20, model=dm_model
963+
)
964+
sim_ppc = pm.sample_posterior_predictive(
965+
burned_trace, return_inferencedata=False, samples=20, model=dm_model
966+
)
960967
assert sim_priors["probs"].shape == (20, 6)
961968
assert sim_priors["obs"].shape == (20,) + mn_data.shape
962969
assert sim_ppc["obs"].shape == (20,) + mn_data.shape
@@ -987,9 +994,9 @@ def test_transformed(self):
987994
y = pm.Binomial("y", n=at_bats, p=thetas, observed=hits)
988995
gen = pm.sample_prior_predictive(draws)
989996

990-
assert gen["phi"].shape == (draws,)
991-
assert gen["y"].shape == (draws, n)
992-
assert "thetas" in gen
997+
assert gen.prior["phi"].shape == (1, draws)
998+
assert gen.prior_predictive["y"].shape == (1, draws, n)
999+
assert "thetas" in gen.prior.data_vars
9931000

9941001
def test_shared(self):
9951002
n1 = 10
@@ -1002,16 +1009,16 @@ def test_shared(self):
10021009
o = pm.Deterministic("o", obs)
10031010
gen1 = pm.sample_prior_predictive(draws)
10041011

1005-
assert gen1["y"].shape == (draws, n1)
1006-
assert gen1["o"].shape == (draws, n1)
1012+
assert gen1.prior["y"].shape == (1, draws, n1)
1013+
assert gen1.prior["o"].shape == (1, draws, n1)
10071014

10081015
n2 = 20
10091016
obs.set_value(np.random.rand(n2) < 0.5)
10101017
with m:
10111018
gen2 = pm.sample_prior_predictive(draws)
10121019

1013-
assert gen2["y"].shape == (draws, n2)
1014-
assert gen2["o"].shape == (draws, n2)
1020+
assert gen2.prior["y"].shape == (1, draws, n2)
1021+
assert gen2.prior["o"].shape == (1, draws, n2)
10151022

10161023
def test_density_dist(self):
10171024
obs = np.random.normal(-1, 0.1, size=10)
@@ -1025,7 +1032,7 @@ def test_density_dist(self):
10251032
random=lambda mu, sd, rng=None, size=None: rng.normal(loc=mu, scale=sd, size=size),
10261033
observed=obs,
10271034
)
1028-
prior = pm.sample_prior_predictive()
1035+
prior = pm.sample_prior_predictive(return_inferencedata=False)
10291036

10301037
npt.assert_almost_equal(prior["a"].mean(), 0, decimal=1)
10311038

@@ -1035,17 +1042,17 @@ def test_shape_edgecase(self):
10351042
sd = pm.Uniform("sd", lower=2, upper=3)
10361043
x = pm.Normal("x", mu=mu, sigma=sd, size=5)
10371044
prior = pm.sample_prior_predictive(10)
1038-
assert prior["mu"].shape == (10, 5)
1045+
assert prior.prior["mu"].shape == (1, 10, 5)
10391046

10401047
def test_zeroinflatedpoisson(self):
10411048
with pm.Model():
10421049
theta = pm.Beta("theta", alpha=1, beta=1)
10431050
psi = pm.HalfNormal("psi", sd=1)
10441051
pm.ZeroInflatedPoisson("suppliers", psi=psi, theta=theta, size=20)
10451052
gen_data = pm.sample_prior_predictive(samples=5000)
1046-
assert gen_data["theta"].shape == (5000,)
1047-
assert gen_data["psi"].shape == (5000,)
1048-
assert gen_data["suppliers"].shape == (5000, 20)
1053+
assert gen_data.prior["theta"].shape == (1, 5000)
1054+
assert gen_data.prior["psi"].shape == (1, 5000)
1055+
assert gen_data.prior["suppliers"].shape == (1, 5000, 20)
10491056

10501057
def test_potentials_warning(self):
10511058
warning_msg = "The effect of Potentials on other parameters is ignored during"
@@ -1075,10 +1082,10 @@ def ub_interval_forward(x, ub):
10751082
)
10761083

10771084
# Check values are correct
1078-
assert np.allclose(prior["ub_log__"], np.log(prior["ub"]))
1085+
assert np.allclose(prior.prior["ub_log__"].data, np.log(prior.prior["ub"].data))
10791086
assert np.allclose(
1080-
prior["x_interval__"],
1081-
ub_interval_forward(prior["x"], prior["ub"]),
1087+
prior.prior["x_interval__"].data,
1088+
ub_interval_forward(prior.prior["x"].data, prior.prior["ub"].data),
10821089
)
10831090

10841091
# Check that it works when the original RVs are not mentioned in var_names
@@ -1090,9 +1097,16 @@ def ub_interval_forward(x, ub):
10901097
var_names=["ub_log__", "x_interval__"],
10911098
samples=10,
10921099
)
1093-
assert "ub" not in prior_transformed_only and "x" not in prior_transformed_only
1094-
assert np.allclose(prior["ub_log__"], prior_transformed_only["ub_log__"])
1095-
assert np.allclose(prior["x_interval__"], prior_transformed_only["x_interval__"])
1100+
assert (
1101+
"ub" not in prior_transformed_only.prior.data_vars
1102+
and "x" not in prior_transformed_only.prior.data_vars
1103+
)
1104+
assert np.allclose(
1105+
prior.prior["ub_log__"].data, prior_transformed_only.prior["ub_log__"].data
1106+
)
1107+
assert np.allclose(
1108+
prior.prior["x_interval__"], prior_transformed_only.prior["x_interval__"].data
1109+
)
10961110

10971111
def test_issue_4490(self):
10981112
# Test that samples do not depend on var_name order or, more fundamentally,
@@ -1112,27 +1126,34 @@ def test_issue_4490(self):
11121126
d = pm.Normal("d")
11131127
prior2 = pm.sample_prior_predictive(samples=1, var_names=["b", "a", "d", "c"])
11141128

1115-
assert prior1["a"] == prior2["a"]
1116-
assert prior1["b"] == prior2["b"]
1117-
assert prior1["c"] == prior2["c"]
1118-
assert prior1["d"] == prior2["d"]
1129+
assert prior1.prior["a"] == prior2.prior["a"]
1130+
assert prior1.prior["b"] == prior2.prior["b"]
1131+
assert prior1.prior["c"] == prior2.prior["c"]
1132+
assert prior1.prior["d"] == prior2.prior["d"]
11191133

11201134

11211135
class TestSamplePosteriorPredictive:
11221136
def test_point_list_arg_bug_spp(self, point_list_arg_bug_fixture):
11231137
pmodel, trace = point_list_arg_bug_fixture
11241138
with pmodel:
1125-
pp = pm.sample_posterior_predictive([trace[15]], var_names=["d"])
1139+
pp = pm.sample_posterior_predictive(
1140+
[trace[15]], return_inferencedata=False, var_names=["d"]
1141+
)
11261142

11271143
def test_sample_from_xarray_prior(self, point_list_arg_bug_fixture):
11281144
pmodel, trace = point_list_arg_bug_fixture
11291145

11301146
with pmodel:
1131-
prior = pm.sample_prior_predictive(samples=20)
1147+
prior = pm.sample_prior_predictive(
1148+
samples=20,
1149+
return_inferencedata=False,
1150+
)
11321151
idat = pm.to_inference_data(trace, prior=prior)
11331152

11341153
with pmodel:
1135-
pp = pm.sample_posterior_predictive(idat.prior, var_names=["d"])
1154+
pp = pm.sample_posterior_predictive(
1155+
idat.prior, return_inferencedata=False, var_names=["d"]
1156+
)
11361157

11371158
def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
11381159
pmodel, trace = point_list_arg_bug_fixture

0 commit comments

Comments
 (0)