Skip to content

Commit 51cb782

Browse files
merging changes from main
2 parents d3a74ec + d811561 commit 51cb782

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

pymc_experimental/model_builder.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515

1616
import hashlib
17+
import json
1718
from abc import abstractmethod
1819
from pathlib import Path
1920
from typing import Dict, Union
@@ -188,8 +189,8 @@ def load(cls, fname: str):
188189
filepath = Path(str(fname))
189190
idata = az.from_netcdf(filepath)
190191
model_builder = cls(
191-
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
192-
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
192+
json.loads(idata.attrs["model_config"]),
193+
json.loads(idata.attrs["sampler_config"]),
193194
idata.fit_data.to_dataframe(),
194195
)
195196
model_builder.idata = idata
@@ -239,10 +240,8 @@ def fit(
239240
self.idata.attrs["id"] = self.id
240241
self.idata.attrs["model_type"] = self._model_type
241242
self.idata.attrs["version"] = self.version
242-
self.idata.attrs["sample_config_keys"] = tuple(self.sample_config.keys())
243-
self.idata.attrs["sample_config_values"] = tuple(self.sample_config.values())
244-
self.idata.attrs["model_config_keys"] = tuple(self.model_config.keys())
245-
self.idata.attrs["model_config_values"] = tuple(self.model_config.values())
243+
self.idata.attrs["sampler_config"] = json.dumps(self.sample_config)
244+
self.idata.attrs["model_config"] = json.dumps(self.model_config)
246245
self.idata.add_groups(fit_data=self.data.to_xarray())
247246
return self.idata
248247

0 commit comments

Comments
 (0)