Skip to content

Commit 68e53eb

Browse files
burnpancktwiecki
authored andcommitted
BUG placed context stack inside thread-local data space (#1555)
* placed context stack inside thread-local data space * fixed thread-local context manager, and added a regression test * fixed docstring of #1552 regression test
1 parent 8a6d87e commit 68e53eb

File tree

2 files changed

+56
-4
lines changed

2 files changed

+56
-4
lines changed

pymc3/model.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import threading
2+
13
import numpy as np
24
import theano
35
import theano.tensor as tt
@@ -95,6 +97,7 @@ class Context(object):
9597
"""Functionality for objects that put themselves in a context using
9698
the `with` statement.
9799
"""
100+
contexts = threading.local()
98101

99102
def __enter__(self):
100103
type(self).get_contexts().append(self)
@@ -105,10 +108,11 @@ def __exit__(self, typ, value, traceback):
105108

106109
@classmethod
107110
def get_contexts(cls):
108-
if not hasattr(cls, "contexts"):
109-
cls.contexts = []
110-
111-
return cls.contexts
111+
# no race-condition here, cls.contexts is a thread-local object
112+
# be sure not to override contexts in a subclass however!
113+
if not hasattr(cls.contexts, 'stack'):
114+
cls.contexts.stack = []
115+
return cls.contexts.stack
112116

113117
@classmethod
114118
def get_context(cls):

pymc3/tests/test_modelcontext.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import threading
2+
import unittest
3+
4+
from pymc3 import Model, Normal
5+
6+
7+
class TestModelContext(unittest.TestCase):
8+
def test_thread_safety(self):
9+
""" Regression test for issue #1552: Thread safety of model context manager
10+
11+
This test creates two threads that attempt to construct two
12+
unrelated models at the same time.
13+
For repeatable testing, the two threads are syncronised such
14+
that thread A enters the context manager first, then B,
15+
then A attempts to declare a variable while B is still in the context manager.
16+
"""
17+
aInCtxt,bInCtxt,aDone = [threading.Event() for k in range(3)]
18+
modelA = Model()
19+
modelB = Model()
20+
def make_model_a():
21+
with modelA:
22+
aInCtxt.set()
23+
bInCtxt.wait()
24+
a = Normal('a',0,1)
25+
aDone.set()
26+
def make_model_b():
27+
aInCtxt.wait()
28+
with modelB:
29+
bInCtxt.set()
30+
aDone.wait()
31+
b = Normal('b', 0, 1)
32+
threadA = threading.Thread(target=make_model_a)
33+
threadB = threading.Thread(target=make_model_b)
34+
threadA.start()
35+
threadB.start()
36+
threadA.join()
37+
threadB.join()
38+
# now let's see which model got which variable
39+
# previous to #1555, the variables would be swapped:
40+
# - B enters it's model context after A, but before a is declared -> a goes into B
41+
# - A leaves it's model context before B attempts to declare b. A's context manager
42+
# takes B from the stack, such that b ends up in model A
43+
self.assertEqual(
44+
(
45+
list(modelA.named_vars),
46+
list(modelB.named_vars),
47+
), (['a'],['b'])
48+
)

0 commit comments

Comments
 (0)