Skip to content

Commit 8076a26

Browse files
committed
uncomment test
1 parent 3fa3d1f commit 8076a26

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

pymc3/distributions/distribution.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@
5454
"vectorized_ppc", default=None
5555
) # type: contextvars.ContextVar[Optional[Callable]]
5656

57+
PLATFORM = sys.platform
58+
5759

5860
class _Unpickling:
5961
pass
@@ -510,17 +512,17 @@ def __init__(
510512
super().__init__(shape, dtype, testval, *args, **kwargs)
511513
self.logp = logp
512514
if type(self.logp) == types.MethodType:
513-
if sys.platform != "linux":
515+
if PLATFORM != "linux":
514516
warnings.warn(
515517
"You are passing a bound method as logp for DensityDist, this can lead to "
516-
+ "errors when sampling on platforms other than Linux. Consider using a "
517-
+ "plain function instead, or subclass Distribution."
518+
"errors when sampling on platforms other than Linux. Consider using a "
519+
"plain function instead, or subclass Distribution."
518520
)
519521
elif type(multiprocessing.get_context()) != multiprocessing.context.ForkContext:
520522
warnings.warn(
521523
"You are passing a bound method as logp for DensityDist, this can lead to "
522-
+ "errors when sampling when multiprocessing cannot rely on forking. Consider using a "
523-
+ "plain function instead, or subclass Distribution."
524+
"errors when sampling when multiprocessing cannot rely on forking. Consider using a "
525+
"plain function instead, or subclass Distribution."
524526
)
525527
self.rand = random
526528
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape

pymc3/tests/test_parallel_sampling.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -172,24 +172,25 @@ def func(x):
172172
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
173173

174174

175-
@pytest.mark.xfail(raises=ValueError)
176175
def test_spawn_densitydist_bound_method():
177176
with pm.Model() as model:
178177
mu = pm.Normal("mu", 0, 1)
179178
normal_dist = pm.Normal.dist(mu, 1)
180179
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
181-
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
180+
with pytest.raises(
181+
ValueError,
182+
match="logp for DensityDist is a bound method, leading to RecursionError while serializing",
183+
):
184+
trace = pm.sample(draws=10, tune=10, step=pm.Metropolis(), cores=2, mp_ctx="spawn")
182185

183186

184-
# cannot test this properly: monkeypatching sys.platform messes up Theano
185-
# def test_spawn_densitydist_syswarning(monkeypatch):
186-
# monkeypatch.setattr(sys, "platform", "win32")
187-
# with pm.Model() as model:
188-
# mu = pm.Normal('mu', 0, 1)
189-
# normal_dist = pm.Normal.dist(mu, 1)
190-
# with pytest.warns(UserWarning) as w:
191-
# obs = pm.DensityDist('density_dist', normal_dist.logp, observed=np.random.randn(100))
192-
# assert len(w) == 1 and "errors when sampling on platforms" in w[0].message.args[0]
187+
def test_spawn_densitydist_syswarning(monkeypatch):
188+
monkeypatch.setattr("pymc3.distributions.distribution.PLATFORM", "win32")
189+
with pm.Model() as model:
190+
mu = pm.Normal("mu", 0, 1)
191+
normal_dist = pm.Normal.dist(mu, 1)
192+
with pytest.warns(UserWarning, match="errors when sampling on platforms"):
193+
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
193194

194195

195196
def test_spawn_densitydist_mpctxwarning(monkeypatch):
@@ -198,6 +199,5 @@ def test_spawn_densitydist_mpctxwarning(monkeypatch):
198199
with pm.Model() as model:
199200
mu = pm.Normal("mu", 0, 1)
200201
normal_dist = pm.Normal.dist(mu, 1)
201-
with pytest.warns(UserWarning) as w:
202+
with pytest.warns(UserWarning, match="errors when sampling when multiprocessing"):
202203
obs = pm.DensityDist("density_dist", normal_dist.logp, observed=np.random.randn(100))
203-
assert len(w) == 1 and "errors when sampling when multiprocessing" in w[0].message.args[0]

0 commit comments

Comments
 (0)