@@ -59,6 +59,8 @@ def __init__(
59
59
"""
60
60
61
61
super ().__init__ ()
62
+ if sampler_config is None :
63
+ sampler_config = {}
62
64
self .model_config = model_config # parameters for priors etc.
63
65
self .sampler_config = sampler_config # parameters for sampling
64
66
self .data = data
@@ -193,7 +195,7 @@ def load(cls, fname: str):
193
195
if "sampler_config" in idata .attrs :
194
196
sampler_config = json .loads (idata .attrs ["sampler_config" ])
195
197
else :
196
- sampler_config = None
198
+ sampler_config = {}
197
199
model_builder = cls (
198
200
model_config = json .loads (idata .attrs ["model_config" ]),
199
201
sampler_config = sampler_config ,
@@ -239,7 +241,7 @@ def fit(
239
241
self ._data_setter (data )
240
242
241
243
with self .model :
242
- if self .sampler_config is not None :
244
+ if self .sampler_config :
243
245
self .idata = pm .sample (** self .sampler_config )
244
246
else :
245
247
self .idata = pm .sample ()
@@ -249,7 +251,7 @@ def fit(
249
251
self .idata .attrs ["id" ] = self .id
250
252
self .idata .attrs ["model_type" ] = self ._model_type
251
253
self .idata .attrs ["version" ] = self .version
252
- if self .sampler_config is not None :
254
+ if self .sampler_config :
253
255
self .idata .attrs ["sampler_config" ] = json .dumps (self .sampler_config )
254
256
self .idata .attrs ["model_config" ] = json .dumps (self .model_config )
255
257
self .idata .add_groups (fit_data = self .data .to_xarray ())
0 commit comments