@@ -58,77 +58,6 @@ static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
58
58
: SmallVector<Value, 4 >(ivs.begin (), ivs.end ());
59
59
}
60
60
61
- // / Creates a number of ranges equal to the number of dimensions in the `map`.
62
- // / The returned ranges correspond to the loop ranges, in the proper order, for
63
- // / which new loops will be created.
64
- // / The function supports only maps that are invertible and have results of type
65
- // / DimExpr or (DimExpr + DimExpr - SymbolExpr floordiv ConstExpr).
66
- // / It expects a non-inverted, concatenated map and last values in
67
- // / allViewSizes will be applied to the symbols in the map if it contains any.
68
- static SmallVector<Range, 4 > emitLoopRanges (OpBuilder &b, Location loc,
69
- AffineMap map,
70
- ValueRange viewSizes) {
71
- unsigned numDims = map.getNumDims (), numRes = map.getNumResults ();
72
- unsigned numSym = map.getNumSymbols ();
73
- assert (viewSizes.size () == numRes + numSym &&
74
- " viewSizes must contain sizes of all views and values for symbols" );
75
- SmallVector<Range, 4 > res (numDims);
76
- for (unsigned idx = 0 ; idx < numRes; ++idx) {
77
- auto result = map.getResult (idx);
78
- if (auto d = result.dyn_cast <AffineDimExpr>()) {
79
- if (res[d.getPosition ()].offset )
80
- continue ;
81
- res[d.getPosition ()] =
82
- Range{std_constant_index (0 ), viewSizes[idx], std_constant_index (1 )};
83
- }
84
-
85
- // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
86
- // then the bounds are:
87
- // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
88
- // where size(n) is applied to the symbol s.
89
- // This is done statically now.
90
- if (auto binOp = result.dyn_cast <AffineBinaryOpExpr>()) {
91
- auto lhs = binOp.getLHS ().dyn_cast <AffineBinaryOpExpr>();
92
- auto rhs = binOp.getRHS ().dyn_cast <AffineBinaryOpExpr>();
93
- if (!lhs || !rhs || binOp.getKind () != AffineExprKind::Add ||
94
- lhs.getKind () != AffineExprKind::Add ||
95
- rhs.getKind () != mlir::AffineExprKind::Mul)
96
- continue ;
97
-
98
- auto m = lhs.getLHS ().dyn_cast <AffineDimExpr>();
99
- auto n = lhs.getRHS ().dyn_cast <AffineDimExpr>();
100
- auto fDiv = rhs.getLHS ().dyn_cast <AffineBinaryOpExpr>();
101
- auto minusOne = rhs.getRHS ().dyn_cast <AffineConstantExpr>();
102
- if (!m || !n || !fDiv || !minusOne ||
103
- fDiv .getKind () != AffineExprKind::FloorDiv ||
104
- fDiv .getLHS ().getKind () != AffineExprKind::SymbolId ||
105
- fDiv .getRHS ().getKind () != AffineExprKind::Constant)
106
- continue ;
107
-
108
- auto s = fDiv .getLHS ().dyn_cast <AffineSymbolExpr>();
109
- if (minusOne.getValue () != -1 )
110
- continue ;
111
-
112
- int mPos = m.getPosition ();
113
- AffineExpr one = getAffineConstantExpr (1 , s.getContext ());
114
- AffineExpr sizeOfM = getAffineSymbolExpr (numSym, s.getContext ());
115
- // Construction of upper bound (size(m) + s floordiv 2 - s + 1).
116
- AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
117
- AffineMap fromMap = AffineMap::get (numDims, numSym + 1 , fDiv );
118
- AffineMap toMap = AffineMap::get (numDims, numSym + 1 , upperOffsetExpr);
119
- SmallVector<Value, 8 > values (viewSizes.begin (),
120
- viewSizes.begin () + numDims);
121
- values.insert (values.end (), viewSizes.begin () + numRes, viewSizes.end ());
122
- values.push_back (viewSizes[mPos ]);
123
- // Construction of the lower bound (s floordiv 2).
124
- Value from = applyMapToValues (b, loc, fromMap, values).front ();
125
- Value to = applyMapToValues (b, loc, toMap, values).front ();
126
- res[mPos ] = Range{from, to, std_constant_index (1 )};
127
- }
128
- }
129
- return res;
130
- }
131
-
132
61
template <typename IndexedValueType, typename OpType>
133
62
static void inlineRegionAndEmitStore (OpType op, ArrayRef<Value> indexedValues,
134
63
ArrayRef<SmallVector<Value, 8 >> indexing,
@@ -708,6 +637,70 @@ static Optional<LinalgLoops> linalgOpToLoopsImplSwitch(Operation *op,
708
637
llvm_unreachable (" Unexpected op in linalgOpToLoopsImpl" );
709
638
}
710
639
640
+ SmallVector<Range, 4 > mlir::linalg::emitLoopRanges (OpBuilder &b, Location loc,
641
+ AffineMap map,
642
+ ValueRange viewSizes) {
643
+ unsigned numDims = map.getNumDims (), numRes = map.getNumResults ();
644
+ unsigned numSym = map.getNumSymbols ();
645
+ assert (viewSizes.size () == numRes + numSym &&
646
+ " viewSizes must contain sizes of all views and values for symbols" );
647
+ SmallVector<Range, 4 > res (numDims);
648
+ for (unsigned idx = 0 ; idx < numRes; ++idx) {
649
+ auto result = map.getResult (idx);
650
+ if (auto d = result.dyn_cast <AffineDimExpr>()) {
651
+ if (res[d.getPosition ()].offset )
652
+ continue ;
653
+ res[d.getPosition ()] =
654
+ Range{std_constant_index (0 ), viewSizes[idx], std_constant_index (1 )};
655
+ }
656
+
657
+ // If the access pattern is of form (m, n)[s] -> (m + n - s floordiv 2),
658
+ // then the bounds are:
659
+ // (s floordiv 2) <= m <= (size(m) + s floordiv 2 - s + 1).
660
+ // where size(n) is applied to the symbol s.
661
+ // This is done statically now.
662
+ if (auto binOp = result.dyn_cast <AffineBinaryOpExpr>()) {
663
+ auto lhs = binOp.getLHS ().dyn_cast <AffineBinaryOpExpr>();
664
+ auto rhs = binOp.getRHS ().dyn_cast <AffineBinaryOpExpr>();
665
+ if (!lhs || !rhs || binOp.getKind () != AffineExprKind::Add ||
666
+ lhs.getKind () != AffineExprKind::Add ||
667
+ rhs.getKind () != mlir::AffineExprKind::Mul)
668
+ continue ;
669
+
670
+ auto m = lhs.getLHS ().dyn_cast <AffineDimExpr>();
671
+ auto n = lhs.getRHS ().dyn_cast <AffineDimExpr>();
672
+ auto fDiv = rhs.getLHS ().dyn_cast <AffineBinaryOpExpr>();
673
+ auto minusOne = rhs.getRHS ().dyn_cast <AffineConstantExpr>();
674
+ if (!m || !n || !fDiv || !minusOne ||
675
+ fDiv .getKind () != AffineExprKind::FloorDiv ||
676
+ fDiv .getLHS ().getKind () != AffineExprKind::SymbolId ||
677
+ fDiv .getRHS ().getKind () != AffineExprKind::Constant)
678
+ continue ;
679
+
680
+ auto s = fDiv .getLHS ().dyn_cast <AffineSymbolExpr>();
681
+ if (minusOne.getValue () != -1 )
682
+ continue ;
683
+
684
+ int mPos = m.getPosition ();
685
+ AffineExpr one = getAffineConstantExpr (1 , s.getContext ());
686
+ AffineExpr sizeOfM = getAffineSymbolExpr (numSym, s.getContext ());
687
+ // Construction of upper bound (size(m) + s floordiv 2 - s + 1).
688
+ AffineExpr upperOffsetExpr = sizeOfM + fDiv + one - s;
689
+ AffineMap fromMap = AffineMap::get (numDims, numSym + 1 , fDiv );
690
+ AffineMap toMap = AffineMap::get (numDims, numSym + 1 , upperOffsetExpr);
691
+ SmallVector<Value, 8 > values (viewSizes.begin (),
692
+ viewSizes.begin () + numDims);
693
+ values.insert (values.end (), viewSizes.begin () + numRes, viewSizes.end ());
694
+ values.push_back (viewSizes[mPos ]);
695
+ // Construction of the lower bound (s floordiv 2).
696
+ Value from = applyMapToValues (b, loc, fromMap, values).front ();
697
+ Value to = applyMapToValues (b, loc, toMap, values).front ();
698
+ res[mPos ] = Range{from, to, std_constant_index (1 )};
699
+ }
700
+ }
701
+ return res;
702
+ }
703
+
711
704
// / Emits a loop nest with the proper body for `op`.
712
705
template <typename LoopTy>
713
706
Optional<LinalgLoops> mlir::linalg::linalgLowerOpToLoops (OpBuilder &builder,
0 commit comments