Skip to content

Commit 7007aab

Browse files
fix underlying bug behind DensityDist stackoverflow
tl:dr: The hashable helper function did not appropiately deal with tuples (and the test case did not actually test the memoization). In the process of prior-predictive sampling a model involving a DensityDist, the _compile_theano_function function was called with arguments (sd__log_, []). The _compile_theano_function has a pm.memo.memoize-decorator, which relies on the pm.memo.hashable for hashing of typically unhashable objects. The "hashable" function incorrectly handled tuples, eventually causing a stackoverflow error on Windows.
1 parent 7fe150f commit 7007aab

File tree

3 files changed

+69
-20
lines changed

3 files changed

+69
-20
lines changed

pymc3/memoize.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import functools
16-
import pickle
16+
import dill
1717
import collections
1818
from .util import biwrap
1919

@@ -23,7 +23,16 @@
2323
@biwrap
2424
def memoize(obj, bound=False):
2525
"""
26-
An expensive memoizer that works with unhashables
26+
Decorator to apply memoization to expensive functions.
27+
It uses a custom `hashable` helper function to hash typically unhashable Python objects.
28+
29+
Parameters
30+
----------
31+
obj : callable
32+
the function to apply the caching to
33+
bound : bool
34+
indicates if the [obj] is a bound method (self as first argument)
35+
For bound methods, the cache is kept in a `_cache` attribute on [self].
2736
"""
2837
# this is declared not to be a bound method, so just attach new attr to obj
2938
if not bound:
@@ -40,7 +49,7 @@ def memoizer(*args, **kwargs):
4049
key = (hashable(args[1:]), hashable(kwargs))
4150
if not hasattr(args[0], "_cache"):
4251
setattr(args[0], "_cache", collections.defaultdict(dict))
43-
# do not add to cache regestry
52+
# do not add to cache registry
4453
cache = getattr(args[0], "_cache")[obj.__name__]
4554
if key not in cache:
4655
cache[key] = obj(*args, **kwargs)
@@ -75,19 +84,26 @@ def __setstate__(self, state):
7584
self.__dict__.update(state)
7685

7786

78-
def hashable(a):
87+
def hashable(a) -> int:
7988
"""
80-
Turn some unhashable objects into hashable ones.
89+
Hashes many kinds of objects, including some that are unhashable through the builtin `hash` function.
90+
Lists and tuples are hashed based on their elements.
8191
"""
8292
if isinstance(a, dict):
83-
return hashable(tuple((hashable(a1), hashable(a2)) for a1, a2 in a.items()))
93+
# first hash the keys and values with hashable
94+
# then hash the tuple of int-tuples with the builtin
95+
return hash(tuple((hashable(k), hashable(v)) for k, v in a.items()))
96+
if isinstance(a, (tuple, list)):
97+
# lists are mutable and not hashable by default
98+
# for memoization, we need the hash to depend on the items
99+
return hash(tuple(map(hashable, a)))
84100
try:
85101
return hash(a)
86102
except TypeError:
87103
pass
88104
# Not hashable >>>
89105
try:
90-
return hash(pickle.dumps(a))
106+
return hash(dill.dumps(a))
91107
except Exception:
92108
if hasattr(a, "__dict__"):
93109
return hashable(a.__dict__)

pymc3/tests/test_memo.py

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,56 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numpy as np
15+
import pymc3 as pm
1416

15-
from pymc3.memoize import memoize
17+
from pymc3 import memoize
1618

1719

18-
def getmemo():
19-
@memoize
20-
def f(a, b=("a")):
21-
return str(a) + str(b)
20+
def test_memo():
21+
def fun(inputs, suffix="_a"):
22+
return str(inputs) + str(suffix)
23+
inputs = ["i1", "i2"]
24+
assert fun(inputs) == "['i1', 'i2']_a"
25+
assert fun(inputs, "_b") == "['i1', 'i2']_b"
2226

23-
return f
27+
funmem = memoize.memoize(fun)
28+
assert hasattr(fun, "cache")
29+
assert isinstance(fun.cache, dict)
30+
assert len(fun.cache) == 0
31+
32+
assert funmem(inputs) == "['i1', 'i2']_a"
33+
assert funmem(inputs) == "['i1', 'i2']_a"
34+
assert len(fun.cache) == 1
35+
assert funmem(inputs, "_b") == "['i1', 'i2']_b"
36+
assert funmem(inputs, "_b") == "['i1', 'i2']_b"
37+
assert len(fun.cache) == 2
2438

39+
# add items to the inputs list (the list instance remains identical !!)
40+
inputs.append("i3")
41+
assert funmem(inputs) == "['i1', 'i2', 'i3']_a"
42+
assert funmem(inputs) == "['i1', 'i2', 'i3']_a"
43+
assert len(fun.cache) == 3
2544

26-
def test_memo():
27-
f = getmemo()
2845

29-
assert f("x", ["y", "z"]) == "x['y', 'z']"
30-
assert f("x", ["a", "z"]) == "x['a', 'z']"
31-
assert f("x", ["y", "z"]) == "x['y', 'z']"
46+
def test_hashing_of_rv_tuples():
47+
obs = np.random.normal(-1, 0.1, size=10)
48+
with pm.Model() as pmodel:
49+
mu = pm.Normal("mu", 0, 1)
50+
sd = pm.Gamma("sd", 1, 2)
51+
dd = pm.DensityDist(
52+
"dd",
53+
pm.Normal.dist(mu, sd).logp,
54+
random=pm.Normal.dist(mu, sd).random,
55+
observed=obs,
56+
)
57+
print()
58+
for freerv in [mu, sd, dd] + pmodel.free_RVs:
59+
for structure in [
60+
freerv,
61+
dict(alpha=freerv, omega=None),
62+
[freerv, []],
63+
(freerv, []),
64+
]:
65+
print(f"testing hashing of: {structure}")
66+
assert isinstance(memoize.hashable(structure), int)

pymc3/tests/test_sampling.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -930,9 +930,7 @@ def test_shared(self):
930930

931931
assert gen2["y"].shape == (draws, n2)
932932

933-
@pytest.mark.skipif(sys.platform.startswith("win"), reason="Stackoverflow error on Windows")
934933
def test_density_dist(self):
935-
936934
obs = np.random.normal(-1, 0.1, size=10)
937935
with pm.Model():
938936
mu = pm.Normal("mu", 0, 1)

0 commit comments

Comments
 (0)