Skip to content

Commit f026fb8

Browse files
ferrinetwiecki
authored andcommitted
fix test
1 parent c950a2a commit f026fb8

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

pymc3/memoize.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ def __getstate__(self):
5555
state.pop('_cache', None)
5656
return state
5757

58+
def __setstate__(self, state):
59+
self.__dict__.update(state)
60+
5861

5962
def hashable(a):
6063
"""

pymc3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ def __init__(self, mean=0, sd=1, name='', model=None):
569569
"""
570570
def __new__(cls, *args, **kwargs):
571571
# resolves the parent instance
572-
instance = object.__new__(cls)
572+
instance = super(Model, cls).__new__(cls)
573573
if kwargs.get('model') is not None:
574574
instance._parent = kwargs.get('model')
575575
elif cls.get_contexts():

pymc3/tests/test_variational_inference.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -612,16 +612,24 @@ def test_remove_scan_op():
612612

613613

614614
def test_clear_cache():
615+
import pickle
615616
pymc3.memoize.clear_cache()
617+
assert all(len(c) == 0 for c in pymc3.memoize.CACHE_REGISTRY)
616618
with pm.Model():
617619
pm.Normal('n', 0, 1)
618620
inference = ADVI()
619621
inference.fit(n=10)
620-
assert len(pm.variational.opvi.Approximation.logp.fget.cache) == 1
621-
del inference
622-
assert len(pm.variational.opvi.Approximation.logp.fget.cache) == 0
623-
for c in pymc3.memoize.CACHE_REGISTRY:
624-
assert len(c) == 0
622+
assert any(len(c) != 0 for c in inference.approx._cache.values())
623+
pymc3.memoize.clear_cache(inference.approx)
624+
# should not be cleared at this call
625+
assert all(len(c) == 0 for c in inference.approx._cache.values())
626+
new_a = pickle.loads(pickle.dumps(inference.approx))
627+
assert not hasattr(new_a, '_cache')
628+
inference_new = pm.KLqp(new_a)
629+
inference_new.fit(n=10)
630+
assert any(len(c) != 0 for c in inference_new.approx._cache.values())
631+
pymc3.memoize.clear_cache(inference_new.approx)
632+
assert all(len(c) == 0 for c in inference_new.approx._cache.values())
625633

626634

627635
@pytest.fixture('module')

pymc3/variational/opvi.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -755,13 +755,13 @@ def __new__(cls, group=None, vfam=None, params=None, *args, **kwargs):
755755
if vfam is not None and params is not None:
756756
raise TypeError('Cannot call Group with both `vfam` and `params` provided')
757757
elif vfam is not None:
758-
return object.__new__(cls.group_for_short_name(vfam))
758+
return super(Group, cls).__new__(cls.group_for_short_name(vfam))
759759
elif params is not None:
760-
return object.__new__(cls.group_for_params(params))
760+
return super(Group, cls).__new__(cls.group_for_params(params))
761761
else:
762762
raise TypeError('Need to call Group with either `vfam` or `params` provided')
763763
else:
764-
return object.__new__(cls)
764+
return super(Group, cls).__new__(cls)
765765

766766
def __init__(self, group,
767767
vfam=None,

0 commit comments

Comments
 (0)