Skip to content

Commit 339828d

Browse files
committed
Implemented get_first_level_conditionals to try to get rid of the added conditional_on attribute of every distribution. This function does a breadth first search on the node's logpt or transformed.logpt graph, looking for named nodes which are different from the root node, or the node's transformed, and is also not a TensorConstant or SharedVariable. Each branch was searched until the first named node was found. This way, the parent conditionals of the root searched node, which were only one step away from it in the bayesian network were returned. However, this ran into a problem with Mixture classes. These add to the logpt graph, another logpt graph from the comp_dists. This leads to the problem that the logpt's first level conditionals will also be seen as if they were first level conditional of the root. Furthermore, many copies of nodes done by the added logpt ended up being inserted into the computed conditional_on. This lead to a very strange error, in which loops appeared in the DAG, and depths started to be wrong. In particular, there were no depth 0 nodes. My view is that the explicit conditional_on attribute prevents problems like this one from happening, and so I left it as is, to discuss. Other changes done in this commit are that test_exact_step for the SMC uses draw_values on a hierarchy, and given that draw_values's behavior changed in the hierarchy situations, the exact trace values must also be adjusted. Finally test_bad_init was changed to run on one core, this way the parallel exception chaining does not change the exception type.
1 parent d43d149 commit 339828d

File tree

2 files changed

+88
-52
lines changed

2 files changed

+88
-52
lines changed

pymc3/model.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,6 +1542,52 @@ def not_shared_or_constant_variable(x):
15421542
) or (isinstance(x, (FreeRV, MultiObservedRV, TransformedRV)))
15431543

15441544

