Skip to content

Commit e08ad0c

Browse files
authored
Use dill to serialize logp functions in DensityDist (#4053)
* Use dill to serialize logp functions in DensityDist * Update testenv based on yml file on travis * Explicitly test pickling and unpickling of DensityDist * Improve release notes * Use conda activate in create testenv
1 parent aaafa8d commit e08ad0c

File tree

7 files changed

+36
-3
lines changed

7 files changed

+36
-3
lines changed

RELEASE-NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
### Maintenance
66
- Mentioned the way to do any random walk with `theano.tensor.cumsum()` in `GaussianRandomWalk` docstrings (see [#4048](https://github.com/pymc-devs/pymc3/pull/4048)).
77
- Fixed numerical instability in ExGaussian's logp by preventing `logpow` from returning `-inf` (see [#4050](https://github.com/pymc-devs/pymc3/pull/4050)).
8+
- Use dill to serialize user defined logp functions in `DensityDist`. The previous serialization code fails if it is used in notebooks on Windows and Mac. `dill` is now a required dependency. (see [#3844](https://github.com/pymc-devs/pymc3/issues/3844)).
89

910
### Documentation
1011

environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ dependencies:
3737
- dataclasses # python_version < 3.7
3838
- contextvars # python_version < 3.7
3939
- mkl-service
40+
- dill
4041
- libblas=*=*mkl
4142
- pip:
4243
- black_nbconvert
43-
- dill

pymc3/distributions/distribution.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import numbers
1616
import contextvars
17+
import dill
1718
from typing import TYPE_CHECKING
1819
if TYPE_CHECKING:
1920
from typing import Optional, Callable
@@ -419,6 +420,19 @@ def __init__(
419420
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
420421
self.check_shape_in_random = check_shape_in_random
421422

423+
def __getstate__(self):
424+
# We use dill to serialize the logp function, as this is almost
425+
# always defined in the notebook and won't be pickled correctly.
426+
# Fix https://github.com/pymc-devs/pymc3/issues/3844
427+
logp = dill.dumps(self.logp)
428+
vals = self.__dict__.copy()
429+
vals['logp'] = logp
430+
return vals
431+
432+
def __setstate__(self, vals):
433+
vals['logp'] = dill.loads(vals['logp'])
434+
self.__dict__ = vals
435+
422436
def random(self, point=None, size=None, **kwargs):
423437
if self.rand is not None:
424438
not_broadcast_kwargs = dict(point=point)

pymc3/tests/test_distributions.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@
7979

8080
from ..distributions import continuous
8181
from pymc3.theanof import floatX
82+
import pymc3 as pm
8283
from numpy import array, inf, log, exp
8384
from numpy.testing import assert_almost_equal, assert_allclose, assert_equal
8485
import numpy.random as nr
@@ -1872,3 +1873,16 @@ def test_issue_3051(self, dims, dist_cls, kwargs):
18721873
assert isinstance(actual_a, np.ndarray)
18731874
assert actual_a.shape == (X.shape[0],)
18741875
pass
1876+
1877+
1878+
def test_serialize_density_dist():
1879+
def func(x):
1880+
return -2 * (x ** 2).sum()
1881+
1882+
with pm.Model():
1883+
pm.Normal('x')
1884+
y = pm.DensityDist('y', func)
1885+
pm.sample(draws=5, tune=1, mp_ctx="spawn")
1886+
1887+
import pickle
1888+
pickle.loads(pickle.dumps(y))

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ sphinx-autobuild==0.7.1
1717
sphinx>=1.5.5
1818
watermark
1919
parameterized
20-
dill
20+
dill

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,4 @@ h5py>=2.7.0
99
typing-extensions>=3.7.4
1010
dataclasses; python_version < '3.7'
1111
contextvars; python_version < '3.7'
12+
dill

scripts/create_testenv.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,18 @@ command -v conda >/dev/null 2>&1 || {
2222
ENVNAME="${ENVNAME:-testenv}" # if no ENVNAME is specified, use testenv
2323

2424
if [ -z ${GLOBAL} ]; then
25+
source $(dirname $(dirname $(which conda)))/etc/profile.d/conda.sh
2526
if conda env list | grep -q ${ENVNAME}; then
2627
echo "Environment ${ENVNAME} already exists, keeping up to date"
28+
conda activate ${ENVNAME}
29+
mamba env update -f environment-dev.yml
2730
else
2831
conda config --add channels conda-forge
2932
conda config --set channel_priority strict
3033
conda install -c conda-forge mamba --yes
3134
mamba env create -f environment-dev.yml
35+
conda activate ${ENVNAME}
3236
fi
33-
source activate ${ENVNAME}
3437
fi
3538

3639
# Install editable using the setup.py

0 commit comments

Comments
 (0)