Skip to content

Commit 285e93e

Browse files
committed
Adding workaround so more graphs with Blockwise(Scans) can be
vectorized
1 parent 92eef5e commit 285e93e

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

pytensor/scan/op.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1500,6 +1500,10 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
15001500
node_input_storage = [storage_map[r] for r in node.inputs]
15011501
node_output_storage = [storage_map[r] for r in node.outputs]
15021502

1503+
# HACK: Here to handle Blockwise Scans
1504+
if compute_map is None:
1505+
compute_map = {out: [False] for out in node.outputs}
1506+
15031507
# Analyse the compile inner function to determine which inputs and
15041508
# outputs are on the gpu and speed up some checks during the execution
15051509
outs_is_tensor = [

tests/scan/test_basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytensor.compile.sharedvalue import shared
2828
from pytensor.configdefaults import config
2929
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
30+
from pytensor.graph import vectorize_graph
3031
from pytensor.graph.basic import Apply, ancestors, equal_computations
3132
from pytensor.graph.fg import FunctionGraph
3233
from pytensor.graph.op import Op
@@ -1178,6 +1179,17 @@ def get_sum_of_grad(input0, input1):
11781179

11791180
utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng)
11801181

1182+
def test_blockwise_scan(self):
1183+
x = pt.tensor("x", shape=())
1184+
out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10)
1185+
x_vec = pt.tensor("x_vec", shape=(None,))
1186+
out_vec = vectorize_graph(out, {x: x_vec})
1187+
1188+
fn = function([x_vec], out_vec)
1189+
o1 = fn([1, 2, 3])
1190+
o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1)
1191+
assert np.allclose(o1, o2)
1192+
11811193
def test_connection_pattern(self):
11821194
"""Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps."""
11831195

0 commit comments

Comments
 (0)