Skip to content

Commit e89724c

Browse files
new test for save, fit allows custom configs
1 parent 4c99877 commit e89724c

File tree

2 files changed

+37
-12
lines changed

2 files changed

+37
-12
lines changed

pymc_experimental/model_builder.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,11 @@ def save(self, fname: str) -> None:
180180
>>> name = './mymodel.nc'
181181
>>> model.save(name)
182182
"""
183-
184-
file = Path(str(fname))
185-
self.idata.to_netcdf(file)
183+
if self.idata is not None and "fit_data" in self.idata:
184+
file = Path(str(fname))
185+
self.idata.to_netcdf(file)
186+
else:
187+
raise RuntimeError("The model hasn't been fit yet, call .fit() first")
186188

187189
@classmethod
188190
def load(cls, fname: str):
@@ -220,7 +222,7 @@ def load(cls, fname: str):
220222
data=idata.fit_data.to_dataframe(),
221223
)
222224
model_builder.idata = idata
223-
model_builder.idata = model_builder.fit()
225+
model_builder.build_model(model_builder.data, model_builder.model_config)
224226
if model_builder.id != idata.attrs["id"]:
225227
raise ValueError(
226228
f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
@@ -261,11 +263,12 @@ def fit(
261263
# If a new data was provided, assign it to the model
262264
if data is not None:
263265
self.data = data
264-
self.model_data, self.model_config, self.sampler_config = self.create_sample_input(
265-
data=self.data
266-
)
266+
self.model_data, model_config, sampler_config = self.create_sample_input(data=self.data)
267+
if self.model_config is None:
268+
self.model_config = model_config
269+
if self.sampler_config is None:
270+
self.sampler_config = sampler_config
267271
self.build_model(self.model_data, self.model_config)
268-
self._data_setter(self.model_data)
269272
with self.model:
270273
self.idata = pm.sample(**self.sampler_config, **kwargs)
271274
self.idata.extend(pm.sample_prior_predictive())
@@ -275,7 +278,7 @@ def fit(
275278
self.idata.attrs["model_type"] = self._model_type
276279
self.idata.attrs["version"] = self.version
277280
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
278-
self.idata.attrs["model_config"] = json.dumps(self.serializable_model_config)
281+
self.idata.attrs["model_config"] = json.dumps(self._serializable_model_config)
279282
self.idata.add_groups(fit_data=self.data.to_xarray())
280283
return self.idata
281284

@@ -386,6 +389,19 @@ def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[st
386389

387390
return post_pred_dict
388391

392+
@property
393+
@abstractmethod
394+
def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:
395+
"""
396+
Converts non-serializable values from model_config to their serializable reversable equivalent.
397+
Data types like pandas DataFrame, Series or datetime aren't JSON serializable,
398+
so in order to save the model they need to be formatted.
399+
400+
Returns
401+
-------
402+
model_config: dict
403+
"""
404+
389405
@property
390406
def id(self) -> str:
391407
"""

pymc_experimental/tests/test_model_builder.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def _data_setter(self, data: pd.DataFrame):
5757
pm.set_data({"y_data": data["output"].values})
5858

5959
@property
60-
def serializable_model_config(self):
60+
def _serializable_model_config(self):
6161
return self.model_config
6262

6363
@classmethod
@@ -95,7 +95,16 @@ def initial_build_and_fit(check_idata=True) -> ModelBuilder:
9595
return model_builder
9696

9797

98-
def test_empty_model_config():
98+
def test_save_without_fit_raises_runtime_error():
99+
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
100+
model_builder = test_ModelBuilder(
101+
model_config=model_config, sampler_config=sampler_config, data=data
102+
)
103+
with pytest.raises(RuntimeError):
104+
model_builder.save("saved_model")
105+
106+
107+
def test_empty_sampler_config_fit():
99108
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
100109
sampler_config = {}
101110
model_builder = test_ModelBuilder(
@@ -106,7 +115,7 @@ def test_empty_model_config():
106115
assert "posterior" in model_builder.idata.groups()
107116

108117

109-
def test_empty_model_config():
118+
def test_empty_model_config_fit():
110119
data, model_config, sampler_config = test_ModelBuilder.create_sample_input()
111120
model_config = {}
112121
model_builder = test_ModelBuilder(

0 commit comments

Comments
 (0)