1
1
from collections import deque
2
- from typing import Iterator , Optional , MutableSet
2
+ from typing import Dict , Iterator , Set , Optional
3
+
4
+ VarName = str
3
5
4
6
from theano .gof .graph import stack_search
5
7
from theano .compile import SharedVariable
6
8
from theano .tensor import Tensor
7
9
8
10
from .util import get_default_varnames
11
+ from .model import ObservedRV
9
12
import pymc3 as pm
10
13
11
- # this is a placeholder for a better characterization of the type
12
- # of variables in a model.
13
- RV = Tensor
14
-
15
14
16
15
class ModelGraph :
17
16
def __init__ (self , model ):
@@ -30,16 +29,16 @@ def get_deterministics(self, var):
30
29
deterministics .append (v )
31
30
return deterministics
32
31
33
- def _get_ancestors (self , var , func ) -> MutableSet [ RV ]:
32
+ def _get_ancestors (self , var : Tensor , func ) -> Set [ Tensor ]:
34
33
"""Get all ancestors of a function, doing some accounting for deterministics.
35
34
"""
36
35
37
36
# this contains all of the variables in the model EXCEPT var...
38
37
vars = set (self .var_list )
39
38
vars .remove (var )
40
39
41
- blockers = set ()
42
- retval = set ()
40
+ blockers = set () # type: Set[Tensor]
41
+ retval = set () # type: Set[Tensor]
43
42
def _expand (node ) -> Optional [Iterator [Tensor ]]:
44
43
if node in blockers :
45
44
return None
@@ -58,9 +57,9 @@ def _expand(node) -> Optional[Iterator[Tensor]]:
58
57
mode = 'bfs' )
59
58
return retval
60
59
61
- def _filter_parents (self , var , parents ):
60
+ def _filter_parents (self , var , parents ) -> Set [ VarName ] :
62
61
"""Get direct parents of a var, as strings"""
63
- keep = set ()
62
+ keep = set () # type: Set[VarName]
64
63
for p in parents :
65
64
if p == var :
66
65
continue
@@ -73,7 +72,7 @@ def _filter_parents(self, var, parents):
73
72
raise AssertionError ('Do not know what to do with {}' .format (str (p )))
74
73
return keep
75
74
76
- def get_parents (self , var ) :
75
+ def get_parents (self , var : Tensor ) -> Set [ VarName ] :
77
76
"""Get the named nodes that are direct inputs to the var"""
78
77
if hasattr (var , 'transformed' ):
79
78
func = var .transformed .logpt
@@ -85,11 +84,26 @@ def get_parents(self, var):
85
84
parents = self ._get_ancestors (var , func )
86
85
return self ._filter_parents (var , parents )
87
86
88
- def make_compute_graph (self ):
87
+ def make_compute_graph (self ) -> Dict [ str , Set [ VarName ]] :
89
88
"""Get map of var_name -> set(input var names) for the model"""
90
- input_map = {}
89
+ input_map = {} # type: Dict[str, Set[VarName]]
90
+ def update_input_map (key : str , val : Set [VarName ]):
91
+ if key in input_map :
92
+ input_map [key ] = input_map [key ].union (val )
93
+ else :
94
+ input_map [key ] = val
95
+
91
96
for var_name in self .var_names :
92
- input_map [var_name ] = self .get_parents (self .model [var_name ])
97
+ var = self .model [var_name ]
98
+ update_input_map (var_name , self .get_parents (var ))
99
+ if isinstance (var , ObservedRV ):
100
+ try :
101
+ obs_name = var .observations .name
102
+ if obs_name :
103
+ input_map [var_name ] = input_map [var_name ].difference (set ([obs_name ]))
104
+ update_input_map (obs_name , set ([var_name ]))
105
+ except AttributeError :
106
+ pass
93
107
return input_map
94
108
95
109
def _make_node (self , var_name , graph ):
@@ -101,12 +115,19 @@ def _make_node(self, var_name, graph):
101
115
if isinstance (v , pm .model .ObservedRV ):
102
116
attrs ['style' ] = 'filled'
103
117
118
+ # make Data be roundtangle, instead of rectangle
104
119
if isinstance (v , SharedVariable ):
105
- attrs ['style' ] = 'filled'
120
+ attrs ['style' ] = 'rounded, filled'
106
121
107
122
# Get name for node
108
- if hasattr (v , 'distribution' ):
123
+ if v in self .model .potentials :
124
+ distribution = 'Potential'
125
+ attrs ['shape' ] = 'octagon'
126
+ elif hasattr (v , 'distribution' ):
109
127
distribution = v .distribution .__class__ .__name__
128
+ elif isinstance (v , SharedVariable ):
129
+ distribution = 'Data'
130
+ attrs ['shape' ] = 'box'
110
131
else :
111
132
distribution = 'Deterministic'
112
133
attrs ['shape' ] = 'box'
0 commit comments