Skip to content

Commit 37b0387

Browse files
committed
Change order of arguments in get_related_valued_nodes
1 parent 5352798 commit 37b0387

File tree

5 files changed

+6
-6
lines changed

5 files changed

+6
-6
lines changed

pymc/logprob/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def conditional_logp(
503503
if not isinstance(node.op, MeasurableOp):
504504
continue
505505

506-
valued_nodes = get_related_valued_nodes(node, fgraph)
506+
valued_nodes = get_related_valued_nodes(fgraph, node)
507507

508508
if not valued_nodes:
509509
continue

pymc/logprob/mixture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -468,7 +468,7 @@ def split_valued_ifelse(fgraph, node):
468468
# Single outputs IfElse
469469
return None
470470

471-
valued_output_nodes = get_related_valued_nodes(node, fgraph)
471+
valued_output_nodes = get_related_valued_nodes(fgraph, node)
472472
if not valued_output_nodes:
473473
return None
474474

pymc/logprob/scan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def find_measurable_scans(fgraph, node):
421421
# Find outputs of scan that are directly valued.
422422
# These must be mapping outputs, such as `outputs_info = [None]` (i.e, no recurrence nit_sot outputs)
423423
direct_valued_outputs = [
424-
valued_node.inputs[0] for valued_node in get_related_valued_nodes(node, fgraph)
424+
valued_node.inputs[0] for valued_node in get_related_valued_nodes(fgraph, node)
425425
]
426426
if not all(valued_out in scan_args.outer_out_nit_sot for valued_out in direct_valued_outputs):
427427
return None
@@ -434,7 +434,7 @@ def find_measurable_scans(fgraph, node):
434434
client.outputs[0]
435435
for out in node.outputs
436436
for client, _ in fgraph.clients[out]
437-
if (isinstance(client.op, Subtensor) and get_related_valued_nodes(client, fgraph))
437+
if (isinstance(client.op, Subtensor) and get_related_valued_nodes(fgraph, client))
438438
]
439439
indirect_valued_outputs = [out.owner.inputs[0] for out in sliced_valued_outputs]
440440
if not all(

pymc/logprob/transform_value.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def transform_values(fgraph: FunctionGraph, node: Apply) -> list[Apply] | None:
147147
return None
148148

149149
rv_node = node.inputs[0].owner
150-
valued_nodes = get_related_valued_nodes(rv_node, fgraph)
150+
valued_nodes = get_related_valued_nodes(fgraph, rv_node)
151151
rvs = [valued_var.inputs[0] for valued_var in valued_nodes]
152152
values = [valued_var.inputs[1] for valued_var in valued_nodes]
153153
transforms = [values_to_transforms.get(value, None) for value in values]

pymc/logprob/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ def find_negated_var(var):
320320
return None
321321

322322

323-
def get_related_valued_nodes(node: Apply, fgraph: FunctionGraph) -> list[Apply]:
323+
def get_related_valued_nodes(fgraph: FunctionGraph, node: Apply) -> list[Apply]:
324324
"""Get all ValuedVars related to the same RV node.
325325
326326
Returns

0 commit comments

Comments
 (0)