@@ -531,7 +531,7 @@ def numba_funcify_CAReduce(op, node, **kwargs):
531
531
532
532
533
533
@numba_funcify .register (DimShuffle )
534
- def numba_funcify_DimShuffle (op , ** kwargs ):
534
+ def numba_funcify_DimShuffle (op , node , ** kwargs ):
535
535
shuffle = tuple (op .shuffle )
536
536
transposition = tuple (op .transposition )
537
537
augment = tuple (op .augment )
@@ -560,16 +560,26 @@ def transpose(x):
560
560
# To avoid this compile-time error, we omit the expression altogether.
561
561
if len (shuffle ) > 0 :
562
562
563
- @numba_basic .numba_njit
564
- def find_shape (array_shape ):
565
- shape = shape_template
566
- j = 0
567
- for i in range (ndim_new_shape ):
568
- if i not in augment :
569
- length = array_shape [j ]
570
- shape = numba_basic .tuple_setitem (shape , i , length )
571
- j = j + 1
572
- return shape
563
+ # Use the statically known shape if available
564
+ if all (length is not None for length in node .outputs [0 ].type .shape ):
565
+ shape = node .outputs [0 ].type .shape
566
+
567
+ @numba_basic .numba_njit
568
+ def find_shape (array_shape ):
569
+ return shape
570
+
571
+ else :
572
+
573
+ @numba_basic .numba_njit
574
+ def find_shape (array_shape ):
575
+ shape = shape_template
576
+ j = 0
577
+ for i in range (ndim_new_shape ):
578
+ if i not in augment :
579
+ length = array_shape [j ]
580
+ shape = numba_basic .tuple_setitem (shape , i , length )
581
+ j = j + 1
582
+ return shape
573
583
574
584
else :
575
585
0 commit comments