Skip to content

Commit 38b8883

Browse files
committed
Restructure tests. Fix load().
1 parent 218761b commit 38b8883

File tree

2 files changed

+27
-85
lines changed

2 files changed

+27
-85
lines changed

pymc_experimental/model_builder.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The PyMC Developers
1+
# Copyright 2023 The PyMC Developers
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -189,13 +189,14 @@ def load(cls, fname):
189189
self = cls(
190190
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
191191
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
192-
idata.data,
192+
idata.fit_data.to_dataframe(),
193193
)
194194
self.idata = idata
195195
if self.id() != idata.attrs["id"]:
196196
raise ValueError(
197197
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'"
198198
)
199+
199200
return self
200201

201202
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
@@ -307,7 +308,7 @@ def predict_posterior(
307308
>>> model = LinearModel(model_config, sampler_config)
308309
>>> idata = model.fit(data)
309310
>>> x_pred = []
310-
>>> prediction_data = pd.DataFrame({'input':x_pred})
311+
>>> prediction_data = pd.DataFrame({'input': x_pred})
311312
# samples
312313
>>> pred_mean = model.predict_posterior(prediction_data)
313314
"""
+23-82
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2022 The PyMC Developers
1+
# Copyright 2023 The PyMC Developers
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,6 +13,8 @@
1313
# limitations under the License.
1414

1515

16+
import tempfile
17+
1618
import numpy as np
1719
import pandas as pd
1820
import pymc as pm
@@ -77,93 +79,32 @@ def create_sample_input(cls):
7779

7880

7981
def test_fit():
80-
with pm.Model() as model:
81-
x = np.linspace(start=0, stop=1, num=100)
82-
y = 5 * x + 3
83-
x = pm.MutableData("x", x)
84-
y_data = pm.MutableData("y_data", y)
85-
86-
a_loc = 7
87-
a_scale = 3
88-
b_loc = 5
89-
b_scale = 3
90-
obs_error = 2
91-
92-
a = pm.Normal("a", a_loc, sigma=a_scale)
93-
b = pm.Normal("b", b_loc, sigma=b_scale)
94-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
95-
96-
y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
97-
98-
idata = pm.sample(tune=100, draws=200, chains=1, cores=1, target_accept=0.5)
99-
idata.extend(pm.sample_prior_predictive())
100-
idata.extend(pm.sample_posterior_predictive(idata))
101-
10282
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
103-
model_2 = test_ModelBuilder(model_config, sampler_config, data)
104-
model_2.idata = model_2.fit()
105-
assert str(model_2.idata.groups) == str(idata.groups)
106-
83+
model = test_ModelBuilder(model_config, sampler_config, data)
84+
model.fit()
85+
assert model.idata is not None
86+
assert "posterior" in model.idata.groups()
10787

108-
def test_predict():
10988
x_pred = np.random.uniform(low=0, high=1, size=100)
11089
prediction_data = pd.DataFrame({"input": x_pred})
111-
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
112-
model_2 = test_ModelBuilder(model_config, sampler_config, data)
113-
model_2.idata = model_2.fit()
114-
model_2.predict(prediction_data)
115-
with pm.Model() as model:
116-
x = np.linspace(start=0, stop=1, num=100)
117-
y = 5 * x + 3
118-
x = pm.MutableData("x", x)
119-
y_data = pm.MutableData("y_data", y)
120-
a_loc = 7
121-
a_scale = 3
122-
b_loc = 5
123-
b_scale = 3
124-
obs_error = 2
125-
126-
a = pm.Normal("a", a_loc, sigma=a_scale)
127-
b = pm.Normal("b", b_loc, sigma=b_scale)
128-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
129-
130-
y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
90+
pred = model.predict(prediction_data)
91+
assert "y_model" in pred.keys()
92+
post_pred = model.predict_posterior(prediction_data)
93+
assert "y_model" in post_pred.keys()
13194

132-
idata = pm.sample(tune=10, draws=20, chains=3, cores=1)
133-
idata.extend(pm.sample_prior_predictive())
134-
idata.extend(pm.sample_posterior_predictive(idata))
135-
y_test = pm.sample_posterior_predictive(idata)
136-
137-
assert str(model_2.idata.groups) == str(idata.groups)
13895

96+
def test_save_load():
97+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
98+
model = test_ModelBuilder(model_config, sampler_config, data)
99+
temp = tempfile.TemporaryFile()
100+
model.fit()
101+
model.save(temp.name)
102+
model2 = test_ModelBuilder.load(temp.name)
103+
assert model.idata.groups() == model2.idata.groups()
139104

140-
def test_predict_posterior():
141105
x_pred = np.random.uniform(low=0, high=1, size=100)
142106
prediction_data = pd.DataFrame({"input": x_pred})
143-
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
144-
model_2 = test_ModelBuilder(model_config, sampler_config, data)
145-
model_2.idata = model_2.fit()
146-
model_2.predict_posterior(prediction_data)
147-
with pm.Model() as model:
148-
x = np.linspace(start=0, stop=1, num=100)
149-
y = 5 * x + 3
150-
x = pm.MutableData("x", x)
151-
y_data = pm.MutableData("y_data", y)
152-
a_loc = 7
153-
a_scale = 3
154-
b_loc = 5
155-
b_scale = 3
156-
obs_error = 2
157-
158-
a = pm.Normal("a", a_loc, sigma=a_scale)
159-
b = pm.Normal("b", b_loc, sigma=b_scale)
160-
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
161-
162-
y_model = pm.Normal("y_model", a + b * x, obs_error, observed=y_data)
163-
164-
idata = pm.sample(tune=10, draws=20, chains=3, cores=1)
165-
idata.extend(pm.sample_prior_predictive())
166-
idata.extend(pm.sample_posterior_predictive(idata))
167-
y_test = pm.sample_posterior_predictive(idata)
168-
169-
assert str(model_2.idata.groups) == str(idata.groups)
107+
pred1 = model.predict(prediction_data)
108+
pred2 = model2.predict(prediction_data)
109+
assert pred1["y_model"].shape == pred2["y_model"].shape
110+
temp.close()

0 commit comments

Comments
 (0)