Skip to content

Commit 1642e73

Browse files
authored
Save memory in posterior predictive (#3556)
Build the numpy ndarrays for the posterior predictive array lazily, but skip the intermediate step of building a very big python array.
1 parent 3a2a765 commit 1642e73

File tree

2 files changed

+62
-7
lines changed

2 files changed

+62
-7
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
- Parallelization of population steppers (`DEMetropolis`) is now set via the `cores` argument. ([#3559](https://github.com/pymc-devs/pymc3/pull/3559))
1616
- SMC: stabilize covariance matrix [3573](https://github.com/pymc-devs/pymc3/pull/3573)
1717
- SMC is no longer a step method of `pm.sample` now it should be called using `pm.sample_smc` [3579](https://github.com/pymc-devs/pymc3/pull/3579)
18+
- `sample_posterior_predictive` now preallocates the memory required for its output to improve memory usage. Addresses problems raised in this [discourse thread](https://discourse.pymc.io/t/memory-error-with-posterior-predictive-sample/2891/4).
1819

1920
## PyMC3 3.7 (May 29 2019)
2021

pymc3/sampling.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Optional, TYPE_CHECKING, cast
22
if TYPE_CHECKING:
3-
from typing import Any
3+
from typing import Any, Tuple
44
from typing import Iterable as TIterable
55
from collections.abc import Iterable
66
from collections import defaultdict
@@ -1010,6 +1010,60 @@ def stop_tuning(step):
10101010
step.stop_tuning()
10111011
return step
10121012

1013+
class _DefaultTrace():
1014+
'''
1015+
This class is a utility for collecting a number of samples
1016+
into a dictionary. Name comes from its similarity to `defaultdict` --
1017+
entries are lazily created.
1018+
1019+
Parameters
1020+
----------
1021+
samples : int
1022+
The number of samples that will be collected, per variable,
1023+
into the trace.
1024+
1025+
Attributes
1026+
----------
1027+
trace_dict : Dict[str, np.ndarray]
1028+
A dictionary constituting a trace. Should be extracted
1029+
after a procedure has filled the `_DefaultTrace` using the
1030+
`insert()` method
1031+
'''
1032+
trace_dict = {} # type: Dict[str, np.ndarray]
1033+
_len = None # type: int
1034+
def __init__(self, samples):
1035+
self._len = samples
1036+
self.trace_dict = {}
1037+
1038+
def insert(self, k: str, v, idx: int):
1039+
'''
1040+
Insert `v` as the value of the `idx`th sample for the variable `k`.
1041+
1042+
Parameters
1043+
----------
1044+
k : str
1045+
Name of the variable.
1046+
v : anything that can go into a numpy array (including a numpy array)
1047+
The value of the `idx`th sample from variable `k`
1048+
ids : int
1049+
The index of the sample we are inserting into the trace.
1050+
'''
1051+
if hasattr(v, 'shape'):
1052+
value_shape = tuple(v.shape) # type: Tuple[int, ...]
1053+
else:
1054+
value_shape = ()
1055+
1056+
# initialize if necessary
1057+
if k not in self.trace_dict:
1058+
array_shape = (self._len,) + value_shape
1059+
self.trace_dict[k] = np.full(array_shape, np.nan)
1060+
1061+
# do the actual insertion
1062+
if value_shape == ():
1063+
self.trace_dict[k][idx] = v
1064+
else:
1065+
self.trace_dict[k][idx,:] = v
1066+
10131067

10141068
def sample_posterior_predictive(trace,
10151069
samples: Optional[int]=None,
@@ -1097,10 +1151,11 @@ def sample_posterior_predictive(trace,
10971151

10981152
indices = np.arange(samples)
10991153

1154+
11001155
if progressbar:
11011156
indices = tqdm(indices, total=samples)
11021157

1103-
ppc_trace = defaultdict(list) # type: Dict[str, List[Any]]
1158+
ppc_trace_t = _DefaultTrace(samples)
11041159
try:
11051160
for idx in indices:
11061161
if nchain > 1:
@@ -1111,7 +1166,7 @@ def sample_posterior_predictive(trace,
11111166

11121167
values = draw_values(vars, point=param, size=size)
11131168
for k, v in zip(vars, values):
1114-
ppc_trace[k.name].append(v)
1169+
ppc_trace_t.insert(k.name, v, idx)
11151170

11161171
except KeyboardInterrupt:
11171172
pass
@@ -1120,13 +1175,12 @@ def sample_posterior_predictive(trace,
11201175
if progressbar:
11211176
indices.close()
11221177

1178+
ppc_trace = ppc_trace_t.trace_dict
11231179
if keep_size:
11241180
for k, ary in ppc_trace.items():
1125-
ary = np.asarray(ary)
11261181
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
1127-
return ppc_trace
1128-
else:
1129-
return {k: np.asarray(v) for k, v in ppc_trace.items()}
1182+
1183+
return ppc_trace
11301184

11311185

11321186
def sample_ppc(*args, **kwargs):

0 commit comments

Comments
 (0)