Skip to content

Commit 4740bb2

Browse files
committed
update tests for cloudpickle
1 parent 8be0162 commit 4740bb2

File tree

6 files changed

+24
-34
lines changed

6 files changed

+24
-34
lines changed

pymc3/tests/test_distributions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3109,9 +3109,9 @@ def func(x):
31093109
y = pm.DensityDist("y", func)
31103110
pm.sample(draws=5, tune=1, mp_ctx="spawn")
31113111

3112-
import pickle
3112+
import cloudpickle
31133113

3114-
pickle.loads(pickle.dumps(y))
3114+
cloudpickle.loads(cloudpickle.dumps(y))
31153115

31163116

31173117
def test_distinct_rvs():

pymc3/tests/test_minibatches.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
# limitations under the License.
1414

1515
import itertools
16-
import pickle
1716

1817
import aesara
18+
import cloudpickle
1919
import numpy as np
2020
import pytest
2121

@@ -132,10 +132,10 @@ def gen():
132132

133133
def test_pickling(self, datagen):
134134
gen = generator(datagen)
135-
pickle.loads(pickle.dumps(gen))
135+
cloudpickle.loads(cloudpickle.dumps(gen))
136136
bad_gen = generator(integers())
137-
with pytest.raises(Exception):
138-
pickle.dumps(bad_gen)
137+
with pytest.raises(TypeError, match="cannot pickle 'generator' object"):
138+
cloudpickle.dumps(bad_gen)
139139

140140
def test_gen_cloning_with_shape_change(self, datagen):
141141
gen = generator(datagen)

pymc3/tests/test_model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import pickle
1514
import unittest
1615

1716
from functools import reduce
1817

1918
import aesara
2019
import aesara.sparse as sparse
2120
import aesara.tensor as at
21+
import cloudpickle
2222
import numpy as np
2323
import numpy.ma as ma
2424
import numpy.testing as npt
@@ -407,9 +407,7 @@ def test_model_pickle(tmpdir):
407407
x = pm.Normal("x")
408408
pm.Normal("y", observed=1)
409409

410-
file_path = tmpdir.join("model.p")
411-
with open(file_path, "wb") as buff:
412-
pickle.dump(model, buff)
410+
cloudpickle.loads(cloudpickle.dumps(model))
413411

414412

415413
def test_model_pickle_deterministic(tmpdir):
@@ -420,9 +418,7 @@ def test_model_pickle_deterministic(tmpdir):
420418
pm.Deterministic("w", x / z)
421419
pm.Normal("y", observed=1)
422420

423-
file_path = tmpdir.join("model.p")
424-
with open(file_path, "wb") as buff:
425-
pickle.dump(model, buff)
421+
cloudpickle.loads(cloudpickle.dumps(model))
426422

427423

428424
def test_model_vars():

pymc3/tests/test_parallel_sampling.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,6 @@ def _crash_remote_process(a, master_pid):
7171
return 2 * np.array(a)
7272

7373

74-
def test_dill():
75-
with pm.Model():
76-
pm.Normal("x")
77-
pm.sample(tune=1, draws=1, chains=2, cores=2, pickle_backend="dill", mp_ctx="spawn")
78-
79-
8074
def test_remote_pipe_closed():
8175
master_pid = os.getpid()
8276
with pm.Model():
@@ -112,7 +106,6 @@ def test_abort():
112106
mp_ctx=ctx,
113107
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
114108
step_method_pickled=None,
115-
pickle_backend="pickle",
116109
)
117110
proc.start()
118111
while True:
@@ -147,7 +140,6 @@ def test_explicit_sample():
147140
mp_ctx=ctx,
148141
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
149142
step_method_pickled=None,
150-
pickle_backend="pickle",
151143
)
152144
proc.start()
153145
while True:

pymc3/tests/test_pickling.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import pickle
1616
import traceback
1717

18+
import cloudpickle
19+
1820
from pymc3.tests.models import simple_model
1921

2022

@@ -26,8 +28,8 @@ def test_model_roundtrip(self):
2628
m = self.model
2729
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2830
try:
29-
s = pickle.dumps(m, proto)
30-
pickle.loads(s)
31+
s = cloudpickle.dumps(m, proto)
32+
cloudpickle.loads(s)
3133
except Exception:
3234
raise AssertionError(
3335
"Exception while trying roundtrip with pickle protocol %d:\n" % proto

pymc3/tests/test_variational_inference.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def test_remove_scan_op():
757757

758758

759759
def test_clear_cache():
760-
import pickle
760+
import cloudpickle
761761

762762
with pm.Model():
763763
pm.Normal("n", 0, 1)
@@ -767,7 +767,7 @@ def test_clear_cache():
767767
inference.approx._cache.clear()
768768
# should not be cleared at this call
769769
assert all(len(c) == 0 for c in inference.approx._cache.values())
770-
new_a = pickle.loads(pickle.dumps(inference.approx))
770+
new_a = cloudpickle.loads(cloudpickle.dumps(inference.approx))
771771
assert not hasattr(new_a, "_cache")
772772
inference_new = pm.KLqp(new_a)
773773
inference_new.fit(n=10)
@@ -871,26 +871,26 @@ def test_rowwise_approx(three_var_model, parametric_grouped_approxes):
871871

872872

873873
def test_pickle_approx(three_var_approx):
874-
import pickle
874+
import cloudpickle
875875

876-
dump = pickle.dumps(three_var_approx)
877-
new = pickle.loads(dump)
876+
dump = cloudpickle.dumps(three_var_approx)
877+
new = cloudpickle.loads(dump)
878878
assert new.sample(1)
879879

880880

881881
def test_pickle_single_group(three_var_approx_single_group_mf):
882-
import pickle
882+
import cloudpickle
883883

884-
dump = pickle.dumps(three_var_approx_single_group_mf)
885-
new = pickle.loads(dump)
884+
dump = cloudpickle.dumps(three_var_approx_single_group_mf)
885+
new = cloudpickle.loads(dump)
886886
assert new.sample(1)
887887

888888

889889
def test_pickle_approx_aevb(three_var_aevb_approx):
890-
import pickle
890+
import cloudpickle
891891

892-
dump = pickle.dumps(three_var_aevb_approx)
893-
new = pickle.loads(dump)
892+
dump = cloudpickle.dumps(three_var_aevb_approx)
893+
new = cloudpickle.loads(dump)
894894
assert new.sample(1000)
895895

896896

0 commit comments

Comments
 (0)