diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 8c5aa7577..6d2513349 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -14,6 +14,7 @@ import hashlib +import json from pathlib import Path from typing import Dict, Union @@ -187,8 +188,8 @@ def load(cls, fname): filepath = Path(str(fname)) idata = az.from_netcdf(filepath) self = cls( - dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])), - dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])), + json.loads(idata.attrs["model_config"]), + json.loads(idata.attrs["sampler_config"]), idata.fit_data.to_dataframe(), ) self.idata = idata @@ -237,10 +238,8 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None self.idata.attrs["id"] = self.id self.idata.attrs["model_type"] = self._model_type self.idata.attrs["version"] = self.version - self.idata.attrs["sample_config_keys"] = tuple(self.sample_config.keys()) - self.idata.attrs["sample_config_values"] = tuple(self.sample_config.values()) - self.idata.attrs["model_config_keys"] = tuple(self.model_config.keys()) - self.idata.attrs["model_config_values"] = tuple(self.model_config.values()) + self.idata.attrs["sampler_config"] = json.dumps(self.sample_config) + self.idata.attrs["model_config"] = json.dumps(self.model_config) self.idata.add_groups(fit_data=self.data.to_xarray()) return self.idata