@@ -245,13 +245,13 @@ def fit(
245
245
self .idata .add_groups (fit_data = self .data .to_xarray ())
246
246
return self .idata
247
247
248
- def predict_posterior (
248
+ def predict (
249
249
self ,
250
250
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
251
251
extend_idata : bool = True ,
252
252
) -> xr .Dataset :
253
253
"""
254
- Generate posterior predictive samples on unseen data.
254
+ Uses model to predict on unseen data and return point prediction of all the samples
255
255
256
256
Parameters
257
257
---------
@@ -262,61 +262,66 @@ def predict_posterior(
262
262
263
263
Returns
264
264
-------
265
- returns posterior predictive samples
265
+ returns posterior mean of predictive samples
266
266
267
267
Examples
268
268
--------
269
269
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
270
270
>>> model = LinearModel(model_config, sampler_config)
271
271
>>> idata = model.fit(data)
272
272
>>> x_pred = []
273
- >>> prediction_data = pd.DataFrame({'input': x_pred})
274
- >>> pred_samples = model.predict_posterior (prediction_data)
273
+ >>> prediction_data = pd.DataFrame({'input':x_pred})
274
+ >>> pred_mean = model.predict (prediction_data)
275
275
"""
276
+ posterior_predictive_samples = self .predict_posterior (data_prediction , extend_idata )
277
+ posterior_means = posterior_predictive_samples .mean (dim = ["chain" , "draw" ], keep_attrs = True )
278
+ return posterior_means
276
279
277
- if data_prediction is not None : # set new input data
278
- self ._data_setter (data_prediction )
279
-
280
- with self .model : # sample with new input data
281
- post_pred = pm .sample_posterior_predictive (self .idata )
282
- if extend_idata :
283
- self .idata .extend (post_pred )
284
-
285
- posterior_predictive_samples = az .extract (post_pred , "posterior_predictive" , combined = False )
286
-
287
- return posterior_predictive_samples
288
-
289
- def predict (
280
+ def predict_posterior (
290
281
self ,
291
282
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
292
283
extend_idata : bool = True ,
284
+ combined : bool = False ,
293
285
) -> xr .Dataset :
294
286
"""
295
- Uses model to predict on unseen data and return point prediction of all the samples
287
+ Generate posterior predictive samples on unseen data.
296
288
297
289
Parameters
298
290
---------
299
291
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
300
292
It is the data we need to make prediction on using the model.
301
293
extend_idata : Boolean determining whether the predictions should be added to inference data object.
302
294
Defaults to True.
295
+ combined: Combine chain and draw dims into sample. Won’t work if a dim named sample already exists.
296
+ Defaults to False.
303
297
304
298
Returns
305
299
-------
306
- returns posterior mean of predictive samples
300
+ returns posterior predictive samples
307
301
308
302
Examples
309
303
--------
310
304
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
311
305
>>> model = LinearModel(model_config, sampler_config)
312
306
>>> idata = model.fit(data)
313
307
>>> x_pred = []
314
- >>> prediction_data = pd.DataFrame({'input':x_pred})
315
- >>> pred_mean = model.predict (prediction_data)
308
+ >>> prediction_data = pd.DataFrame({'input': x_pred})
309
+ >>> pred_samples = model.predict_posterior (prediction_data)
316
310
"""
317
- posterior_predictive_samples = self .predict_posterior (data_prediction , extend_idata )
318
- posterior_means = posterior_predictive_samples .mean (dim = ["chain" , "draw" ], keep_attrs = True )
319
- return posterior_means
311
+
312
+ if data_prediction is not None : # set new input data
313
+ self ._data_setter (data_prediction )
314
+
315
+ with self .model : # sample with new input data
316
+ post_pred = pm .sample_posterior_predictive (self .idata )
317
+ if extend_idata :
318
+ self .idata .extend (post_pred )
319
+
320
+ posterior_predictive_samples = az .extract (
321
+ post_pred , "posterior_predictive" , combined = combined
322
+ )
323
+
324
+ return posterior_predictive_samples
320
325
321
326
@property
322
327
def id (self ) -> str :
0 commit comments