Skip to content

Commit 4fe8671

Browse files
Replace custom memoize module with cachetools
1 parent 2901888 commit 4fe8671

File tree

11 files changed

+154
-210
lines changed

11 files changed

+154
-210
lines changed

pymc3/distributions/distribution.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@
3737
from aesara.graph.basic import Constant
3838
from aesara.tensor.type import TensorType as AesaraTensorType
3939
from aesara.tensor.var import TensorVariable
40+
from cachetools import LRUCache, cached
4041

4142
from pymc3.distributions.shape_utils import (
4243
broadcast_dist_samples_shape,
4344
get_broadcastable_dist_samples,
4445
to_tuple,
4546
)
46-
from pymc3.memoize import memoize
4747
from pymc3.model import (
4848
ContextMeta,
4949
FreeRV,
@@ -52,7 +52,7 @@
5252
ObservedRV,
5353
build_named_node_tree,
5454
)
55-
from pymc3.util import get_repr_for_variable, get_var_name
55+
from pymc3.util import get_repr_for_variable, get_var_name, hash_key
5656
from pymc3.vartypes import string_types
5757

5858
__all__ = [
@@ -841,7 +841,7 @@ def draw_values(params, point=None, size=None):
841841
return [evaluated[j] for j in params] # set the order back
842842

843843

844-
@memoize
844+
@cached(LRUCache(128), key=hash_key)
845845
def _compile_aesara_function(param, vars, givens=None):
846846
"""Compile aesara function for a given parameter and input variables.
847847

pymc3/memoize.py

Lines changed: 0 additions & 113 deletions
This file was deleted.

pymc3/model.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from aesara.graph.basic import Apply, Variable
3333
from aesara.tensor.type import TensorType as AesaraTensorType
3434
from aesara.tensor.var import TensorVariable
35+
from cachetools import LRUCache, cachedmethod
3536
from pandas import Series
3637

3738
import pymc3 as pm
@@ -40,8 +41,7 @@
4041
from pymc3.blocking import ArrayOrdering, DictToArrayBijection
4142
from pymc3.exceptions import ImputationWarning
4243
from pymc3.math import flatten_list
43-
from pymc3.memoize import WithMemoization, memoize
44-
from pymc3.util import get_transformed_name, get_var_name
44+
from pymc3.util import WithMemoization, get_transformed_name, get_var_name, hash_key
4545
from pymc3.vartypes import continuous_types, discrete_types, isgenerator, typefilter
4646

4747
__all__ = [
@@ -946,7 +946,9 @@ def isroot(self):
946946
return self.parent is None
947947

948948
@property # type: ignore
949-
@memoize(bound=True)
949+
@cachedmethod(
950+
lambda self: self.__dict__.setdefault("_bijection_cache", LRUCache(128)), key=hash_key
951+
)
950952
def bijection(self):
951953
vars = inputvars(self.vars)
952954

pymc3/tests/test_memo.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

pymc3/tests/test_util.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515
import numpy as np
1616
import pytest
1717

18+
from cachetools import cached
1819
from numpy.testing import assert_almost_equal
1920

2021
import pymc3 as pm
2122

2223
from pymc3.distributions.transforms import Transform
2324
from pymc3.tests.helpers import SeededTest
25+
from pymc3.util import hash_key, hashable, locally_cachedmethod
2426

2527

2628
class TestTransformName:
@@ -167,3 +169,53 @@ def test_dtype_error(self):
167169
raise pm.exceptions.DtypeError("With types.", actual=int, expected=str)
168170
assert "int" in exinfo.value.args[0] and "str" in exinfo.value.args[0]
169171
pass
172+
173+
174+
def test_hashing_of_rv_tuples():
175+
obs = np.random.normal(-1, 0.1, size=10)
176+
with pm.Model() as pmodel:
177+
mu = pm.Normal("mu", 0, 1)
178+
sd = pm.Gamma("sd", 1, 2)
179+
dd = pm.DensityDist(
180+
"dd",
181+
pm.Normal.dist(mu, sd).logp,
182+
random=pm.Normal.dist(mu, sd).random,
183+
observed=obs,
184+
)
185+
for freerv in [mu, sd, dd] + pmodel.free_RVs:
186+
for structure in [
187+
freerv,
188+
{"alpha": freerv, "omega": None},
189+
[freerv, []],
190+
(freerv, []),
191+
]:
192+
assert isinstance(hashable(structure), int)
193+
194+
195+
def test_hash_key():
196+
class Bad1:
197+
def __hash__(self):
198+
return 329
199+
200+
class Bad2:
201+
def __hash__(self):
202+
return 329
203+
204+
b1 = Bad1()
205+
b2 = Bad2()
206+
207+
assert b1 != b2
208+
209+
@cached({}, key=hash_key)
210+
def some_func(x):
211+
return x
212+
213+
assert some_func(b1) != some_func(b2)
214+
215+
class TestClass:
216+
@locally_cachedmethod
217+
def some_method(self, x):
218+
return x
219+
220+
tc = TestClass()
221+
assert tc.some_method(b1) != tc.some_method(b2)

pymc3/tests/test_variational_inference.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import pytest
2323

2424
import pymc3 as pm
25-
import pymc3.memoize
2625
import pymc3.util
2726

2827
from pymc3.aesaraf import intX
@@ -757,22 +756,20 @@ def test_remove_scan_op():
757756
def test_clear_cache():
758757
import pickle
759758

760-
pymc3.memoize.clear_cache()
761-
assert all(len(c) == 0 for c in pymc3.memoize.CACHE_REGISTRY)
762759
with pm.Model():
763760
pm.Normal("n", 0, 1)
764761
inference = ADVI()
765762
inference.fit(n=10)
766763
assert any(len(c) != 0 for c in inference.approx._cache.values())
767-
pymc3.memoize.clear_cache(inference.approx)
764+
inference.approx._cache.clear()
768765
# should not be cleared at this call
769766
assert all(len(c) == 0 for c in inference.approx._cache.values())
770767
new_a = pickle.loads(pickle.dumps(inference.approx))
771768
assert not hasattr(new_a, "_cache")
772769
inference_new = pm.KLqp(new_a)
773770
inference_new.fit(n=10)
774771
assert any(len(c) != 0 for c in inference_new.approx._cache.values())
775-
pymc3.memoize.clear_cache(inference_new.approx)
772+
inference_new.approx._cache.clear()
776773
assert all(len(c) == 0 for c in inference_new.approx._cache.values())
777774

778775

0 commit comments

Comments
 (0)