Skip to content

Commit 7fe150f

Browse files
rename dataset_to_points_dict with DeprecationWarning
1 parent 73ca256 commit 7fe150f

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

pymc3/distributions/posterior_predictive.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from ..exceptions import IncorrectArgumentsError
4444
from ..vartypes import theano_constant
45-
from ..util import dataset_to_point_dict, chains_and_samples, get_var_name
45+
from ..util import dataset_to_point_list, chains_and_samples, get_var_name
4646

4747
# Failing tests:
4848
# test_mixture_random_shape::test_mixture_random_shape
@@ -209,10 +209,10 @@ def fast_sample_posterior_predictive(
209209

210210
if isinstance(trace, InferenceData):
211211
nchains, ndraws = chains_and_samples(trace)
212-
trace = dataset_to_point_dict(trace.posterior)
212+
trace = dataset_to_point_list(trace.posterior)
213213
elif isinstance(trace, Dataset):
214214
nchains, ndraws = chains_and_samples(trace)
215-
trace = dataset_to_point_dict(trace)
215+
trace = dataset_to_point_list(trace)
216216
elif isinstance(trace, MultiTrace):
217217
nchains = trace.nchains
218218
ndraws = len(trace)

pymc3/sampling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
get_untransformed_name,
5757
is_transformed_name,
5858
get_default_varnames,
59-
dataset_to_point_dict,
59+
dataset_to_point_list,
6060
chains_and_samples,
6161
)
6262
from .vartypes import discrete_types
@@ -1642,9 +1642,9 @@ def sample_posterior_predictive(
16421642

16431643
_trace: Union[MultiTrace, PointList]
16441644
if isinstance(trace, InferenceData):
1645-
_trace = dataset_to_point_dict(trace.posterior)
1645+
_trace = dataset_to_point_list(trace.posterior)
16461646
elif isinstance(trace, xarray.Dataset):
1647-
_trace = dataset_to_point_dict(trace)
1647+
_trace = dataset_to_point_list(trace)
16481648
else:
16491649
_trace = trace
16501650

@@ -1780,10 +1780,10 @@ def sample_posterior_predictive_w(
17801780
n_samples = [
17811781
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
17821782
]
1783-
traces = [dataset_to_point_dict(trace.posterior) for trace in traces]
1783+
traces = [dataset_to_point_list(trace.posterior) for trace in traces]
17841784
elif isinstance(traces[0], xarray.Dataset):
17851785
n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
1786-
traces = [dataset_to_point_dict(trace) for trace in traces]
1786+
traces = [dataset_to_point_list(trace) for trace in traces]
17871787
else:
17881788
n_samples = [len(i) * i.nchains for i in traces]
17891789

pymc3/util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import functools
1717
from typing import List, Dict, Tuple, Union
18+
import warnings
1819

1920
import numpy as np
2021
import xarray
@@ -258,6 +259,14 @@ def enhanced(*args, **kwargs):
258259
# FIXME: this function is poorly named, because it returns a LIST of
259260
# points, not a dictionary of points.
260261
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
262+
warnings.warn(
263+
"dataset_to_point_dict was renamed to dataset_to_point_list and will be removed!.",
264+
DeprecationWarning
265+
)
266+
return dataset_to_point_list(ds)
267+
268+
269+
def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
261270
# grab posterior samples for each variable
262271
_samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()}
263272
# make dicts

0 commit comments

Comments
 (0)