Skip to content

Commit c950a2a

Browse files
ferrinetwiecki
authored andcommitted
refactor memoize
1 parent 553f057 commit c950a2a

File tree

5 files changed

+54
-31
lines changed

5 files changed

+54
-31
lines changed

pymc3/memoize.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,59 @@
11
import functools
22
import pickle
3-
3+
import collections
4+
from .util import biwrap
45
CACHE_REGISTRY = []
56

67

7-
def memoize(obj):
8+
@biwrap
9+
def memoize(obj, bound=False):
810
"""
911
An expensive memoizer that works with unhashables
1012
"""
11-
cache = obj.cache = {}
12-
CACHE_REGISTRY.append(cache)
13+
# this is declared not to be a bound method, so just attach new attr to obj
14+
if not bound:
15+
obj.cache = {}
16+
CACHE_REGISTRY.append(obj.cache)
1317

1418
@functools.wraps(obj)
1519
def memoizer(*args, **kwargs):
16-
# remember first argument as well, used to clear cache for particular instance
17-
key = (hashable(args[:1]), hashable(args), hashable(kwargs))
18-
20+
if not bound:
21+
key = (hashable(args), hashable(kwargs))
22+
cache = obj.cache
23+
else:
24+
# bound methods have self as first argument, remove it to compute key
25+
key = (hashable(args[1:]), hashable(kwargs))
26+
if not hasattr(args[0], '_cache'):
27+
setattr(args[0], '_cache', collections.defaultdict(dict))
28+
# do not add to cache regestry
29+
cache = getattr(args[0], '_cache')[obj.__name__]
1930
if key not in cache:
2031
cache[key] = obj(*args, **kwargs)
2132

2233
return cache[key]
2334
return memoizer
2435

2536

26-
def clear_cache():
27-
for c in CACHE_REGISTRY:
28-
c.clear()
37+
def clear_cache(obj=None):
38+
if obj is None:
39+
for c in CACHE_REGISTRY:
40+
c.clear()
41+
else:
42+
if isinstance(obj, WithMemoization):
43+
for v in getattr(obj, '_cache', {}).values():
44+
v.clear()
45+
else:
46+
obj.cache.clear()
2947

3048

3149
class WithMemoization(object):
3250
def __hash__(self):
3351
return hash(id(self))
3452

35-
def __del__(self):
36-
# regular property call with args (self, )
37-
key = hash((self, ))
38-
to_del = []
39-
for c in CACHE_REGISTRY:
40-
for k in c.keys():
41-
if k[0] == key:
42-
to_del.append((c, k))
43-
for (c, k) in to_del:
44-
del c[k]
53+
def __getstate__(self):
54+
state = self.__dict__.copy()
55+
state.pop('_cache', None)
56+
return state
4557

4658

4759
def hashable(a):

pymc3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ def isroot(self):
619619
return self.parent is None
620620

621621
@property
622-
@memoize
622+
@memoize(bound=True)
623623
def bijection(self):
624624
vars = inputvars(self.cont_vars)
625625

pymc3/util.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
2+
import functools
33
from numpy import asscalar
44

55
LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)
@@ -152,3 +152,20 @@ def get_transformed(z):
152152
if hasattr(z, 'transformed'):
153153
z = z.transformed
154154
return z
155+
156+
157+
def biwrap(wrapper):
158+
@functools.wraps(wrapper)
159+
def enhanced(*args, **kwargs):
160+
is_bound_method = hasattr(args[0], wrapper.__name__) if args else False
161+
if is_bound_method:
162+
count = 1
163+
else:
164+
count = 0
165+
if len(args) > count:
166+
newfn = wrapper(*args, **kwargs)
167+
return newfn
168+
else:
169+
newwrapper = functools.partial(wrapper, *args, **kwargs)
170+
return newwrapper
171+
return enhanced

pymc3/variational/approximations.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,12 +539,6 @@ def __init__(self, *args, **kwargs):
539539
def __getattr__(self, item):
540540
return getattr(self.groups[0], item)
541541

542-
def __getstate__(self):
543-
return self.__dict__.copy()
544-
545-
def __setstate__(self, state):
546-
self.__dict__.update(state)
547-
548542

549543
class MeanField(SingleGroupApproximation):
550544
__doc__ = """**Single Group Mean Field Approximation**

pymc3/variational/opvi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,10 +106,10 @@ def node_property(f):
106106
if isinstance(f, str):
107107

108108
def wrapper(fn):
109-
return property(memoize(change_flags(compute_test_value='off')(append_name(f)(fn))))
109+
return property(memoize(change_flags(compute_test_value='off')(append_name(f)(fn)), bound=True))
110110
return wrapper
111111
else:
112-
return property(memoize(change_flags(compute_test_value='off')(f)))
112+
return property(memoize(change_flags(compute_test_value='off')(f), bound=True))
113113

114114

115115
@change_flags(compute_test_value='ignore')
@@ -1487,7 +1487,7 @@ def vars_names(vs):
14871487
return found
14881488

14891489
@property
1490-
@memoize
1490+
@memoize(bound=True)
14911491
@change_flags(compute_test_value='off')
14921492
def sample_dict_fn(self):
14931493
s = tt.iscalar()

0 commit comments

Comments
 (0)