Skip to content

Commit 432bcda

Browse files
authored
Merge pull request #7 from michaelosthege/windows-test
Fix the stackoverflow problem
2 parents 73ca256 + e94b47e commit 432bcda

File tree

6 files changed

+89
-29
lines changed

6 files changed

+89
-29
lines changed

pymc3/distributions/posterior_predictive.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
)
4343
from ..exceptions import IncorrectArgumentsError
4444
from ..vartypes import theano_constant
45-
from ..util import dataset_to_point_dict, chains_and_samples, get_var_name
45+
from ..util import dataset_to_point_list, chains_and_samples, get_var_name
4646

4747
# Failing tests:
4848
# test_mixture_random_shape::test_mixture_random_shape
@@ -209,10 +209,10 @@ def fast_sample_posterior_predictive(
209209

210210
if isinstance(trace, InferenceData):
211211
nchains, ndraws = chains_and_samples(trace)
212-
trace = dataset_to_point_dict(trace.posterior)
212+
trace = dataset_to_point_list(trace.posterior)
213213
elif isinstance(trace, Dataset):
214214
nchains, ndraws = chains_and_samples(trace)
215-
trace = dataset_to_point_dict(trace)
215+
trace = dataset_to_point_list(trace)
216216
elif isinstance(trace, MultiTrace):
217217
nchains = trace.nchains
218218
ndraws = len(trace)

pymc3/memoize.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
import pickle
16+
import dill
1717
import collections
1818
from .util import biwrap
1919

@@ -23,7 +23,16 @@
2323
@biwrap
2424
def memoize(obj, bound=False):
2525
"""
26-
An expensive memoizer that works with unhashables
26+
Decorator to apply memoization to expensive functions.
27+
It uses a custom `hashable` helper function to hash typically unhashable Python objects.
28+
29+
Parameters
30+
----------
31+
obj : callable
32+
the function to apply the caching to
33+
bound : bool
34+
indicates if the [obj] is a bound method (self as first argument)
35+
For bound methods, the cache is kept in a `_cache` attribute on [self].
2736
"""
2837
# this is declared not to be a bound method, so just attach new attr to obj
2938
if not bound:
@@ -40,7 +49,7 @@ def memoizer(*args, **kwargs):
4049
key = (hashable(args[1:]), hashable(kwargs))
4150
if not hasattr(args[0], "_cache"):
4251
setattr(args[0], "_cache", collections.defaultdict(dict))
43-
# do not add to cache regestry
52+
# do not add to cache registry
4453
cache = getattr(args[0], "_cache")[obj.__name__]
4554
if key not in cache:
4655
cache[key] = obj(*args, **kwargs)
@@ -75,19 +84,26 @@ def __setstate__(self, state):
7584
self.__dict__.update(state)
7685

7786

78-
def hashable(a):
87+
def hashable(a) -> int:
7988
"""
80-
Turn some unhashable objects into hashable ones.
89+
Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function.
90+
Lists and tuples are hashed based on their elements.
8191
"""
8292
if isinstance(a, dict):
83-
return hashable(tuple((hashable(a1), hashable(a2)) for a1, a2 in a.items()))
93+
# first hash the keys and values with hashable
94+
# then hash the tuple of int-tuples with the builtin
95+
return hash(tuple((hashable(k), hashable(v)) for k, v in a.items()))
96+
if isinstance(a, (tuple, list)):
97+
# lists are mutable and not hashable by default
98+
# for memoization, we need the hash to depend on the items
99+
return hash(tuple(map(hashable, a)))
84100
try:
85101
return hash(a)
86102
except TypeError:
87103
pass
88104
# Not hashable >>>
89105
try:
90-
return hash(pickle.dumps(a))
106+
return hash(dill.dumps(a))
91107
except Exception:
92108
if hasattr(a, "__dict__"):
93109
return hashable(a.__dict__)

pymc3/sampling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
get_untransformed_name,
5757
is_transformed_name,
5858
get_default_varnames,
59-
dataset_to_point_dict,
59+
dataset_to_point_list,
6060
chains_and_samples,
6161
)
6262
from .vartypes import discrete_types
@@ -1642,9 +1642,9 @@ def sample_posterior_predictive(
16421642

16431643
_trace: Union[MultiTrace, PointList]
16441644
if isinstance(trace, InferenceData):
1645-
_trace = dataset_to_point_dict(trace.posterior)
1645+
_trace = dataset_to_point_list(trace.posterior)
16461646
elif isinstance(trace, xarray.Dataset):
1647-
_trace = dataset_to_point_dict(trace)
1647+
_trace = dataset_to_point_list(trace)
16481648
else:
16491649
_trace = trace
16501650

@@ -1780,10 +1780,10 @@ def sample_posterior_predictive_w(
17801780
n_samples = [
17811781
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
17821782
]
1783-
traces = [dataset_to_point_dict(trace.posterior) for trace in traces]
1783+
traces = [dataset_to_point_list(trace.posterior) for trace in traces]
17841784
elif isinstance(traces[0], xarray.Dataset):
17851785
n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
1786-
traces = [dataset_to_point_dict(trace) for trace in traces]
1786+
traces = [dataset_to_point_list(trace) for trace in traces]
17871787
else:
17881788
n_samples = [len(i) * i.nchains for i in traces]
17891789

pymc3/tests/test_memo.py

Lines changed: 49 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,59 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
15+
import pymc3 as pm
1416

15-
from pymc3.memoize import memoize
17+
from pymc3 import memoize
1618

1719

18-
def getmemo():
19-
@memoize
20-
def f(a, b=("a")):
21-
return str(a) + str(b)
20+
def test_memo():
21+
def fun(inputs, suffix="_a"):
22+
return str(inputs) + str(suffix)
2223

23-
return f
24+
inputs = ["i1", "i2"]
25+
assert fun(inputs) == "['i1', 'i2']_a"
26+
assert fun(inputs, "_b") == "['i1', 'i2']_b"
2427

28+
funmem = memoize.memoize(fun)
29+
assert hasattr(fun, "cache")
30+
assert isinstance(fun.cache, dict)
31+
assert len(fun.cache) == 0
32+
33+
# call the memoized function with a list input
34+
# and check the size of the cache!
35+
assert funmem(inputs) == "['i1', 'i2']_a"
36+
assert funmem(inputs) == "['i1', 'i2']_a"
37+
assert len(fun.cache) == 1
38+
assert funmem(inputs, "_b") == "['i1', 'i2']_b"
39+
assert funmem(inputs, "_b") == "['i1', 'i2']_b"
40+
assert len(fun.cache) == 2
41+
42+
# add items to the inputs list (the list instance remains identical !!)
43+
inputs.append("i3")
44+
assert funmem(inputs) == "['i1', 'i2', 'i3']_a"
45+
assert funmem(inputs) == "['i1', 'i2', 'i3']_a"
46+
assert len(fun.cache) == 3
2547

26-
def test_memo():
27-
f = getmemo()
2848

29-
assert f("x", ["y", "z"]) == "x['y', 'z']"
30-
assert f("x", ["a", "z"]) == "x['a', 'z']"
31-
assert f("x", ["y", "z"]) == "x['y', 'z']"
49+
def test_hashing_of_rv_tuples():
50+
obs = np.random.normal(-1, 0.1, size=10)
51+
with pm.Model() as pmodel:
52+
mu = pm.Normal("mu", 0, 1)
53+
sd = pm.Gamma("sd", 1, 2)
54+
dd = pm.DensityDist(
55+
"dd",
56+
pm.Normal.dist(mu, sd).logp,
57+
random=pm.Normal.dist(mu, sd).random,
58+
observed=obs,
59+
)
60+
print()
61+
for freerv in [mu, sd, dd] + pmodel.free_RVs:
62+
for structure in [
63+
freerv,
64+
dict(alpha=freerv, omega=None),
65+
[freerv, []],
66+
(freerv, []),
67+
]:
68+
print(f"testing hashing of: {structure}")
69+
assert isinstance(memoize.hashable(structure), int)

pymc3/tests/test_sampling.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from typing import Tuple
1818
import numpy as np
1919
import unittest.mock as mock
20-
import sys
2120

2221
import numpy.testing as npt
2322
import arviz as az
@@ -930,9 +929,7 @@ def test_shared(self):
930929

931930
assert gen2["y"].shape == (draws, n2)
932931

933-
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Stackoverflow error on Windows")
934932
def test_density_dist(self):
935-
936933
obs = np.random.normal(-1, 0.1, size=10)
937934
with pm.Model():
938935
mu = pm.Normal("mu", 0, 1)

pymc3/util.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import functools
1717
from typing import List, Dict, Tuple, Union
18+
import warnings
1819

1920
import numpy as np
2021
import xarray
@@ -258,6 +259,14 @@ def enhanced(*args, **kwargs):
258259
# FIXME: this function is poorly named, because it returns a LIST of
259260
# points, not a dictionary of points.
260261
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
262+
warnings.warn(
263+
"dataset_to_point_dict was renamed to dataset_to_point_list and will be removed!.",
264+
DeprecationWarning,
265+
)
266+
return dataset_to_point_list(ds)
267+
268+
269+
def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
261270
# grab posterior samples for each variable
262271
_samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()}
263272
# make dicts

0 commit comments

Comments
 (0)