Skip to content

Commit 189ba03

Browse files
committed
Faster vectorize by walking sorted nodes
1 parent e2d0751 commit 189ba03

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

pytensor/graph/replace.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,13 @@
33
from functools import partial, singledispatch
44
from typing import Optional, Union, cast, overload
55

6-
from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs
6+
from pytensor.graph.basic import (
7+
Apply,
8+
Constant,
9+
Variable,
10+
io_toposort,
11+
truncated_graph_inputs,
12+
)
713
from pytensor.graph.fg import FunctionGraph
814
from pytensor.graph.op import Op
915

@@ -295,19 +301,14 @@ def vectorize_graph(
295301
inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
296302
new_inputs = [replace.get(inp, inp) for inp in inputs]
297303

298-
def transform(var: Variable) -> Variable:
299-
if var in inputs:
300-
return new_inputs[inputs.index(var)]
304+
vect_vars = dict(zip(inputs, new_inputs))
305+
for node in io_toposort(inputs, seq_outputs):
306+
vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs]
307+
vect_node = vectorize_node(node, *vect_inputs)
308+
for output, vect_output in zip(node.outputs, vect_node.outputs):
309+
vect_vars[output] = vect_output
301310

302-
node = var.owner
303-
batched_inputs = [transform(inp) for inp in node.inputs]
304-
batched_node = vectorize_node(node, *batched_inputs)
305-
batched_var = batched_node.outputs[var.owner.outputs.index(var)]
306-
307-
return cast(Variable, batched_var)
308-
309-
# TODO: MergeOptimization or node caching?
310-
seq_vect_outputs = [transform(out) for out in seq_outputs]
311+
seq_vect_outputs = [vect_vars[out] for out in seq_outputs]
311312

312313
if isinstance(outputs, Sequence):
313314
return seq_vect_outputs

0 commit comments

Comments
 (0)