Skip to content

Commit 422ac3c

Browse files
Solve infinite looping error for scalar variable
1 parent 36645b7 commit 422ac3c

File tree

1 file changed

+10
-2
lines changed

1 file changed

+10
-2
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,10 @@ def local_replace_slice(fgraph, node):
413413
idx_flag = True
414414
start = None
415415

416-
if extract_constant(stop, only_process_constants=True) == x.type.shape[dim]:
416+
if (
417+
x.type.shape[dim] is not None
418+
and extract_constant(stop, only_process_constants=True) == x.type.shape[dim]
419+
):
417420
idx_flag = True
418421
stop = None
419422

@@ -423,8 +426,13 @@ def local_replace_slice(fgraph, node):
423426

424427
new_idxs[dim] = slice(start, stop, step)
425428

426-
assert node.outputs[0].type == x[tuple(new_idxs)].type
427429
if idx_flag is True:
430+
try:
431+
assert node.outputs[0].type == x[tuple(new_idxs)].type
432+
except AssertionError:
433+
raise AssertionError(
434+
f"The type of the returned variable {x[tuple(new_idxs)].type} did not match with the type of the original variable {node.outputs[0].type}"
435+
)
428436
return [x[tuple(new_idxs)]]
429437

430438

0 commit comments

Comments
 (0)