File tree Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Expand file tree Collapse file tree 1 file changed +11
-2
lines changed Original file line number Diff line number Diff line change @@ -42,16 +42,25 @@ def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
42
42
if var .owner is None or var .owner .inputs is None :
43
43
return set ()
44
44
45
+ def _filter_non_parameter_inputs (var ):
46
+ node = var .owner
47
+ if isinstance (node .op , RandomVariable ):
48
+ # Filter out rng, dtype and size parameters or RandomVariable nodes
49
+ return node .inputs [3 :]
50
+ else :
51
+ # Otherwise return all inputs
52
+ return node .inputs
53
+
45
54
def _expand (x ):
46
55
if x .name :
47
56
return [x ]
48
57
if isinstance (x .owner , Apply ):
49
- return reversed (x . owner . inputs )
58
+ return reversed (_filter_non_parameter_inputs ( x ) )
50
59
return []
51
60
52
61
parents = {
53
62
get_var_name (x )
54
- for x in walk (nodes = var . owner . inputs , expand = _expand )
63
+ for x in walk (nodes = _filter_non_parameter_inputs ( var ) , expand = _expand )
55
64
# Only consider nodes that are in the named model variables.
56
65
if x .name and x .name in self ._all_var_names
57
66
}
You can’t perform that action at this time.
0 commit comments