1545+
def get_first_level_conditionals(root):
1546+
"""Performs a breadth first search on the supplied root node's logpt or
1547+
transformed logpt graph searching for named input nodes, which are
1548+
different from the supplied root. Each explored branch will stop when
1549+
either when it ends or when it finds its first named node.
1550+
1551+
Parameters
1552+
----------
1553+
root: theano.Variable (mandatory)
1554+
The node from which to get the transformed.logpt or logpt and perform
1555+
the search. If root does not have either of these attributes, the
1556+
function returns None.
1557+
1558+
Returns
1559+
-------
1560+
conditional_on : set, with named nodes that are not theano.Constant nor
1561+
SharedVariable. The input `root` is conditionally dependent on these nodes
1562+
and is one step away from them in the bayesian network that specifies the
1563+
relationships, hence the name `get_first_level_conditionals`.
1564+
"""
1565+
transformed = getattr(root, 'transformed', None)
1566+
try:
1567+
cond = transformed.logpt
1568+
except AttributeError:
1569+
cond = getattr(root, 'logpt', None)
1570+
if cond is None:
1571+
return None
1572+
conditional_on = set()
1573+
queue = copy(getattr(cond.owner, 'inputs', []))
1574+
while queue:
1575+
parent = queue.pop(0)
1576+
if (parent is not None and getattr(parent, 'name', None) is not None
1577+
and not_shared_or_constant_variable(parent)):
1578+
# We don't include as a conditional relation either logpt depending
1579+
# on root or on transformed because they are both deterministic
1580+
# relations
1581+
if parent == root and parent == transformed:
1582+
conditional_on.add(parent)
1583+
else:
1584+
parent_owner = getattr(parent, 'owner', None)
1585+
queue.extend(getattr(parent_owner, 'inputs', []))
1586+
if not conditional_on:
1587+
return None
1588+
return conditional_on
1589+
1590+
15451591
class DependenceDAG(object):
15461592
"""
15471593
`DependenceDAG` instances represent the directed acyclic graph (DAG) that
@@ -1866,6 +1912,7 @@ def add(self, node, force=False, return_added_node=False,
18661912
self.depth[node] = 0
18671913

18681914
# Try to get the conditional parents of node and add them
1915+
# cond = get_first_level_conditionals(node)
18691916
try:
18701917
cond = node.distribution.conditional_on
18711918
except AttributeError:

pymc3/tests/test_step.py

Lines changed: 41 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -105,56 +105,46 @@ class TestStepMethods(object): # yield test doesn't work subclassing object
105105
1.58740483, 1.67905741, 0.77744868, 0.15050587, 0.15050587,
106106
0.73979127, 0.15445515, 0.13134717, 0.85068974, 0.85068974,
107107
0.6974799 , 0.16170472, 0.86405959, 0.86405959, -0.22032854]),
108-
SMC: np.array([ 5.10950205e-02, 1.09811720e+00, 1.78330202e-01, 6.85938766e-01,
109-
1.42354476e-01, -1.59630758e+00, 1.57176810e+00, -4.01398917e-01,
110-
1.14567871e+00, 1.14954938e+00, 4.94399840e-01, 1.16253017e+00,
111-
1.17432244e+00, 7.79195162e-01, 1.29017945e+00, 2.53722905e-01,
112-
5.38589898e-01, 3.52121216e-01, 1.35795966e+00, 1.02086933e-01,
113-
1.58845251e+00, 6.76852927e-01, -1.04716592e-02, -1.01613324e-01,
114-
1.37680965e+00, 7.40036542e-01, 2.89069320e-01, 1.48153741e+00,
115-
9.58156958e-01, 5.73623782e-02, 7.68850721e-01, 3.68643390e-01,
116-
1.47645964e+00, 2.32596780e-01, -1.85008158e-01, 3.71335958e-01,
117-
2.68600102e+00, -4.89504443e-01, 6.54265561e-02, 3.80455349e-01,
118-
1.17875338e+00, 2.30233324e-01, 6.90960231e-01, 8.81668685e-01,
119-
-2.19754340e-01, 1.27686862e-01, 3.28444250e-01, 1.34820635e-01,
120-
5.29725257e-01, 1.43783915e+00, -1.64754264e-01, 7.41446719e-01,
121-
-1.17733186e+00, 6.01215658e-02, 1.82638158e-01, -2.23232214e-02,
122-
-1.79877583e-02, 8.37949150e-01, 4.41964955e-01, -8.66524743e-01,
123-
4.90738093e-01, 2.42056488e-01, 4.67699626e-01, 2.91075351e-01,
124-
1.49541153e+00, 8.30730845e-01, 1.03956404e+00, -5.16162910e-01,
125-
2.84338859e-01, 1.72305888e+00, 9.52445566e-01, 1.48831718e+00,
126-
8.03455325e-01, 1.48840970e+00, 6.98122664e-01, 3.30187139e-01,
127-
7.88029712e-01, 9.31510828e-01, 1.01326878e+00, 2.26637755e-01,
128-
1.70703646e-01, -8.54429841e-01, 2.97254590e-01, -2.77843274e-01,
129-
-2.25544207e-01, 1.98862826e-02, 5.05953885e-01, 4.98203941e-01,
130-
1.20897382e+00, -6.32958669e-05, -7.22425896e-01, 1.60930869e+00,
131-
-5.02773645e-01, 2.46405678e+00, 9.16039706e-01, 1.14146060e+00,
132-
-1.95781984e-01, -2.44653942e-01, 2.67851290e-01, 2.37462012e-01,
133-
6.71471950e-01, 1.18319765e+00, 1.29146530e+00, -3.14177753e-01,
134-
-1.31041215e-02, 1.05029405e+00, 1.31202399e+00, 7.40532839e-02,
135-
9.15510041e-01, 7.71054604e-01, 9.83483263e-01, 9.03032142e-01,
136-
9.14191160e-01, 9.32285366e-01, 1.13937607e+00, -4.29155928e-01,
137-
3.44609229e-02, -5.46423555e-02, 1.34625982e+00, -1.28287047e-01,
138-
-1.55214879e-02, 3.25294234e-01, 1.06120585e+00, -5.09891282e-01,
139-
1.25789335e+00, 1.01808348e+00, -9.92590713e-01, 1.72832932e+00,
140-
1.12232980e+00, 8.54801892e-01, 1.41534752e+00, 3.50798405e-01,
141-
3.69381623e-01, 1.48608411e+00, -1.15506310e-02, 1.57066360e+00,
142-
2.00747378e-01, 4.47219763e-01, 5.57720524e-01, -7.74295353e-02,
143-
1.79192501e+00, 7.66510475e-01, 1.38852488e+00, -4.06055122e-01,
144-
2.73203156e-01, 3.61014687e-01, 1.23574043e+00, 1.64565746e-01,
145-
-9.89896480e-02, 9.26130265e-02, 1.06440134e+00, -1.55890408e-01,
146-
4.47131846e-01, -7.59186008e-01, -1.50881256e+00, -2.13928005e-01,
147-
-4.19160151e-01, 1.75815544e+00, 7.45423008e-01, 6.94781506e-01,
148-
1.58596346e+00, 1.75508724e+00, 4.56070434e-01, 2.94128709e-02,
149-
1.17703970e+00, -9.90230827e-02, 8.42796845e-01, 1.79154944e+00,
150-
5.92779197e-01, 2.73562285e-01, 1.61597907e+00, 1.23514403e+00,
151-
4.86261080e-01, -3.10434934e-01, 5.57873722e-01, 6.50365217e-01,
152-
-3.41009850e-01, 9.26851109e-01, 8.28936486e-01, 9.16180689e-02,
153-
1.30226405e+00, 3.73945789e-01, 6.04560122e-02, 6.00698708e-01,
154-
9.68764731e-02, 1.41904148e+00, 6.94182961e-03, 3.17504138e-01,
155-
5.90956041e-01, -5.78113887e-01, 5.26615565e-01, -4.19715252e-01,
156-
8.92891364e-01, 1.30207363e-01, 4.19899637e-01, 7.10275704e-01,
157-
9.27418179e-02, 1.85758044e+00, 4.76988907e-01, -1.36341398e-01]),
108+
SMC: np.array([ 0.40152748, -0.1440789 , 1.87105436, 1.65027354, 0.78140894,
109+
-0.33437271, 0.55987446, 1.05976848, 0.52126327, 0.5295624 ,
110+
-0.7120724 , 0.39250673, 0.92590897, 0.776836 , 0.30528805,
111+
1.32178809, 1.30972392, 0.77107019, 1.11885364, 0.59633151,
112+
0.63584096, -0.29117982, 0.97372731, 1.06270256, 0.87424729,
113+
0.49249202, -0.55942483, -0.17608982, 0.47118016, 1.0026767 ,
114+
1.42476886, 1.16505966, 0.71572226, 1.14267914, -0.27628211,
115+
0.66712824, 0.58322462, 0.28193361, 0.30175522, -0.11615552,
116+
-0.02127047, 0.01085484, 1.21229396, 0.50109798, 0.2046552 ,
117+
0.95648093, 0.26673391, -0.703456 , 1.23223409, -0.87686456,
118+
1.45480993, 1.04172093, 1.73512969, 1.00835375, 0.56551883,
119+
0.43457948, 1.85267864, 0.51961398, 0.20641743, 0.70484816,
120+
1.04491792, -0.70236338, 1.47248532, 0.57438209, -0.15590465,
121+
0.51528505, 1.49158593, 0.02418851, -0.04563402, 1.50712686,
122+
1.01211014, -0.1058956 , 1.91153929, 1.09281243, 0.78028316,
123+
0.08148316, 0.3989925 , 0.30230531, 1.59469562, -0.53948736,
124+
-0.35653048, 0.44440402, 1.02983002, 0.05184227, 0.78152799,
125+
0.99204159, 0.44148902, -0.12657838, 0.97114256, 0.67963455,
126+
1.33757129, 0.71977859, 0.09706076, -0.13609892, -0.39969385,
127+
0.04687582, 0.053386 , 0.33382962, -0.36082645, 0.86597207,
128+
0.09824643, -0.85212079, 0.54518473, -0.26622955, 0.71836765,
129+
0.81359943, 1.39550066, 0.25118273, 1.03965837, -0.65995684,
130+
-0.25522586, 2.12497766, 0.69534904, 0.74613619, -0.10312994,
131+
1.3244944 , -0.036056 , 0.90976629, 0.49647046, 0.80779428,
132+
0.18921903, -0.18365952, 0.56968353, -0.8232526 , -0.88612154,
133+
-0.47326386, 0.18939692, 0.2298177 , 0.65693251, 1.08908496,
134+
1.04748985, 0.53615771, -0.4611776 , 1.12076823, -0.79971572,
135+
1.78908277, 1.32673932, 1.43691077, 0.2564599 , 0.08480867,
136+
0.26340606, -0.86864626, 1.05716355, 0.18611255, 0.44701292,
137+
-0.06966819, 0.3325726 , 0.94594745, -0.0904025 , 0.14349182,
138+
0.83638941, 0.57657934, 0.9549692 , -0.18496471, 0.87838048,
139+
0.66938294, 0.54401984, 0.47804147, 0.32545637, -0.82626784,
140+
0.93390148, 0.39170683, -0.22244643, 0.36576256, 0.62426937,
141+
-0.16594267, 1.55050592, 0.60508809, -1.28925325, 1.1470063 ,
142+
0.71030941, 1.20896922, 1.23267962, 0.67278808, 0.5846423 ,
143+
-0.09343583, -0.28323718, 0.87891542, 0.54779014, 0.17131075,
144+
1.02287448, 0.61819842, 1.28724788, 0.641085 , 1.48324063,
145+
-1.68770188, 0.03750369, 0.47352403, 0.22929128, 0.637757 ,
146+
0.61735636, 0.17260147, 1.10929764, -0.33766643, 0.27064342,
147+
-0.54594464, -1.23229206, -0.18328842, -0.78636148, 1.38189874]),
158148
}
159149

160150
def setup_class(self):
@@ -202,7 +192,6 @@ def check_trace(self, step_method):
202192
trace = sample(0, tune=n_steps,
203193
discard_tuned_samples=False,
204194
step=step_method(), random_seed=1, chains=1)
205-
206195
assert_array_almost_equal(
207196
trace['x'],
208197
self.master_samples[step_method],
@@ -428,7 +417,7 @@ def test_bad_init(self):
428417
with Model():
429418
HalfNormal('a', sd=1, testval=-1, transform=None)
430419
with pytest.raises(ValueError) as error:
431-
sample(init=None)
420+
sample(init=None, cores=1)
432421
error.match('Bad initial')
433422

434423
def test_linalg(self, caplog):

0 commit comments

Comments
 (0)