Skip to content

Commit a7d52c7

Browse files
committed
start porting tests in ArviZ
1 parent ed76e65 commit a7d52c7

File tree

5 files changed

+706
-61
lines changed

5 files changed

+706
-61
lines changed

pymc3/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ def __set_compiler_flags():
4040

4141
from pymc3 import gp, ode, sampling
4242
from pymc3.aesaraf import *
43-
from pymc3.backends import load_trace, save_trace
43+
from pymc3.backends import (
44+
load_trace,
45+
predictions_to_inference_data,
46+
save_trace,
47+
to_inference_data,
48+
)
4449
from pymc3.backends.tracetab import *
4550
from pymc3.blocking import *
4651
from pymc3.data import *

pymc3/backends/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
Saved backends can be loaded using `arviz.from_netcdf`
6161
6262
"""
63+
from pymc3.backends.arviz import predictions_to_inference_data, to_inference_data
6364
from pymc3.backends.ndarray import (
6465
NDArray,
6566
load_trace,

pymc3/backends/arviz.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,14 @@
1717
import numpy as np
1818
import xarray as xr
1919

20-
from aesara.gof.graph import ancestors
20+
from aesara.graph.basic import ancestors
2121
from aesara.tensor.var import TensorVariable
2222
from arviz import InferenceData, concat, rcParams
2323
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
2424

2525
import pymc3
2626

2727
from pymc3.model import modelcontext
28-
from pymc3.sampling import _DefaultTrace
2928
from pymc3.util import get_default_varnames
3029

3130
if TYPE_CHECKING:
@@ -42,6 +41,61 @@
4241
Var = Any # pylint: disable=invalid-name
4342

4443

44+
class _DefaultTrace:
45+
"""
46+
Utility for collecting samples into a dictionary.
47+
48+
Name comes from its similarity to ``defaultdict``:
49+
entries are lazily created.
50+
51+
Parameters
52+
----------
53+
samples : int
54+
The number of samples that will be collected, per variable,
55+
into the trace.
56+
57+
Attributes
58+
----------
59+
trace_dict : Dict[str, np.ndarray]
60+
A dictionary constituting a trace. Should be extracted
61+
after a procedure has filled the `_DefaultTrace` using the
62+
`insert()` method
63+
"""
64+
65+
trace_dict: Dict[str, np.ndarray] = {}
66+
_len: Optional[int] = None
67+
68+
def __init__(self, samples: int):
69+
self._len = samples
70+
self.trace_dict = {}
71+
72+
def insert(self, k: str, v, idx: int):
73+
"""
74+
Insert `v` as the value of the `idx`th sample for the variable `k`.
75+
76+
Parameters
77+
----------
78+
k: str
79+
Name of the variable.
80+
v: anything that can go into a numpy array (including a numpy array)
81+
The value of the `idx`th sample from variable `k`
82+
ids: int
83+
The index of the sample we are inserting into the trace.
84+
"""
85+
value_shape = np.shape(v)
86+
87+
# initialize if necessary
88+
if k not in self.trace_dict:
89+
array_shape = (self._len,) + value_shape
90+
self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype)
91+
92+
# do the actual insertion
93+
if value_shape == ():
94+
self.trace_dict[k][idx] = v
95+
else:
96+
self.trace_dict[k][idx, :] = v
97+
98+
4599
class InferenceDataConverter: # pylint: disable=too-many-instance-attributes
46100
"""Encapsulate InferenceData specific logic."""
47101

pymc3/sampling.py

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727

2828
import aesara
2929
import aesara.gradient as tg
30-
import arviz
3130
import numpy as np
3231
import packaging
3332
import xarray
@@ -38,6 +37,7 @@
3837
import pymc3 as pm
3938

4039
from pymc3.aesaraf import inputvars
40+
from pymc3.backends.arviz import _DefaultTrace
4141
from pymc3.backends.base import BaseTrace, MultiTrace
4242
from pymc3.backends.ndarray import NDArray
4343
from pymc3.blocking import DictToArrayBijection
@@ -344,7 +344,7 @@ def sample(
344344
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
345345
Defaults to `False`, but we'll switch to `True` in an upcoming release.
346346
idata_kwargs : dict, optional
347-
Keyword arguments for :func:`arviz:arviz.from_pymc3`
347+
Keyword arguments for :func:`pymc3.to_inference_data`
348348
mp_ctx : multiprocessing.context.BaseContent
349349
A multiprocessing context for parallel sampling. See multiprocessing
350350
documentation for details.
@@ -639,7 +639,7 @@ def sample(
639639
ikwargs = dict(model=model, save_warmup=not discard_tuned_samples, log_likelihood=False)
640640
if idata_kwargs:
641641
ikwargs.update(idata_kwargs)
642-
idata = arviz.from_pymc3(trace, **ikwargs)
642+
idata = pm.to_inference_data(trace, **ikwargs)
643643

644644
if compute_convergence_checks:
645645
if draws - tune < 100:
@@ -1546,61 +1546,6 @@ def stop_tuning(step):
15461546
return step
15471547

15481548

1549-
class _DefaultTrace:
1550-
"""
1551-
Utility for collecting samples into a dictionary.
1552-
1553-
Name comes from its similarity to ``defaultdict``:
1554-
entries are lazily created.
1555-
1556-
Parameters
1557-
----------
1558-
samples : int
1559-
The number of samples that will be collected, per variable,
1560-
into the trace.
1561-
1562-
Attributes
1563-
----------
1564-
trace_dict : Dict[str, np.ndarray]
1565-
A dictionary constituting a trace. Should be extracted
1566-
after a procedure has filled the `_DefaultTrace` using the
1567-
`insert()` method
1568-
"""
1569-
1570-
trace_dict: Dict[str, np.ndarray] = {}
1571-
_len: Optional[int] = None
1572-
1573-
def __init__(self, samples: int):
1574-
self._len = samples
1575-
self.trace_dict = {}
1576-
1577-
def insert(self, k: str, v, idx: int):
1578-
"""
1579-
Insert `v` as the value of the `idx`th sample for the variable `k`.
1580-
1581-
Parameters
1582-
----------
1583-
k: str
1584-
Name of the variable.
1585-
v: anything that can go into a numpy array (including a numpy array)
1586-
The value of the `idx`th sample from variable `k`
1587-
ids: int
1588-
The index of the sample we are inserting into the trace.
1589-
"""
1590-
value_shape = np.shape(v)
1591-
1592-
# initialize if necessary
1593-
if k not in self.trace_dict:
1594-
array_shape = (self._len,) + value_shape
1595-
self.trace_dict[k] = np.empty(array_shape, dtype=np.array(v).dtype)
1596-
1597-
# do the actual insertion
1598-
if value_shape == ():
1599-
self.trace_dict[k][idx] = v
1600-
else:
1601-
self.trace_dict[k][idx, :] = v
1602-
1603-
16041549
def sample_posterior_predictive(
16051550
trace,
16061551
samples: Optional[int] = None,

0 commit comments

Comments
 (0)