|
3 | 3 | from functools import partial, singledispatch
|
4 | 4 | from typing import Optional, Union, cast, overload
|
5 | 5 |
|
6 |
| -from pytensor.graph.basic import Apply, Constant, Variable, truncated_graph_inputs |
| 6 | +from pytensor.graph.basic import ( |
| 7 | + Apply, |
| 8 | + Constant, |
| 9 | + Variable, |
| 10 | + io_toposort, |
| 11 | + truncated_graph_inputs, |
| 12 | +) |
7 | 13 | from pytensor.graph.fg import FunctionGraph
|
8 | 14 | from pytensor.graph.op import Op
|
9 | 15 |
|
@@ -295,19 +301,14 @@ def vectorize_graph(
|
295 | 301 | inputs = truncated_graph_inputs(seq_outputs, ancestors_to_include=replace.keys())
|
296 | 302 | new_inputs = [replace.get(inp, inp) for inp in inputs]
|
297 | 303 |
|
298 |
| - def transform(var: Variable) -> Variable: |
299 |
| - if var in inputs: |
300 |
| - return new_inputs[inputs.index(var)] |
| 304 | + vect_vars = dict(zip(inputs, new_inputs)) |
| 305 | + for node in io_toposort(inputs, seq_outputs): |
| 306 | + vect_inputs = [vect_vars.get(inp, inp) for inp in node.inputs] |
| 307 | + vect_node = vectorize_node(node, *vect_inputs) |
| 308 | + for output, vect_output in zip(node.outputs, vect_node.outputs): |
| 309 | + vect_vars[output] = vect_output |
301 | 310 |
|
302 |
| - node = var.owner |
303 |
| - batched_inputs = [transform(inp) for inp in node.inputs] |
304 |
| - batched_node = vectorize_node(node, *batched_inputs) |
305 |
| - batched_var = batched_node.outputs[var.owner.outputs.index(var)] |
306 |
| - |
307 |
| - return cast(Variable, batched_var) |
308 |
| - |
309 |
| - # TODO: MergeOptimization or node caching? |
310 |
| - seq_vect_outputs = [transform(out) for out in seq_outputs] |
| 311 | + seq_vect_outputs = [vect_vars[out] for out in seq_outputs] |
311 | 312 |
|
312 | 313 | if isinstance(outputs, Sequence):
|
313 | 314 | return seq_vect_outputs
|
|
0 commit comments