Skip to content

Commit 4c99877

Browse files
small tweaks to make mmm tests work smoother
1 parent 6f4bb25 commit 4c99877

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

pymc_experimental/model_builder.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,8 @@ def fit(
266266
)
267267
self.build_model(self.model_data, self.model_config)
268268
self._data_setter(self.model_data)
269-
270269
with self.model:
271-
self.idata = pm.sample(**self.sampler_config)
270+
self.idata = pm.sample(**self.sampler_config, **kwargs)
272271
self.idata.extend(pm.sample_prior_predictive())
273272
self.idata.extend(pm.sample_posterior_predictive(self.idata))
274273

@@ -277,7 +276,7 @@ def fit(
277276
self.idata.attrs["version"] = self.version
278277
self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config)
279278
self.idata.attrs["model_config"] = json.dumps(self.serializable_model_config)
280-
self.idata.add_groups(fit_data=self.model_data.to_xarray())
279+
self.idata.add_groups(fit_data=self.data.to_xarray())
281280
return self.idata
282281

283282
def predict(

0 commit comments

Comments
 (0)