Skip to content

Commit d9fe197

Browse files
committed
Use static shape info in numba DimShuffle
1 parent db7ae4d commit d9fe197

File tree

1 file changed

+21
-11
lines changed

1 file changed

+21
-11
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
531531

532532

533533
@numba_funcify.register(DimShuffle)
534-
def numba_funcify_DimShuffle(op, **kwargs):
534+
def numba_funcify_DimShuffle(op, node, **kwargs):
535535
shuffle = tuple(op.shuffle)
536536
transposition = tuple(op.transposition)
537537
augment = tuple(op.augment)
@@ -560,16 +560,26 @@ def transpose(x):
560560
# To avoid this compile-time error, we omit the expression altogether.
561561
if len(shuffle) > 0:
562562

563-
@numba_basic.numba_njit
564-
def find_shape(array_shape):
565-
shape = shape_template
566-
j = 0
567-
for i in range(ndim_new_shape):
568-
if i not in augment:
569-
length = array_shape[j]
570-
shape = numba_basic.tuple_setitem(shape, i, length)
571-
j = j + 1
572-
return shape
563+
# Use the statically known shape if available
564+
if all(length is not None for length in node.outputs[0].type.shape):
565+
shape = node.outputs[0].type.shape
566+
567+
@numba_basic.numba_njit
568+
def find_shape(array_shape):
569+
return shape
570+
571+
else:
572+
573+
@numba_basic.numba_njit
574+
def find_shape(array_shape):
575+
shape = shape_template
576+
j = 0
577+
for i in range(ndim_new_shape):
578+
if i not in augment:
579+
length = array_shape[j]
580+
shape = numba_basic.tuple_setitem(shape, i, length)
581+
j = j + 1
582+
return shape
573583

574584
else:
575585

0 commit comments

Comments
 (0)