Skip to content

Commit b4e6349

Browse files
changed default sampler_config from None to {}
1 parent 9c481aa commit b4e6349

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

pymc_experimental/model_builder.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def __init__(
5959
"""
6060

6161
super().__init__()
62+
if sampler_config is None:
63+
sampler_config = {}
6264
self.model_config = model_config # parameters for priors etc.
6365
self.sampler_config = sampler_config # parameters for sampling
6466
self.data = data
@@ -193,7 +195,7 @@ def load(cls, fname: str):
193195
if "sampler_config" in idata.attrs:
194196
sampler_config = json.loads(idata.attrs["sampler_config"])
195197
else:
196-
sampler_config = None
198+
sampler_config = {}
197199
model_builder = cls(
198200
model_config=json.loads(idata.attrs["model_config"]),
199201
sampler_config=sampler_config,
@@ -239,7 +241,7 @@ def fit(
239241
self._data_setter(data)
240242

241243
with self.model:
242-
if self.sampler_config is not None:
244+
if self.sampler_config:
243245
self.idata = pm.sample(**self.sampler_config)
244246
else:
245247
self.idata = pm.sample()
@@ -249,7 +251,7 @@ def fit(
249251
self.idata.attrs["id"] = self.id
250252
self.idata.attrs["model_type"] = self._model_type
251253
self.idata.attrs["version"] = self.version
252-
if self.sampler_config is not None:
254+
if self.sampler_config:
253255
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
254256
self.idata.attrs["model_config"] = json.dumps(self.model_config)
255257
self.idata.add_groups(fit_data=self.data.to_xarray())

0 commit comments

Comments
 (0)