Skip to content

Commit 5f7b185

Browse files
ModelBuilder.load versatility improvements (#210)
* fixing dims format, enabling input param preservation * adding additional test for implemented features * fixing typehinting in linearmodel * introducing dims to model_builder tests to check for dim format preservation
1 parent dd3c44d commit 5f7b185

File tree

4 files changed

+92
-42
lines changed

4 files changed

+92
-42
lines changed

pymc_experimental/linearmodel.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def default_sampler_config(self):
3737
"target_accept": 0.95,
3838
}
3939

40+
@property
41+
def _serializable_model_config(self) -> Dict:
42+
return self.model_config
43+
4044
@property
4145
def output_var(self):
4246
return "y_hat"

pymc_experimental/model_builder.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def sample_model(self, **kwargs):
297297
idata.extend(pm.sample_prior_predictive())
298298
idata.extend(pm.sample_posterior_predictive(idata))
299299

300-
self.set_idata_attrs(idata)
300+
idata = self.set_idata_attrs(idata)
301301
return idata
302302

303303
def set_idata_attrs(self, idata=None):
@@ -338,6 +338,10 @@ def set_idata_attrs(self, idata=None):
338338
idata.attrs["version"] = self.version
339339
idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
340340
idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
341+
# Only classes with non-dataset parameters will implement save_input_params
342+
if hasattr(self, "_save_input_params"):
343+
self._save_input_params(idata)
344+
return idata
341345

