Skip to content

Commit 02faac7

Browse files
author
Max Joseph
committed
Add/test combined arg, revert method order
1 parent 3eb5b35 commit 02faac7

File tree

2 files changed

+35
-29
lines changed

2 files changed

+35
-29
lines changed

pymc_experimental/model_builder.py

Lines changed: 30 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -245,13 +245,13 @@ def fit(
245245
self.idata.add_groups(fit_data=self.data.to_xarray())
246246
return self.idata
247247

248-
def predict_posterior(
248+
def predict(
249249
self,
250250
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
251251
extend_idata: bool = True,
252252
) -> xr.Dataset:
253253
"""
254-
Generate posterior predictive samples on unseen data.
254+
Uses model to predict on unseen data and return point prediction of all the samples
255255
256256
Parameters
257257
---------
@@ -262,61 +262,66 @@ def predict_posterior(
262262
263263
Returns
264264
-------
265-
returns posterior predictive samples
265+
returns posterior mean of predictive samples
266266
267267
Examples
268268
--------
269269
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
270270
>>> model = LinearModel(model_config, sampler_config)
271271
>>> idata = model.fit(data)
272272
>>> x_pred = []
273-
>>> prediction_data = pd.DataFrame({'input': x_pred})
274-
>>> pred_samples = model.predict_posterior(prediction_data)
273+
>>> prediction_data = pd.DataFrame({'input':x_pred})
274+
>>> pred_mean = model.predict(prediction_data)
275275
"""
276+
posterior_predictive_samples = self.predict_posterior(data_prediction, extend_idata)
277+
posterior_means = posterior_predictive_samples.mean(dim=["chain", "draw"], keep_attrs=True)
278+
return posterior_means
276279

277-
if data_prediction is not None: # set new input data
278-
self._data_setter(data_prediction)
279-
280-
with self.model: # sample with new input data
281-
post_pred = pm.sample_posterior_predictive(self.idata)
282-
if extend_idata:
283-
self.idata.extend(post_pred)
284-
285-
posterior_predictive_samples = az.extract(post_pred, "posterior_predictive", combined=False)
286-
287-
return posterior_predictive_samples
288-
289-
def predict(
280+
def predict_posterior(
290281
self,
291282
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
292283
extend_idata: bool = True,
284+
combined: bool = False,
293285
) -> xr.Dataset:
294286
"""
295-
Uses model to predict on unseen data and return point prediction of all the samples
287+
Generate posterior predictive samples on unseen data.
296288
297289
Parameters
298290
---------
299291
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
300292
It is the data we need to make prediction on using the model.
301293
extend_idata : Boolean determining whether the predictions should be added to inference data object.
302294
Defaults to True.
295+
combined: Combine chain and draw dims into sample. Won’t work if a dim named sample already exists.
296+
Defaults to False.
303297
304298
Returns
305299
-------
306-
returns posterior mean of predictive samples
300+
returns posterior predictive samples
307301
308302
Examples
309303
--------
310304
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
311305
>>> model = LinearModel(model_config, sampler_config)
312306
>>> idata = model.fit(data)
313307
>>> x_pred = []
314-
>>> prediction_data = pd.DataFrame({'input':x_pred})
315-
>>> pred_mean = model.predict(prediction_data)
308+
>>> prediction_data = pd.DataFrame({'input': x_pred})
309+
>>> pred_samples = model.predict_posterior(prediction_data)
316310
"""
317-
posterior_predictive_samples = self.predict_posterior(data_prediction, extend_idata)
318-
posterior_means = posterior_predictive_samples.mean(dim=["chain", "draw"], keep_attrs=True)
319-
return posterior_means
311+
312+
if data_prediction is not None: # set new input data
313+
self._data_setter(data_prediction)
314+
315+
with self.model: # sample with new input data
316+
post_pred = pm.sample_posterior_predictive(self.idata)
317+
if extend_idata:
318+
self.idata.extend(post_pred)
319+
320+
posterior_predictive_samples = az.extract(
321+
post_pred, "posterior_predictive", combined=combined
322+
)
323+
324+
return posterior_predictive_samples
320325

321326
@property
322327
def id(self) -> str:

pymc_experimental/tests/test_model_builder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,16 +151,17 @@ def test_predict():
151151
assert np.issubdtype(pred["y_model"].dtype, np.floating)
152152

153153

154-
def test_predict_posterior():
154+
@pytest.mark.parametrize("combined", [True, False])
155+
def test_predict_posterior(combined):
155156
model = test_ModelBuilder.initial_build_and_fit()
156157
n_pred = 100
157158
x_pred = np.random.uniform(low=0, high=1, size=n_pred)
158159
prediction_data = pd.DataFrame({"input": x_pred})
159-
pred = model.predict_posterior(prediction_data)
160+
pred = model.predict_posterior(prediction_data, combined=combined)
160161
chains = model.idata.sample_stats.dims["chain"]
161162
draws = model.idata.sample_stats.dims["draw"]
162-
assert "y_model" in pred
163-
assert pred["y_model"].shape == (chains, draws, n_pred)
163+
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
164+
assert pred["y_model"].shape == expected_shape
164165
assert np.issubdtype(pred["y_model"].dtype, np.floating)
165166

166167

0 commit comments

Comments
 (0)