Skip to content

Commit 2caa005

Browse files
burnpancktwiecki
authored andcommitted
Fix #895: Bug pickling model instances (#1560)
* added regression test for #895 * fixes #895
1 parent c0cc253 commit 2caa005

File tree

2 files changed

+27
-4
lines changed

2 files changed

+27
-4
lines changed

pymc3/distributions/distribution.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@
1010
__all__ = ['DensityDist', 'Distribution', 'Continuous',
1111
'Discrete', 'NoDistribution', 'TensorType', 'draw_values']
1212

13+
class _Unpickling(object):
14+
pass
1315

1416
class Distribution(object):
1517
"""Statistical distribution"""
1618
def __new__(cls, name, *args, **kwargs):
19+
if name is _Unpickling:
20+
return object.__new__(cls) # for pickle
1721
try:
1822
model = Model.get_context()
1923
except TypeError:
@@ -25,13 +29,11 @@ def __new__(cls, name, *args, **kwargs):
2529
data = kwargs.pop('observed', None)
2630
dist = cls.dist(*args, **kwargs)
2731
return model.Var(name, dist, data)
28-
elif name is None:
29-
return object.__new__(cls) # for pickle
3032
else:
31-
raise TypeError("needed name or None but got: %s" % name)
33+
raise TypeError("Name needs to be a string but got: %s" % name)
3234

3335
def __getnewargs__(self):
34-
return None,
36+
return _Unpickling,
3537

3638
@classmethod
3739
def dist(cls, *args, **kwargs):

pymc3/tests/test_pickling.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import unittest
2+
import pickle
3+
import traceback
4+
from .models import simple_model
5+
6+
7+
class TestPickling(unittest.TestCase):
8+
def setUp(self):
9+
_, self.model, _ = simple_model()
10+
11+
def test_model_roundtrip(self):
12+
m = self.model
13+
for proto in range(pickle.HIGHEST_PROTOCOL+1):
14+
try:
15+
s = pickle.dumps(m, proto)
16+
n = pickle.loads(s)
17+
except Exception as ex:
18+
raise AssertionError(
19+
"Exception while trying roundtrip with pickle protocol %d:\n"%proto +
20+
''.join(traceback.format_exc())
21+
)

0 commit comments

Comments
 (0)