Skip to content

Commit 8efc137

Browse files
adding test for model_config_formatter helper function (#222)
* adding test for model_config_formatter helper function * fixing typo
1 parent cb6b864 commit 8efc137

File tree

1 file changed

+16
-14
lines changed

1 file changed

+16
-14
lines changed

pymc_experimental/tests/test_model_builder.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -150,20 +150,6 @@ def test_save_load(fitted_model_instance):
150150
temp.close()
151151

152152

153-
def test_convert_dims_to_tuple(fitted_model_instance):
154-
model_config = {
155-
"a": {
156-
"loc": 0,
157-
"scale": 10,
158-
"dims": [
159-
"x",
160-
],
161-
},
162-
}
163-
converted_model_config = fitted_model_instance._model_config_formatting(model_config)
164-
assert converted_model_config["a"]["dims"] == ("x",)
165-
166-
167153
def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
168154
if check_idata:
169155
assert fitted_model_instance.idata is not None
@@ -228,6 +214,22 @@ def test_sample_posterior_predictive(fitted_model_instance, combined):
228214
assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)
229215

230216

217+
def test_model_config_formatting():
218+
model_config = {
219+
"a": {
220+
"loc": [0, 0],
221+
"scale": 10,
222+
"dims": [
223+
"x",
224+
],
225+
},
226+
}
227+
model_builder = test_ModelBuilder()
228+
converted_model_config = model_builder._model_config_formatting(model_config)
229+
np.testing.assert_equal(converted_model_config["a"]["dims"], ("x",))
230+
np.testing.assert_equal(converted_model_config["a"]["loc"], np.array([0, 0]))
231+
232+
231233
def test_id():
232234
model_builder = test_ModelBuilder()
233235
expected_id = hashlib.sha256(

0 commit comments

Comments
 (0)