23
23
import numpy as np
24
24
import pandas as pd
25
25
import pymc as pm
26
+ import xarray as xr
26
27
from pymc .util import RandomState
27
28
28
29
@@ -306,7 +307,7 @@ def predict(
306
307
self ,
307
308
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
308
309
extend_idata : bool = True ,
309
- ) -> dict :
310
+ ) -> xr . Dataset :
310
311
"""
311
312
Uses model to predict on unseen data and return point prediction of all the samples
312
313
@@ -319,7 +320,7 @@ def predict(
319
320
320
321
Returns
321
322
-------
322
- returns dictionary of sample's mean of posterior predict.
323
+ returns posterior mean of predictive samples
323
324
324
325
Examples
325
326
--------
@@ -328,42 +329,33 @@ def predict(
328
329
>>> idata = model.fit(data)
329
330
>>> x_pred = []
330
331
>>> prediction_data = pd.DataFrame({'input':x_pred})
331
- # point predict
332
332
>>> pred_mean = model.predict(prediction_data)
333
333
"""
334
-
335
- if data_prediction is not None : # set new input data
336
- self ._data_setter (data_prediction )
337
-
338
- with self .model : # sample with new input data
339
- post_pred = pm .sample_posterior_predictive (self .idata )
340
- if extend_idata :
341
- self .idata .extend (post_pred )
342
- # reshape output
343
- post_pred = self ._extract_samples (post_pred )
344
- for key in post_pred :
345
- post_pred [key ] = post_pred [key ].mean (axis = 0 )
346
-
347
- return post_pred
334
+ posterior_predictive_samples = self .predict_posterior (data_prediction , extend_idata )
335
+ posterior_means = posterior_predictive_samples .mean (dim = ["chain" , "draw" ], keep_attrs = True )
336
+ return posterior_means
348
337
349
338
def predict_posterior (
350
339
self ,
351
340
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
352
341
extend_idata : bool = True ,
353
- ) -> Dict [str , np .array ]:
342
+ combined : bool = False ,
343
+ ) -> xr .Dataset :
354
344
"""
355
- Uses model to predict samples on unseen data.
345
+ Generate posterior predictive samples on unseen data.
356
346
357
347
Parameters
358
348
---------
359
349
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
360
350
It is the data we need to make prediction on using the model.
361
351
extend_idata : Boolean determining whether the predictions should be added to inference data object.
362
352
Defaults to True.
353
+ combined: Combine chain and draw dims into sample. Won’t work if a dim named sample already exists.
354
+ Defaults to False.
363
355
364
356
Returns
365
357
-------
366
- returns dictionary of sample's posterior predict.
358
+ returns posterior predictive samples
367
359
368
360
Examples
369
361
--------
@@ -372,8 +364,7 @@ def predict_posterior(
372
364
>>> idata = model.fit(data)
373
365
>>> x_pred = []
374
366
>>> prediction_data = pd.DataFrame({'input': x_pred})
375
- # samples
376
- >>> pred_mean = model.predict_posterior(prediction_data)
367
+ >>> pred_samples = model.predict_posterior(prediction_data)
377
368
"""
378
369
379
370
if data_prediction is not None : # set new input data
@@ -384,30 +375,11 @@ def predict_posterior(
384
375
if extend_idata :
385
376
self .idata .extend (post_pred )
386
377
387
- # reshape output
388
- post_pred = self ._extract_samples (post_pred )
389
-
390
- return post_pred
391
-
392
- @staticmethod
393
- def _extract_samples (post_pred : az .data .inference_data .InferenceData ) -> Dict [str , np .array ]:
394
- """
395
- This method can be used to extract samples from posterior predict.
396
-
397
- Parameters
398
- ----------
399
- post_pred: arviz InferenceData object
400
-
401
- Returns
402
- -------
403
- Dictionary of numpy arrays from InferenceData object
404
- """
405
-
406
- post_pred_dict = dict ()
407
- for key in post_pred .posterior_predictive :
408
- post_pred_dict [key ] = post_pred .posterior_predictive [key ].to_numpy ()[0 ]
378
+ posterior_predictive_samples = az .extract (
379
+ post_pred , "posterior_predictive" , combined = combined
380
+ )
409
381
410
- return post_pred_dict
382
+ return posterior_predictive_samples
411
383
412
384
@property
413
385
@abstractmethod
0 commit comments