Skip to content

Commit d835151

Browse files
committed
remove special case serialization for DensityDist.logp
1 parent 2c466bf commit d835151

File tree

2 files changed

+2
-24
lines changed

2 files changed

+2
-24
lines changed

pymc3/distributions/distribution.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import aesara
2525
import aesara.tensor as at
26-
import dill
2726

2827
from aesara.tensor.random.op import RandomVariable
2928
from aesara.tensor.random.var import RandomStateSharedVariable
@@ -533,26 +532,5 @@ def __init__(
533532
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
534533
self.check_shape_in_random = check_shape_in_random
535534

536-
def __getstate__(self):
537-
# We use dill to serialize the logp function, as this is almost
538-
# always defined in the notebook and won't be pickled correctly.
539-
# Fix https://github.com/pymc-devs/pymc3/issues/3844
540-
try:
541-
logp = dill.dumps(self.logp)
542-
except RecursionError as err:
543-
if type(self.logp) == types.MethodType:
544-
raise ValueError(
545-
"logp for DensityDist is a bound method, leading to RecursionError while serializing"
546-
) from err
547-
else:
548-
raise err
549-
vals = self.__dict__.copy()
550-
vals["logp"] = logp
551-
return vals
552-
553-
def __setstate__(self, vals):
554-
vals["logp"] = dill.loads(vals["logp"])
555-
self.__dict__ = vals
556-
557535
def _distr_parameters_for_repr(self):
558536
return []

pymc3/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing import Dict, List, Tuple, Union
2020

2121
import arviz
22-
import dill
22+
import cloudpickle
2323
import numpy as np
2424
import xarray
2525

@@ -347,7 +347,7 @@ def hashable(a=None) -> int:
347347
pass
348348
# Not hashable >>>
349349
try:
350-
return hash(dill.dumps(a))
350+
return hash(cloudpickle.dumps(a))
351351
except Exception:
352352
if hasattr(a, "__dict__"):
353353
return hashable(a.__dict__)

0 commit comments

Comments
 (0)