Skip to content

Commit d6183b0

Browse files
committed
fixed laod functio
1 parent 51d1ff0 commit d6183b0

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

pymc_experimental/model_builder.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,12 @@ def load(cls, fname):
186186

187187
filepath = Path(str(fname))
188188
idata = az.from_netcdf(filepath)
189-
self = cls(dict(zip(idata.attrs['model_config_keys'],idata.attrs['model_config_values'])), dict(zip(idata.attrs['sample_config_keys'],idata.attrs['sample_config_values'])), idata.data)
190-
self.idata=idata
189+
self = cls(
190+
dict(zip(idata.attrs["model_config_keys"], idata.attrs["model_config_values"])),
191+
dict(zip(idata.attrs["sample_config_keys"], idata.attrs["sample_config_values"])),
192+
idata.data,
193+
)
194+
self.idata = idata
191195
if self.id() != idata.attrs["id"]:
192196
raise ValueError(
193197
f"The route '{fname}' does not contain an inference data of the same model '{self._model_type}'"
@@ -236,7 +240,7 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None
236240
self.idata.attrs["sample_config_values"] = tuple(self.sample_config.values())
237241
self.idata.attrs["model_config_keys"] = tuple(self.model_config.keys())
238242
self.idata.attrs["model_config_values"] = tuple(self.model_config.values())
239-
self.idata.add_groups(data = self.data.to_xarray())
243+
self.idata.add_groups(data=self.data.to_xarray())
240244
return self.idata
241245

242246
def predict(

0 commit comments

Comments
 (0)