1
- from theano .gof .graph import inputs
1
+ import itertools
2
+
3
+ from theano .gof .graph import ancestors
2
4
3
5
from .util import get_default_varnames
4
6
import pymc3 as pm
5
7
6
8
9
+ def powerset (iterable ):
10
+ """All *nonempty* subsets of an iterable.
11
+
12
+ From itertools docs.
13
+
14
+ powerset([1,2,3]) --> (1,) (2,) (3,) (1,2) (1,3) (2,3) (1,2,3)
15
+ """
16
+ s = list (iterable )
17
+ return itertools .chain .from_iterable (itertools .combinations (s , r ) for r in range (1 , len (s )+ 1 ))
18
+
19
+
7
20
class ModelGraph (object ):
8
21
def __init__ (self , model ):
9
22
self .model = model
@@ -21,30 +34,29 @@ def get_deterministics(self, var):
21
34
deterministics .append (v )
22
35
return deterministics
23
36
24
- def _inputs (self , var , func , blockers = None ):
25
- """Get inputs to a function that are also named PyMC3 variables"""
26
- return set ([j for j in inputs ([func ], blockers = blockers ) if j in self .var_list and j != var ])
37
+ def _ancestors (self , var , func , blockers = None ):
38
+ """Get ancestors of a function that are also named PyMC3 variables"""
39
+ return set ([j for j in ancestors ([func ], blockers = blockers ) if j in self .var_list and j != var ])
27
40
28
- def _get_inputs (self , var , func ):
29
- """Get all inputs to a function, doing some accounting for deterministics
41
+ def _get_ancestors (self , var , func ):
42
+ """Get all ancestors of a function, doing some accounting for deterministics
30
43
31
- Specifically, if a deterministic is an input, theano.gof.graph.inputs will
44
+ Specifically, if a deterministic is an input, theano.gof.graph.ancestors will
32
45
return only the inputs *to the deterministic*. However, if we pass in the
33
46
deterministic as a blocker, it will skip those nodes.
34
47
"""
35
48
deterministics = self .get_deterministics (var )
36
- upstream = self ._inputs (var , func )
37
- parents = self ._inputs (var , func , blockers = deterministics )
38
- if parents != upstream :
39
- det_map = {}
40
- for d in deterministics :
41
- d_set = {j for j in inputs ([func ], blockers = [d ])}
42
- if upstream - d_set :
43
- det_map [d ] = d_set
44
- for d , d_set in det_map .items ():
45
- if all (d_set .issubset (other ) for other in det_map .values ()):
46
- parents .add (d )
47
- return parents
49
+ upstream = self ._ancestors (var , func )
50
+
51
+ # Usual case
52
+ if upstream == self ._ancestors (var , func , blockers = upstream ):
53
+ return upstream
54
+ else : # deterministic accounting
55
+ for d in powerset (upstream ):
56
+ blocked = self ._ancestors (var , func , blockers = d )
57
+ if set (d ) == blocked :
58
+ return d
59
+ raise RuntimeError ('Could not traverse graph. Consider raising an issue with developers.' )
48
60
49
61
def _filter_parents (self , var , parents ):
50
62
"""Get direct parents of a var, as strings"""
@@ -70,7 +82,7 @@ def get_parents(self, var):
70
82
else :
71
83
func = var
72
84
73
- parents = self ._get_inputs (var , func )
85
+ parents = self ._get_ancestors (var , func )
74
86
return self ._filter_parents (var , parents )
75
87
76
88
def make_compute_graph (self ):
0 commit comments