Skip to content

Commit 84aa791

Browse files
michaelraczyckiMichal Raczycki
and
Michal Raczycki
authored
Model builder tests (pymc-devs#114)
* changed id of model_builder to property, added type hinting * added id property test * autolinting * added test for model_builder predict function * extended test_model_builder.py with test_extract_samples * last test for model_builder * addressing requested changes, created support function from code repetitions --------- Co-authored-by: Michal Raczycki <[email protected]>
1 parent 6f67dec commit 84aa791

File tree

1 file changed

+66
-9
lines changed

1 file changed

+66
-9
lines changed

pymc_experimental/tests/test_model_builder.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515

16+
import hashlib
1617
import sys
1718
import tempfile
1819

@@ -79,14 +80,19 @@ def create_sample_input(cls):
7980

8081
return data, model_config, sampler_config
8182

83+
@staticmethod
84+
def initial_build_and_fit(check_idata=True):
85+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
86+
model = test_ModelBuilder(model_config, sampler_config, data)
87+
model.fit()
88+
if check_idata:
89+
assert model.idata is not None
90+
assert "posterior" in model.idata.groups()
91+
return model
8292

83-
def test_fit():
84-
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
85-
model = test_ModelBuilder(model_config, sampler_config, data)
86-
model.fit()
87-
assert model.idata is not None
88-
assert "posterior" in model.idata.groups()
8993

94+
def test_fit():
95+
model = test_ModelBuilder.initial_build_and_fit()
9096
x_pred = np.random.uniform(low=0, high=1, size=100)
9197
prediction_data = pd.DataFrame({"input": x_pred})
9298
pred = model.predict(prediction_data)
@@ -99,9 +105,7 @@ def test_fit():
99105
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
100106
)
101107
def test_save_load():
102-
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
103-
model = test_ModelBuilder(model_config, sampler_config, data)
104-
model.fit()
108+
model = test_ModelBuilder.initial_build_and_fit(False)
105109
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
106110
model.save(temp.name)
107111
model2 = test_ModelBuilder.load(temp.name)
@@ -113,3 +117,56 @@ def test_save_load():
113117
pred2 = model2.predict(prediction_data)
114118
assert pred1["y_model"].shape == pred2["y_model"].shape
115119
temp.close()
120+
121+
122+
def test_predict():
123+
model = test_ModelBuilder.initial_build_and_fit()
124+
x_pred = np.random.uniform(low=0, high=1, size=100)
125+
prediction_data = pd.DataFrame({"input": x_pred})
126+
pred = model.predict(prediction_data)
127+
assert "y_model" in pred
128+
assert isinstance(pred, dict)
129+
assert len(prediction_data.input.values) == len(pred["y_model"])
130+
assert isinstance(pred["y_model"][0], float)
131+
132+
133+
def test_predict_posterior():
134+
model = test_ModelBuilder.initial_build_and_fit()
135+
x_pred = np.random.uniform(low=0, high=1, size=100)
136+
prediction_data = pd.DataFrame({"input": x_pred})
137+
pred = model.predict_posterior(prediction_data)
138+
assert "y_model" in pred
139+
assert isinstance(pred, dict)
140+
assert len(prediction_data.input.values) == len(pred["y_model"][0])
141+
assert isinstance(pred["y_model"][0], np.ndarray)
142+
143+
144+
def test_extract_samples():
145+
# create a fake InferenceData object
146+
with pm.Model() as model:
147+
x = pm.Normal("x", mu=0, sigma=1)
148+
intercept = pm.Normal("intercept", mu=0, sigma=1)
149+
y_model = pm.Normal("y_model", mu=x * intercept, sigma=1, observed=[0, 1, 2])
150+
151+
idata = pm.sample(1000, tune=1000)
152+
post_pred = pm.sample_posterior_predictive(idata)
153+
154+
# call the function and get the output
155+
samples_dict = test_ModelBuilder._extract_samples(post_pred)
156+
157+
# assert that the keys and values are correct
158+
assert len(samples_dict) == len(post_pred.posterior_predictive)
159+
for key in post_pred.posterior_predictive:
160+
expected_value = post_pred.posterior_predictive[key].to_numpy()[0]
161+
assert np.array_equal(samples_dict[key], expected_value)
162+
163+
164+
def test_id():
165+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
166+
model = test_ModelBuilder(model_config, sampler_config, data)
167+
168+
expected_id = hashlib.sha256(
169+
str(model_config.values()).encode() + model.version.encode() + model._model_type.encode()
170+
).hexdigest()[:16]
171+
172+
assert model.id == expected_id

0 commit comments

Comments
 (0)