diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 066dbfd26d..734722ae49 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import numbers import contextvars import dill import inspect +import sys +import types from typing import TYPE_CHECKING +import warnings if TYPE_CHECKING: from typing import Optional, Callable @@ -505,6 +509,19 @@ def __init__( dtype = theano.config.floatX super().__init__(shape, dtype, testval, *args, **kwargs) self.logp = logp + if type(self.logp) == types.MethodType: + if sys.platform != "linux": + warnings.warn( + "You are passing a bound method as logp for DensityDist, this can lead to " + + "errors when sampling on platforms other than Linux. Consider using a " + + "plain function instead, or subclass Distribution." + ) + elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext: + warnings.warn( + "You are passing a bound method as logp for DensityDist, this can lead to " + + "errors when sampling when multiprocessing cannot rely on forking. Consider using a " + + "plain function instead, or subclass Distribution." + ) self.rand = random self.wrap_random_with_dist_shape = wrap_random_with_dist_shape self.check_shape_in_random = check_shape_in_random @@ -513,7 +530,15 @@ def __getstate__(self): # We use dill to serialize the logp function, as this is almost # always defined in the notebook and won't be pickled correctly. # Fix https://github.com/pymc-devs/pymc3/issues/3844 - logp = dill.dumps(self.logp) + try: + logp = dill.dumps(self.logp) + except RecursionError as err: + if type(self.logp) == types.MethodType: + raise ValueError( + "logp for DensityDist is a bound method, leading to RecursionError while serializing" + ) from err + else: + raise err vals = self.__dict__.copy() vals["logp"] = logp return vals diff --git a/pymc3/tests/test_distributions_random.py b/pymc3/tests/test_distributions_random.py index 84bd8bc117..a789674095 100644 --- a/pymc3/tests/test_distributions_random.py +++ b/pymc3/tests/test_distributions_random.py @@ -1171,7 +1171,7 @@ def test_density_dist_with_random_sampleable(self, shape): shape=shape, random=normal_dist.random, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 size = 100 @@ -1194,7 +1194,7 @@ def test_density_dist_with_random_sampleable_failure(self, shape): random=normal_dist.random, wrap_random_with_dist_shape=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 with pytest.raises(RuntimeError): @@ -1217,7 +1217,7 @@ def test_density_dist_with_random_sampleable_hidden_error(self, shape): wrap_random_with_dist_shape=False, check_shape_in_random=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 ppc = pm.sample_posterior_predictive(trace, samples=samples, model=model) @@ -1240,7 +1240,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success(self): random=rvs, wrap_random_with_dist_shape=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 size = 100 @@ -1260,7 +1260,7 @@ def test_density_dist_with_random_sampleable_handcrafted_success_fast(self): random=rvs, wrap_random_with_dist_shape=False, ) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 size = 100 @@ -1273,7 +1273,7 @@ def test_density_dist_without_random_not_sampleable(self): mu = pm.Normal("mu", 0, 1) normal_dist = pm.Normal.dist(mu, 1) pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) - trace = pm.sample(100) + trace = pm.sample(100, cores=1) samples = 500 with pytest.raises(ValueError): diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py index b5de8332cc..f8063663e8 100644 --- a/pymc3/tests/test_parallel_sampling.py +++ b/pymc3/tests/test_parallel_sampling.py @@ -159,3 +159,45 @@ def test_iterator(): with sampler: for draw in sampler: pass + + +def test_spawn_densitydist_function(): + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + + def func(x): + return -2 * (x ** 2).sum() + + obs = pm.DensityDist("density_dist", func, observed=np.random.randn(100)) + trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + + +@pytest.mark.xfail(raises=ValueError) +def test_spawn_densitydist_bound_method(): + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + normal_dist = pm.Normal.dist(mu, 1) + obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) + trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + + +# cannot test this properly: monkeypatching sys.platform messes up Theano +# def test_spawn_densitydist_syswarning(monkeypatch): +# monkeypatch.setattr(sys, "platform", "win32") +# with pm.Model() as model: +# mu = pm.Normal('mu', 0, 1) +# normal_dist = pm.Normal.dist(mu, 1) +# with pytest.warns(UserWarning) as w: +# obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100)) +# assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0] + + +def test_spawn_densitydist_mpctxwarning(monkeypatch): + ctx = multiprocessing.get_context("spawn") + monkeypatch.setattr(multiprocessing, "get_context", lambda: ctx) + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + normal_dist = pm.Normal.dist(mu, 1) + with pytest.warns(UserWarning) as w: + obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) + assert len(w) == 1 and "errors when sampling when multiprocessing" in w[0].message.args[0]