Skip to content

Commit 4e33b32

Browse files
ColCarrolltwiecki
authored andcommitted
Update to more accurate way of calculating ancestors
1 parent fd4c71d commit 4e33b32

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

pymc3/model_graph.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1-
from theano.gof.graph import inputs
1+
import itertools
2+
3+
from theano.gof.graph import ancestors
24

35
from .util import get_default_varnames
46
import pymc3 as pm
57

68

9+
def powerset(iterable):
10+
"""All *nonempty* subsets of an iterable.
11+
12+
From itertools docs.
13+
14+
powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
15+
"""
16+
s = list(iterable)
17+
return itertools.chain.from_iterable(itertools.combinations(s, r) for r in range(1, len(s)+1))
18+
19+
720
class ModelGraph(object):
821
def __init__(self, model):
922
self.model = model
@@ -21,30 +34,29 @@ def get_deterministics(self, var):
2134
deterministics.append(v)
2235
return deterministics
2336

24-
def _inputs(self, var, func, blockers=None):
25-
"""Get inputs to a function that are also named PyMC3 variables"""
26-
return set([j for j in inputs([func], blockers=blockers) if j in self.var_list and j != var])
37+
def _ancestors(self, var, func, blockers=None):
38+
"""Get ancestors of a function that are also named PyMC3 variables"""
39+
return set([j for j in ancestors([func], blockers=blockers) if j in self.var_list and j != var])
2740

28-
def _get_inputs(self, var, func):
29-
"""Get all inputs to a function, doing some accounting for deterministics
41+
def _get_ancestors(self, var, func):
42+
"""Get all ancestors of a function, doing some accounting for deterministics
3043
31-
Specifically, if a deterministic is an input, theano.gof.graph.inputs will
44+
Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
3245
return only the inputs *to the deterministic*. However, if we pass in the
3346
deterministic as a blocker, it will skip those nodes.
3447
"""
3548
deterministics = self.get_deterministics(var)
36-
upstream = self._inputs(var, func)
37-
parents = self._inputs(var, func, blockers=deterministics)
38-
if parents != upstream:
39-
det_map = {}
40-
for d in deterministics:
41-
d_set = {j for j in inputs([func], blockers=[d])}
42-
if upstream - d_set:
43-
det_map[d] = d_set
44-
for d, d_set in det_map.items():
45-
if all(d_set.issubset(other) for other in det_map.values()):
46-
parents.add(d)
47-
return parents
49+
upstream = self._ancestors(var, func)
50+
51+
# Usual case
52+
if upstream == self._ancestors(var, func, blockers=upstream):
53+
return upstream
54+
else: # deterministic accounting
55+
for d in powerset(upstream):
56+
blocked = self._ancestors(var, func, blockers=d)
57+
if set(d) == blocked:
58+
return d
59+
raise RuntimeError('Could not traverse graph. Consider raising an issue with developers.')
4860

4961
def _filter_parents(self, var, parents):
5062
"""Get direct parents of a var, as strings"""
@@ -70,7 +82,7 @@ def get_parents(self, var):
7082
else:
7183
func = var
7284

73-
parents = self._get_inputs(var, func)
85+
parents = self._get_ancestors(var, func)
7486
return self._filter_parents(var, parents)
7587

7688
def make_compute_graph(self):

0 commit comments

Comments
 (0)