Skip to content

Commit 58b2f11

Browse files
Disable use of Arviz in pymc3.tests.test_data_container
1 parent 02e170a commit 58b2f11

File tree

1 file changed

+53
-8
lines changed

1 file changed

+53
-8
lines changed

pymc3/tests/test_data_container.py

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,13 @@ def test_sample_posterior_predictive_after_set_data(self):
7070
y = pm.Data("y", [1.0, 2.0, 3.0])
7171
beta = pm.Normal("beta", 0, 10.0)
7272
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
73-
trace = pm.sample(1000, tune=1000, chains=1)
73+
trace = pm.sample(
74+
1000,
75+
tune=1000,
76+
chains=1,
77+
return_inferencedata=False,
78+
compute_convergence_checks=False,
79+
)
7480
# Predict on new data.
7581
with model:
7682
x_test = [5, 6, 9]
@@ -86,13 +92,27 @@ def test_sample_after_set_data(self):
8692
y = pm.Data("y", [1.0, 2.0, 3.0])
8793
beta = pm.Normal("beta", 0, 10.0)
8894
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
89-
pm.sample(1000, init=None, tune=1000, chains=1)
95+
pm.sample(
96+
1000,
97+
init=None,
98+
tune=1000,
99+
chains=1,
100+
return_inferencedata=False,
101+
compute_convergence_checks=False,
102+
)
90103
# Predict on new data.
91104
new_x = [5.0, 6.0, 9.0]
92105
new_y = [5.0, 6.0, 9.0]
93106
with model:
94107
pm.set_data(new_data={"x": new_x, "y": new_y})
95-
new_trace = pm.sample(1000, init=None, tune=1000, chains=1)
108+
new_trace = pm.sample(
109+
1000,
110+
init=None,
111+
tune=1000,
112+
chains=1,
113+
return_inferencedata=False,
114+
compute_convergence_checks=False,
115+
)
96116
pp_trace = pm.sample_posterior_predictive(new_trace, 1000)
97117

98118
assert pp_trace["obs"].shape == (1000, 3)
@@ -110,7 +130,14 @@ def test_shared_data_as_index(self):
110130
pm.Normal("obs", alpha[index], np.sqrt(1e-2), observed=y)
111131

112132
prior_trace = pm.sample_prior_predictive(1000, var_names=["alpha"])
113-
trace = pm.sample(1000, init=None, tune=1000, chains=1)
133+
trace = pm.sample(
134+
1000,
135+
init=None,
136+
tune=1000,
137+
chains=1,
138+
return_inferencedata=False,
139+
compute_convergence_checks=False,
140+
)
114141

115142
# Predict on new data
116143
new_index = np.array([0, 1, 2])
@@ -132,14 +159,18 @@ def test_shared_data_as_rv_input(self):
132159
with pm.Model() as m:
133160
x = pm.Data("x", [1.0, 2.0, 3.0])
134161
_ = pm.Normal("y", mu=x, size=3)
135-
trace = pm.sample(chains=1)
162+
trace = pm.sample(
163+
chains=1, return_inferencedata=False, compute_convergence_checks=False
164+
)
136165

137166
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), x.get_value(), atol=1e-1)
138167
np.testing.assert_allclose(np.array([1.0, 2.0, 3.0]), trace["y"].mean(0), atol=1e-1)
139168

140169
with m:
141170
pm.set_data({"x": np.array([2.0, 4.0, 6.0])})
142-
trace = pm.sample(chains=1)
171+
trace = pm.sample(
172+
chains=1, return_inferencedata=False, compute_convergence_checks=False
173+
)
143174

144175
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
145176
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)
@@ -175,7 +206,14 @@ def test_set_data_to_non_data_container_variables(self):
175206
y = np.array([1.0, 2.0, 3.0])
176207
beta = pm.Normal("beta", 0, 10.0)
177208
pm.Normal("obs", beta * x, np.sqrt(1e-2), observed=y)
178-
pm.sample(1000, init=None, tune=1000, chains=1)
209+
pm.sample(
210+
1000,
211+
init=None,
212+
tune=1000,
213+
chains=1,
214+
return_inferencedata=False,
215+
compute_convergence_checks=False,
216+
)
179217
with pytest.raises(TypeError) as error:
180218
pm.set_data({"beta": [1.1, 2.2, 3.3]}, model=model)
181219
error.match("defined as `pymc3.Data` inside the model")
@@ -188,7 +226,14 @@ def test_model_to_graphviz_for_model_with_data_container(self):
188226
beta = pm.Normal("beta", 0, 10.0)
189227
obs_sigma = floatX(np.sqrt(1e-2))
190228
pm.Normal("obs", beta * x, obs_sigma, observed=y)
191-
pm.sample(1000, init=None, tune=1000, chains=1)
229+
pm.sample(
230+
1000,
231+
init=None,
232+
tune=1000,
233+
chains=1,
234+
return_inferencedata=False,
235+
compute_convergence_checks=False,
236+
)
192237

193238
for formatting in {"latex", "latex_with_params"}:
194239
with pytest.raises(ValueError, match="Unsupported formatting"):

0 commit comments

Comments
 (0)