diff --git a/pymc3/distributions/distribution.py b/pymc3/distributions/distribution.py index 734722ae49..c8c895c188 100644 --- a/pymc3/distributions/distribution.py +++ b/pymc3/distributions/distribution.py @@ -54,6 +54,8 @@ "vectorized_ppc", default=None ) # type: contextvars.ContextVar[Optional[Callable]] +PLATFORM = sys.platform + class _Unpickling: pass @@ -510,17 +512,17 @@ def __init__( super().__init__(shape, dtype, testval, *args, **kwargs) self.logp = logp if type(self.logp) == types.MethodType: - if sys.platform != "linux": + if PLATFORM != "linux": warnings.warn( "You are passing a bound method as logp for DensityDist, this can lead to " - + "errors when sampling on platforms other than Linux. Consider using a " - + "plain function instead, or subclass Distribution." + "errors when sampling on platforms other than Linux. Consider using a " + "plain function instead, or subclass Distribution." ) elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext: warnings.warn( "You are passing a bound method as logp for DensityDist, this can lead to " - + "errors when sampling when multiprocessing cannot rely on forking. Consider using a " - + "plain function instead, or subclass Distribution." + "errors when sampling when multiprocessing cannot rely on forking. Consider using a " + "plain function instead, or subclass Distribution." ) self.rand = random self.wrap_random_with_dist_shape = wrap_random_with_dist_shape diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py index f8063663e8..21715d6591 100644 --- a/pymc3/tests/test_parallel_sampling.py +++ b/pymc3/tests/test_parallel_sampling.py @@ -172,24 +172,23 @@ def func(x): trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") -@pytest.mark.xfail(raises=ValueError) def test_spawn_densitydist_bound_method(): with pm.Model() as model: mu = pm.Normal("mu", 0, 1) normal_dist = pm.Normal.dist(mu, 1) obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) - trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") + msg = "logp for DensityDist is a bound method, leading to RecursionError while serializing" + with pytest.raises(ValueError, match=msg): + trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn") -# cannot test this properly: monkeypatching sys.platform messes up Theano -# def test_spawn_densitydist_syswarning(monkeypatch): -# monkeypatch.setattr(sys, "platform", "win32") -# with pm.Model() as model: -# mu = pm.Normal('mu', 0, 1) -# normal_dist = pm.Normal.dist(mu, 1) -# with pytest.warns(UserWarning) as w: -# obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100)) -# assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0] +def test_spawn_densitydist_syswarning(monkeypatch): + monkeypatch.setattr("pymc3.distributions.distribution.PLATFORM", "win32") + with pm.Model() as model: + mu = pm.Normal("mu", 0, 1) + normal_dist = pm.Normal.dist(mu, 1) + with pytest.warns(UserWarning, match="errors when sampling on platforms"): + obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) def test_spawn_densitydist_mpctxwarning(monkeypatch): @@ -198,6 +197,5 @@ def test_spawn_densitydist_mpctxwarning(monkeypatch): with pm.Model() as model: mu = pm.Normal("mu", 0, 1) normal_dist = pm.Normal.dist(mu, 1) - with pytest.warns(UserWarning) as w: + with pytest.warns(UserWarning, match="errors when sampling when multiprocessing"): obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100)) - assert len(w) == 1 and "errors when sampling when multiprocessing" in w[0].message.args[0]