Skip to content

Commit d811561

Browse files
michaelraczyckitwiecki
authored andcommitted
storing model_config and sampler_config changed to json string
1 parent 84aa791 commit d811561

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

pymc_experimental/model_builder.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import hashlib
17+
import json
1718
from pathlib import Path
1819
from typing import Dict, Union
1920

@@ -187,8 +188,8 @@ def load(cls, fname):
187188
filepath = Path(str(fname))
188189
idata = az.from_netcdf(filepath)
189190
self = cls(
190-
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
191-
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
191+
json.loads(idata.attrs["model_config"]),
192+
json.loads(idata.attrs["sampler_config"]),
192193
idata.fit_data.to_dataframe(),
193194
)
194195
self.idata = idata
@@ -237,10 +238,8 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
237238
self.idata.attrs["id"] = self.id
238239
self.idata.attrs["model_type"] = self._model_type
239240
self.idata.attrs["version"] = self.version
240-
self.idata.attrs["sample_config_keys"] = tuple(self.sample_config.keys())
241-
self.idata.attrs["sample_config_values"] = tuple(self.sample_config.values())
242-
self.idata.attrs["model_config_keys"] = tuple(self.model_config.keys())
243-
self.idata.attrs["model_config_values"] = tuple(self.model_config.values())
241+
self.idata.attrs["sampler_config"] = json.dumps(self.sample_config)
242+
self.idata.attrs["model_config"] = json.dumps(self.model_config)
244243
self.idata.add_groups(fit_data=self.data.to_xarray())
245244
return self.idata
246245

0 commit comments

Comments
 (0)