@@ -650,3 +650,133 @@ func @no_fold_pad_fill_value_mismatch() -> tensor<412x276xf32> {
650
650
} : tensor <400 x273 xf32 > to tensor <412 x276 xf32 >
651
651
return %pad : tensor <412 x276 xf32 >
652
652
}
653
+
654
+ // -----
655
+
656
+ // Tests below verify whether static information is propagated through all the operands of generic op.
657
+ // 1. If one of the inputs of generic op has static info and it has no cast source.
658
+ // 2. If one of the inputs of generic op has static info and it is coming from tensr.cast operation.
659
+ // 3. If one of the outputs of generic op has static info and it is coming from tenso.cast operation.
660
+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
661
+ // CHECK-LABEL: func @static_input_without_cast
662
+ // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
663
+ func @static_input_without_cast (%arg0 : tensor <2 x3 x4 xf32 >, %arg1: tensor <?x?x?xf32 >) -> tensor <2 x3 x4 xf32 > {
664
+ %c0 = arith.constant 0 : index
665
+ %c1 = arith.constant 1 : index
666
+ %c2 = arith.constant 2 : index
667
+ %0 = tensor.dim %arg0 , %c0 : tensor <2 x3 x4 xf32 >
668
+ %1 = tensor.dim %arg0 , %c1 : tensor <2 x3 x4 xf32 >
669
+ %2 = tensor.dim %arg0 , %c2 : tensor <2 x3 x4 xf32 >
670
+ %3 = linalg.init_tensor [%0 , %1 , %2 ] : tensor <?x?x?xf32 >
671
+ %4 = linalg.generic {
672
+ indexing_maps = [#map , #map , #map ],
673
+ iterator_types = [" parallel" , " parallel" , " parallel" ]
674
+ } ins (%arg0 , %arg1 : tensor <2 x3 x4 xf32 >, tensor <?x?x?xf32 >)
675
+ outs (%3 : tensor <?x?x?xf32 >) {
676
+ ^bb0 (%arg2 : f32 , %arg3 : f32 , %arg4 : f32 ):
677
+ %9 = arith.addf %arg2 , %arg3 : f32
678
+ linalg.yield %9 : f32
679
+ } -> (tensor <?x?x?xf32 >)
680
+ %5 = tensor.cast %4 : tensor <?x?x?xf32 > to tensor <2 x3 x4 xf32 >
681
+ return %5 : tensor <2 x3 x4 xf32 >
682
+ // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
683
+ // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
684
+ // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
685
+ // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
686
+ }
687
+
688
+ // -----
689
+
690
+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
691
+ // CHECK-LABEL: func @static_input_with_cast
692
+ // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>) -> tensor<2x3x4xf32> {
693
+ func @static_input_with_cast (%arg0 : tensor <2 x3 x4 xf32 >, %arg1: tensor <?x?x?xf32 >) -> tensor <2 x3 x4 xf32 > {
694
+ %c0 = arith.constant 0 : index
695
+ %c1 = arith.constant 1 : index
696
+ %c2 = arith.constant 2 : index
697
+ %0 = tensor.dim %arg0 , %c0 : tensor <2 x3 x4 xf32 >
698
+ %1 = tensor.dim %arg0 , %c1 : tensor <2 x3 x4 xf32 >
699
+ %2 = tensor.dim %arg0 , %c2 : tensor <2 x3 x4 xf32 >
700
+ %3 = linalg.init_tensor [%0 , %1 , %2 ] : tensor <?x?x?xf32 >
701
+ %4 = tensor.cast %arg1 : tensor <?x?x?xf32 > to tensor <2 x?x?xf32 >
702
+ %5 = linalg.generic {
703
+ indexing_maps = [#map , #map , #map ],
704
+ iterator_types = [" parallel" , " parallel" , " parallel" ]
705
+ } ins (%arg0 , %4 : tensor <2 x3 x4 xf32 >, tensor <2 x?x?xf32 >)
706
+ outs (%3 : tensor <?x?x?xf32 >) {
707
+ ^bb0 (%arg2 : f32 , %arg3 : f32 , %arg4 : f32 ):
708
+ %9 = arith.addf %arg2 , %arg3 : f32
709
+ linalg.yield %9 : f32
710
+ } -> (tensor <?x?x?xf32 >)
711
+ %6 = tensor.cast %5 : tensor <?x?x?xf32 > to tensor <2 x3 x4 xf32 >
712
+ return %6: tensor <2 x3 x4 xf32 >
713
+ // CHECK: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
714
+ // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
715
+ // CHECK-SAME: ins(%[[ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
716
+ // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
717
+ }
718
+
719
+ // -----
720
+
721
+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
722
+ // CHECK-LABEL: func @static_output_with_cast
723
+ // CHECK-SAME: (%[[ARG0:.*]]: tensor<?x?x?xf32>, %[[ARG1:.*]]: tensor<?x?x?xf32>, %[[ARG2:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
724
+ func @static_output_with_cast (%arg0 : tensor <?x?x?xf32 >, %arg1: tensor <?x?x?xf32 >, %arg2: tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 > {
725
+ %c0 = arith.constant 0 : index
726
+ %c1 = arith.constant 1 : index
727
+ %c2 = arith.constant 2 : index
728
+ %0 = tensor.dim %arg2 , %c0 : tensor <2 x3 x4 xf32 >
729
+ %1 = tensor.dim %arg2 , %c1 : tensor <2 x3 x4 xf32 >
730
+ %2 = tensor.dim %arg2 , %c2 : tensor <2 x3 x4 xf32 >
731
+ %3 = linalg.init_tensor [%0 , %1 , %2 ] : tensor <?x?x?xf32 >
732
+ %4 = tensor.cast %3 : tensor <?x?x?xf32 > to tensor <2 x3 x4 xf32 >
733
+ %5 = tensor.cast %arg1 : tensor <?x?x?xf32 > to tensor <2 x?x?xf32 >
734
+ %6 = linalg.generic {
735
+ indexing_maps = [#map , #map , #map ],
736
+ iterator_types = [" parallel" , " parallel" , " parallel" ]
737
+ } ins (%arg0 , %5 : tensor <?x?x?xf32 >, tensor <2 x?x?xf32 >)
738
+ outs (%4 : tensor <2 x3 x4 xf32 >) {
739
+ ^bb0 (%arg3 : f32 , %arg4 : f32 , %arg5 : f32 ):
740
+ %9 = arith.addf %arg3 , %arg4 : f32
741
+ linalg.yield %9 : f32
742
+ } -> (tensor <2 x3 x4 xf32 >)
743
+ return %6: tensor <2 x3 x4 xf32 >
744
+ // CHECK: %[[CAST_ARG0:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
745
+ // CHECK-NEXT: %[[CAST_ARG1:.*]] = tensor.cast %[[ARG1]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
746
+ // CHECK-NEXT: %[[GENERIC_OP:.*]] = linalg.generic
747
+ // CHECK-SAME: ins(%[[CAST_ARG0]], %[[CAST_ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
748
+ // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
749
+ }
750
+
751
+ // -----
752
+
753
+ // This test checks the folding of tensor.cast operation when the source value of cast
754
+ // has more static information than the destination value.
755
+ #map = affine_map <(d0 , d1 , d2 ) -> (d0 , d1 , d2 )>
756
+ // CHECK-LABEL: func @cast_source
757
+ // CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3x4xf32>, %[[ARG1:.*]]: tensor<2x3x4xf32>) -> tensor<2x3x4xf32> {
758
+ func @cast_source (%arg0 : tensor <2 x3 x4 xf32 >, %arg1: tensor <2 x3 x4 xf32 >) -> tensor <2 x3 x4 xf32 > {
759
+ %c0 = arith.constant 0 : index
760
+ %c1 = arith.constant 1 : index
761
+ %c2 = arith.constant 2 : index
762
+ %0 = tensor.dim %arg0 , %c0 : tensor <2 x3 x4 xf32 >
763
+ %1 = tensor.dim %arg0 , %c1 : tensor <2 x3 x4 xf32 >
764
+ %2 = tensor.dim %arg0 , %c2 : tensor <2 x3 x4 xf32 >
765
+ %3 = linalg.init_tensor [%0 , %1 , %2 ] : tensor <?x?x?xf32 >
766
+ %4 = tensor.cast %arg0 : tensor <2 x3 x4 xf32 > to tensor <2 x?x?xf32 >
767
+ %5 = tensor.cast %arg1 : tensor <2 x3 x4 xf32 > to tensor <2 x?x?xf32 >
768
+ %6 = linalg.generic {
769
+ indexing_maps = [#map , #map , #map ],
770
+ iterator_types = [" parallel" , " parallel" , " parallel" ]
771
+ } ins (%4 , %5 : tensor <2 x?x?xf32 >, tensor <2 x?x?xf32 >)
772
+ outs (%3 : tensor <?x?x?xf32 >) {
773
+ ^bb0 (%arg2 : f32 , %arg3 : f32 , %arg4 : f32 ):
774
+ %9 = arith.addf %arg2 , %arg3 : f32
775
+ linalg.yield %9 : f32
776
+ } -> (tensor <?x?x?xf32 >)
777
+ %7 = tensor.cast %6 : tensor <?x?x?xf32 > to tensor <2 x3 x4 xf32 >
778
+ return %7: tensor <2 x3 x4 xf32 >
779
+ // CHECK: %[[GENERIC_OP:.*]] = linalg.generic
780
+ // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<2x3x4xf32>, tensor<2x3x4xf32>)
781
+ // CHECK-SAME: outs({{.*}} : tensor<2x3x4xf32>)
782
+ }
0 commit comments