Skip to content

Commit 5e08cbe

Browse files
committed
Filter non-parameter inputs of RandomVariables in model_graph
This removes visual dependencies between observed data and likelihood, due to flow of shape information
1 parent bb15dbc commit 5e08cbe

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

pymc/model_graph.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,25 @@ def get_parent_names(self, var: TensorVariable) -> Set[VarName]:
4242
if var.owner is None or var.owner.inputs is None:
4343
return set()
4444

45+
def _filter_non_parameter_inputs(var):
46+
node = var.owner
47+
if isinstance(node.op, RandomVariable):
48+
# Filter out rng, dtype and size parameters or RandomVariable nodes
49+
return node.inputs[3:]
50+
else:
51+
# Otherwise return all inputs
52+
return node.inputs
53+
4554
def _expand(x):
4655
if x.name:
4756
return [x]
4857
if isinstance(x.owner, Apply):
49-
return reversed(x.owner.inputs)
58+
return reversed(_filter_non_parameter_inputs(x))
5059
return []
5160

5261
parents = {
5362
get_var_name(x)
54-
for x in walk(nodes=var.owner.inputs, expand=_expand)
63+
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand)
5564
# Only consider nodes that are in the named model variables.
5665
if x.name and x.name in self._all_var_names
5766
}

0 commit comments

Comments
 (0)