342346
def save(self, fname: str) -> None:
343347
"""
@@ -375,6 +379,17 @@ def save(self, fname: str) -> None:
375379
else:
376380
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
377381

382+
@classmethod
383+
def _convert_dims_to_tuple(cls, model_config: Dict) -> Dict:
384+
for key in model_config:
385+
if (
386+
isinstance(model_config[key], dict)
387+
and "dims" in model_config[key]
388+
and isinstance(model_config[key]["dims"], list)
389+
):
390+
model_config[key]["dims"] = tuple(model_config[key]["dims"])
391+
return model_config
392+
378393
@classmethod
379394
def load(cls, fname: str):
380395
"""
@@ -403,8 +418,10 @@ def load(cls, fname: str):
403418
"""
404419
filepath = Path(str(fname))
405420
idata = az.from_netcdf(filepath)
421+
# needs to be converted, because json.loads was changing tuple to list
422+
model_config = cls._convert_dims_to_tuple(json.loads(idata.attrs["model_config"]))
406423
model = cls(
407-
model_config=json.loads(idata.attrs["model_config"]),
424+
model_config=model_config,
408425
sampler_config=json.loads(idata.attrs["sampler_config"]),
409426
)
410427
model.idata = idata
@@ -480,6 +497,7 @@ def fit(
480497
combined_data = pd.concat([X_df, y], axis=1)
481498
assert all(combined_data.columns), "All columns must have non-empty names"
482499
self.idata.add_groups(fit_data=combined_data.to_xarray()) # type: ignore
500+
483501
return self.idata # type: ignore
484502

485503
def predict(

pymc_experimental/tests/test_linearmodel.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,24 @@ def fitted_linear_model_instance(toy_X, toy_y):
6464
return model
6565

6666

67+
@pytest.mark.skipif(
68+
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
69+
)
70+
def test_save_load(fitted_linear_model_instance):
71+
model = fitted_linear_model_instance
72+
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
73+
model.save(temp.name)
74+
model2 = LinearModel.load(temp.name)
75+
assert model.idata.groups() == model2.idata.groups()
76+
77+
X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
78+
pred1 = model.predict(X_pred, random_seed=423)
79+
pred2 = model2.predict(X_pred, random_seed=423)
80+
# Predictions should be identical
81+
np.testing.assert_array_equal(pred1, pred2)
82+
temp.close()
83+
84+
6785
def test_save_without_fit_raises_runtime_error(toy_X, toy_y):
6886
test_model = LinearModel()
6987
with pytest.raises(RuntimeError):
@@ -83,24 +101,6 @@ def test_fit(fitted_linear_model_instance):
83101
assert isinstance(post_pred, xr.DataArray)
84102

85103

86-
@pytest.mark.skipif(
87-
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
88-
)
89-
def test_save_load(fitted_linear_model_instance):
90-
model = fitted_linear_model_instance
91-
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
92-
model.save(temp.name)
93-
model2 = LinearModel.load(temp.name)
94-
assert model.idata.groups() == model2.idata.groups()
95-
96-
X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
97-
pred1 = model.predict(X_pred, random_seed=423)
98-
pred2 = model2.predict(X_pred, random_seed=423)
99-
# Predictions should be identical
100-
np.testing.assert_array_equal(pred1, pred2)
101-
temp.close()
102-
103-
104104
def test_predict(fitted_linear_model_instance):
105105
model = fitted_linear_model_instance
106106
X_pred = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})

pymc_experimental/tests/test_model_builder.py

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import hashlib
16+
import json
1617
import sys
1718
import tempfile
1819
from typing import Dict
@@ -43,29 +44,35 @@ def toy_y(toy_X):
4344
@pytest.fixture(scope="module")
4445
def fitted_model_instance(toy_X, toy_y):
4546
sampler_config = {
46-
"draws": 500,
47-
"tune": 300,
47+
"draws": 100,
48+
"tune": 100,
4849
"chains": 2,
4950
"target_accept": 0.95,
5051
}
5152
model_config = {
52-
"a": {"loc": 0, "scale": 10},
53+
"a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
5354
"b": {"loc": 0, "scale": 10},
5455
"obs_error": 2,
5556
}
56-
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config)
57+
model = test_ModelBuilder(
58+
model_config=model_config, sampler_config=sampler_config, test_parameter="test_paramter"
59+
)
5760
model.fit(toy_X)
5861
return model
5962

6063

6164
class test_ModelBuilder(ModelBuilder):
65+
def __init__(self, model_config=None, sampler_config=None, test_parameter=None):
66+
self.test_parameter = test_parameter
67+
super().__init__(model_config=model_config, sampler_config=sampler_config)
6268

63-
_model_type = "LinearModel"
69+
_model_type = "test_model"
6470
version = "0.1"
6571

6672
def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
73+
coords = {"numbers": np.arange(len(X))}
6774
self.generate_and_preprocess_model_data(X, y)
68-
with pm.Model() as self.model:
75+
with pm.Model(coords=coords) as self.model:
6976
if model_config is None:
7077
model_config = self.default_model_config
7178
x = pm.MutableData("x", self.X["input"].values)
@@ -79,13 +86,16 @@ def build_model(self, X: pd.DataFrame, y: pd.Series, model_config=None):
7986
obs_error = model_config["obs_error"]
8087

8188
# priors
82-
a = pm.Normal("a", a_loc, sigma=a_scale)
89+
a = pm.Normal("a", a_loc, sigma=a_scale, dims=model_config["a"]["dims"])
8390
b = pm.Normal("b", b_loc, sigma=b_scale)
8491
obs_error = pm.HalfNormal("σ_model_fmc", obs_error)
8592

8693
# observed data
8794
output = pm.Normal("output", a + b * x, obs_error, shape=x.shape, observed=y_data)
8895

96+
def _save_input_params(self, idata):
97+
idata.attrs["test_paramter"] = json.dumps(self.test_parameter)
98+
8999
@property
90100
def output_var(self):
91101
return "output"
@@ -107,7 +117,7 @@ def generate_and_preprocess_model_data(self, X: pd.DataFrame, y: pd.Series):
107117
@property
108118
def default_model_config(self) -> Dict:
109119
return {
110-
"a": {"loc": 0, "scale": 10},
120+
"a": {"loc": 0, "scale": 10, "dims": ("numbers",)},
111121
"b": {"loc": 0, "scale": 10},
112122
"obs_error": 2,
113123
}
@@ -122,6 +132,38 @@ def default_sampler_config(self) -> Dict:
122132
}
123133

124134

135+
def test_save_input_params(fitted_model_instance):
136+
assert fitted_model_instance.idata.attrs["test_paramter"] == '"test_paramter"'
137+
138+
139+
def test_save_load(fitted_model_instance):
140+
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
141+
fitted_model_instance.save(temp.name)
142+
test_builder2 = test_ModelBuilder.load(temp.name)
143+
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
144+
assert fitted_model_instance.id == test_builder2.id
145+
x_pred = np.random.uniform(low=0, high=1, size=100)
146+
prediction_data = pd.DataFrame({"input": x_pred})
147+
pred1 = fitted_model_instance.predict(prediction_data["input"])
148+
pred2 = test_builder2.predict(prediction_data["input"])
149+
assert pred1.shape == pred2.shape
150+
temp.close()
151+
152+
153+
def test_convert_dims_to_tuple(fitted_model_instance):
154+
model_config = {
155+
"a": {
156+
"loc": 0,
157+
"scale": 10,
158+
"dims": [
159+
"x",
160+
],
161+
},
162+
}
163+
converted_model_config = fitted_model_instance._convert_dims_to_tuple(model_config)
164+
assert converted_model_config["a"]["dims"] == ("x",)
165+
166+
125167
def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
126168
if check_idata:
127169
assert fitted_model_instance.idata is not None
@@ -162,20 +204,6 @@ def test_fit_no_y(toy_X):
162204
@pytest.mark.skipif(
163205
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
164206
)
165-
def test_save_load(fitted_model_instance):
166-
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
167-
fitted_model_instance.save(temp.name)
168-
test_builder2 = test_ModelBuilder.load(temp.name)
169-
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
170-
171-
x_pred = np.random.uniform(low=0, high=1, size=100)
172-
prediction_data = pd.DataFrame({"input": x_pred})
173-
pred1 = fitted_model_instance.predict(prediction_data["input"])
174-
pred2 = test_builder2.predict(prediction_data["input"])
175-
assert pred1.shape == pred2.shape
176-
temp.close()
177-
178-
179207
def test_predict(fitted_model_instance):
180208
x_pred = np.random.uniform(low=0, high=1, size=100)
181209
prediction_data = pd.DataFrame({"input": x_pred})

0 commit comments

Comments
 (0)