File tree Expand file tree Collapse file tree 3 files changed +17
-8
lines changed Expand file tree Collapse file tree 3 files changed +17
-8
lines changed Original file line number Diff line number Diff line change 42
42
)
43
43
from ..exceptions import IncorrectArgumentsError
44
44
from ..vartypes import theano_constant
45
- from ..util import dataset_to_point_dict , chains_and_samples , get_var_name
45
+ from ..util import dataset_to_point_list , chains_and_samples , get_var_name
46
46
47
47
# Failing tests:
48
48
# test_mixture_random_shape::test_mixture_random_shape
@@ -209,10 +209,10 @@ def fast_sample_posterior_predictive(
209
209
210
210
if isinstance (trace , InferenceData ):
211
211
nchains , ndraws = chains_and_samples (trace )
212
- trace = dataset_to_point_dict (trace .posterior )
212
+ trace = dataset_to_point_list (trace .posterior )
213
213
elif isinstance (trace , Dataset ):
214
214
nchains , ndraws = chains_and_samples (trace )
215
- trace = dataset_to_point_dict (trace )
215
+ trace = dataset_to_point_list (trace )
216
216
elif isinstance (trace , MultiTrace ):
217
217
nchains = trace .nchains
218
218
ndraws = len (trace )
Original file line number Diff line number Diff line change 56
56
get_untransformed_name ,
57
57
is_transformed_name ,
58
58
get_default_varnames ,
59
- dataset_to_point_dict ,
59
+ dataset_to_point_list ,
60
60
chains_and_samples ,
61
61
)
62
62
from .vartypes import discrete_types
@@ -1642,9 +1642,9 @@ def sample_posterior_predictive(
1642
1642
1643
1643
_trace : Union [MultiTrace , PointList ]
1644
1644
if isinstance (trace , InferenceData ):
1645
- _trace = dataset_to_point_dict (trace .posterior )
1645
+ _trace = dataset_to_point_list (trace .posterior )
1646
1646
elif isinstance (trace , xarray .Dataset ):
1647
- _trace = dataset_to_point_dict (trace )
1647
+ _trace = dataset_to_point_list (trace )
1648
1648
else :
1649
1649
_trace = trace
1650
1650
@@ -1780,10 +1780,10 @@ def sample_posterior_predictive_w(
1780
1780
n_samples = [
1781
1781
trace .posterior .sizes ["chain" ] * trace .posterior .sizes ["draw" ] for trace in traces
1782
1782
]
1783
- traces = [dataset_to_point_dict (trace .posterior ) for trace in traces ]
1783
+ traces = [dataset_to_point_list (trace .posterior ) for trace in traces ]
1784
1784
elif isinstance (traces [0 ], xarray .Dataset ):
1785
1785
n_samples = [trace .sizes ["chain" ] * trace .sizes ["draw" ] for trace in traces ]
1786
- traces = [dataset_to_point_dict (trace ) for trace in traces ]
1786
+ traces = [dataset_to_point_list (trace ) for trace in traces ]
1787
1787
else :
1788
1788
n_samples = [len (i ) * i .nchains for i in traces ]
1789
1789
Original file line number Diff line number Diff line change 15
15
import re
16
16
import functools
17
17
from typing import List , Dict , Tuple , Union
18
+ import warnings
18
19
19
20
import numpy as np
20
21
import xarray
@@ -258,6 +259,14 @@ def enhanced(*args, **kwargs):
258
259
# FIXME: this function is poorly named, because it returns a LIST of
259
260
# points, not a dictionary of points.
260
261
def dataset_to_point_dict (ds : xarray .Dataset ) -> List [Dict [str , np .ndarray ]]:
262
+ warnings .warn (
263
+ "dataset_to_point_dict was renamed to dataset_to_point_list and will be removed!." ,
264
+ DeprecationWarning
265
+ )
266
+ return dataset_to_point_list (ds )
267
+
268
+
269
+ def dataset_to_point_list (ds : xarray .Dataset ) -> List [Dict [str , np .ndarray ]]:
261
270
# grab posterior samples for each variable
262
271
_samples : Dict [str , np .ndarray ] = {vn : ds [vn ].values for vn in ds .keys ()}
263
272
# make dicts
You can’t perform that action at this time.
0 commit comments