23
23
import numpy as np
24
24
import pandas as pd
25
25
import pymc as pm
26
+ import xarray as xr
26
27
from pymc .util import RandomState
27
28
28
29
@@ -286,7 +287,7 @@ def predict(
286
287
self ,
287
288
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
288
289
extend_idata : bool = True ,
289
- ) -> dict :
290
+ ) -> xr . Dataset :
290
291
"""
291
292
Uses model to predict on unseen data and return point prediction of all the samples
292
293
@@ -299,7 +300,7 @@ def predict(
299
300
300
301
Returns
301
302
-------
302
- returns dictionary of sample's mean of posterior predict.
303
+ returns posterior mean of predictive samples
303
304
304
305
Examples
305
306
--------
@@ -308,42 +309,33 @@ def predict(
308
309
>>> idata = model.fit(data)
309
310
>>> x_pred = []
310
311
>>> prediction_data = pd.DataFrame({'input':x_pred})
311
- # point predict
312
312
>>> pred_mean = model.predict(prediction_data)
313
313
"""
314
-
315
- if data_prediction is not None : # set new input data
316
- self ._data_setter (data_prediction )
317
-
318
- with self .model : # sample with new input data
319
- post_pred = pm .sample_posterior_predictive (self .idata )
320
- if extend_idata :
321
- self .idata .extend (post_pred )
322
- # reshape output
323
- post_pred = self ._extract_samples (post_pred )
324
- for key in post_pred :
325
- post_pred [key ] = post_pred [key ].mean (axis = 0 )
326
-
327
- return post_pred
314
+ posterior_predictive_samples = self .predict_posterior (data_prediction , extend_idata )
315
+ posterior_means = posterior_predictive_samples .mean (dim = ["chain" , "draw" ], keep_attrs = True )
316
+ return posterior_means
328
317
329
318
def predict_posterior (
330
319
self ,
331
320
data_prediction : Dict [str , Union [np .ndarray , pd .DataFrame , pd .Series ]] = None ,
332
321
extend_idata : bool = True ,
333
- ) -> Dict [str , np .array ]:
322
+ combined : bool = False ,
323
+ ) -> xr .Dataset :
334
324
"""
335
- Uses model to predict samples on unseen data.
325
+ Generate posterior predictive samples on unseen data.
336
326
337
327
Parameters
338
328
---------
339
329
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
340
330
It is the data we need to make prediction on using the model.
341
331
extend_idata : Boolean determining whether the predictions should be added to inference data object.
342
332
Defaults to True.
333
+ combined: Combine chain and draw dims into sample. Won’t work if a dim named sample already exists.
334
+ Defaults to False.
343
335
344
336
Returns
345
337
-------
346
- returns dictionary of sample's posterior predict.
338
+ returns posterior predictive samples
347
339
348
340
Examples
349
341
--------
@@ -352,8 +344,7 @@ def predict_posterior(
352
344
>>> idata = model.fit(data)
353
345
>>> x_pred = []
354
346
>>> prediction_data = pd.DataFrame({'input': x_pred})
355
- # samples
356
- >>> pred_mean = model.predict_posterior(prediction_data)
347
+ >>> pred_samples = model.predict_posterior(prediction_data)
357
348
"""
358
349
359
350
if data_prediction is not None : # set new input data
@@ -364,30 +355,11 @@ def predict_posterior(
364
355
if extend_idata :
365
356
self .idata .extend (post_pred )
366
357
367
- # reshape output
368
- post_pred = self ._extract_samples (post_pred )
369
-
370
- return post_pred
371
-
372
- @staticmethod
373
- def _extract_samples (post_pred : az .data .inference_data .InferenceData ) -> Dict [str , np .array ]:
374
- """
375
- This method can be used to extract samples from posterior predict.
376
-
377
- Parameters
378
- ----------
379
- post_pred: arviz InferenceData object
380
-
381
- Returns
382
- -------
383
- Dictionary of numpy arrays from InferenceData object
384
- """
385
-
386
- post_pred_dict = dict ()
387
- for key in post_pred .posterior_predictive :
388
- post_pred_dict [key ] = post_pred .posterior_predictive [key ].to_numpy ()[0 ]
358
+ posterior_predictive_samples = az .extract (
359
+ post_pred , "posterior_predictive" , combined = combined
360
+ )
389
361
390
- return post_pred_dict
362
+ return posterior_predictive_samples
391
363
392
364
@property
393
365
@abstractmethod
0 commit comments