Skip to content

Commit f5f7e2a

Browse files
amirBishamrami
authored andcommitted
[mlir][tosa] Constant optimizations for reduce operations
Replace the different reduce operations which is getting a constant tensor as an input argument with a constant tensor. As the arguement of the reduce operation is constant tensor and has only a single user we could calculate the resulted constant tensor in compilation time and replace it with reduced memory tensor This optimization has been implemented for: tosa.reduce_sum tosa.reduce_prod tosa.reduce_any tosa.reduce_all tosa.reduce_max tosa.reduce_min Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D154832
1 parent a7612e2 commit f5f7e2a

File tree

5 files changed

+636
-1
lines changed

5 files changed

+636
-1
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,11 @@ def Tosa_ReduceAllOp : Tosa_InferTensorTypeOp<"reduce_all"> {
12731273
/// Returns true when two result types are compatible for this op;
12741274
/// Method used by InferTypeOpInterface.
12751275
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
1276+
1277+
/// Return the AND result between two integer operands
1278+
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1279+
return leftOperand & rightOperand;
1280+
}
12761281
}];
12771282
}
12781283

@@ -1301,6 +1306,11 @@ def Tosa_ReduceAnyOp : Tosa_InferTensorTypeOp<"reduce_any"> {
13011306
/// Returns true when two result types are compatible for this op;
13021307
/// Method used by InferTypeOpInterface.
13031308
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
1309+
1310+
/// Return the OR result between two integer operands
1311+
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1312+
return leftOperand | rightOperand;
1313+
}
13041314
}];
13051315
}
13061316

@@ -1329,6 +1339,12 @@ def Tosa_ReduceMaxOp : Tosa_InferTensorTypeOp<"reduce_max"> {
13291339
/// Returns true when two result types are compatible for this op;
13301340
/// Method used by InferTypeOpInterface.
13311341
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
1342+
1343+
/// Return the max of the two integer operands
1344+
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1345+
const llvm::APInt subtractRes = leftOperand - rightOperand;
1346+
return (!subtractRes.isNegative()) ? leftOperand : rightOperand;
1347+
}
13321348
}];
13331349
}
13341350

@@ -1357,6 +1373,12 @@ def Tosa_ReduceMinOp : Tosa_InferTensorTypeOp<"reduce_min"> {
13571373
/// Returns true when two result types are compatible for this op;
13581374
/// Method used by InferTypeOpInterface.
13591375
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
1376+
1377+
/// Return the min of the two integer operands
1378+
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1379+
const llvm::APInt subtractRes = leftOperand - rightOperand;
1380+
return (!subtractRes.isNegative()) ? rightOperand : leftOperand;
1381+
}
13601382
}];
13611383
}
13621384

@@ -1385,6 +1407,11 @@ def Tosa_ReduceProdOp : Tosa_InferTensorTypeOp<"reduce_prod"> {
13851407
/// Returns true when two result types are compatible for this op;
13861408
/// Method used by InferTypeOpInterface.
13871409
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
1410+
1411+
/// Return the prod of the two integer operands
1412+
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1413+
return leftOperand * rightOperand;
1414+
}
13881415
}];
13891416
}
13901417

@@ -1406,13 +1433,17 @@ def Tosa_ReduceSumOp : Tosa_InferTensorTypeOp<"reduce_sum"> {
14061433
let results = (outs
14071434
Tosa_Tensor:$output
14081435
);
1409-
14101436
let hasFolder = 1;
14111437

14121438
let extraClassDeclaration = [{
14131439
/// Returns true when two result types are compatible for this op;
14141440
/// Method used by InferTypeOpInterface.
14151441
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
1442+
1443+
/// Return the sum of the two integer operands
1444+
static inline APInt calcOneElement(APInt leftOperand, APInt rightOperand) {
1445+
return leftOperand + rightOperand;
1446+
}
14161447
}];
14171448
}
14181449

mlir/include/mlir/Dialect/Tosa/Transforms/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ void populateTosaFoldConstantReciprocalPatterns(MLIRContext *ctx,
3434
RewritePatternSet &patterns);
3535
void populateTosaFoldConstantTransposePatterns(MLIRContext *ctx,
3636
RewritePatternSet &patterns);
37+
void populateTosaConstantReduction(MLIRContext *ctx,
38+
RewritePatternSet &patterns);
3739

3840
std::unique_ptr<Pass> createTosaLayerwiseConstantFoldPass();
3941
std::unique_ptr<Pass> createTosaInferShapesPass();

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include <functional>
14+
#include <numeric>
1415

1516
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1617
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
@@ -289,8 +290,130 @@ struct TosaFoldConstantReciprocal : public OpRewritePattern<ReciprocalOp> {
289290
}
290291
};
291292

293+
/// Getting the axes position of the element which is located
294+
/// in the tensor at the counter index
295+
296+
llvm::SmallVector<int64_t>
297+
getPositionFromIndex(int64_t index, llvm::ArrayRef<int64_t> tensorShape) {
298+
int64_t remaining = index;
299+
llvm::SmallVector<int64_t> position(tensorShape.size(), 0);
300+
for (int64_t i = tensorShape.size() - 1; i >= 0; --i) {
301+
position[i] = remaining % tensorShape[i];
302+
remaining /= tensorShape[i];
303+
}
304+
return position;
305+
}
306+
307+
/// Getting the index of the element which is located at the
308+
/// axes position in the tensor
309+
310+
int64_t getIndexFromPosition(llvm::ArrayRef<int64_t> position,
311+
llvm::ArrayRef<int64_t> tensorShape) {
312+
int64_t index = 0;
313+
int64_t multiplierTmp = 1;
314+
for (int64_t i = position.size() - 1; i >= 0; --i) {
315+
index += position[i] * multiplierTmp;
316+
multiplierTmp *= tensorShape[i];
317+
}
318+
return index;
319+
}
320+
321+
template <typename OperationType>
322+
llvm::APInt calculateReducedValue(const mlir::ElementsAttr &oldTensorAttr,
323+
llvm::ArrayRef<int64_t> oldShape,
324+
int64_t reductionAxis,
325+
int64_t reductionIndex) {
326+
327+
llvm::SmallVector<int64_t> newShape(oldShape);
328+
newShape[reductionAxis] = 1;
329+
/// Let's calculate the position of the index
330+
llvm::SmallVector<int64_t> position =
331+
getPositionFromIndex(reductionIndex, newShape);
332+
auto oldTensor = oldTensorAttr.getValues<llvm::APInt>();
333+
/// Starting from the first positon along the reduction axis
334+
position[reductionAxis] = 0;
335+
int64_t indexAtOldTensor = getIndexFromPosition(position, oldShape);
336+
llvm::APInt reducedValue = oldTensor[indexAtOldTensor];
337+
338+
for (int64_t reductionAxisVal = 1; reductionAxisVal < oldShape[reductionAxis];
339+
++reductionAxisVal) {
340+
341+
int64_t stride = std::accumulate(oldShape.begin() + reductionAxis + 1,
342+
oldShape.end(), 1, std::multiplies<int>());
343+
int64_t index = indexAtOldTensor + stride * reductionAxisVal;
344+
reducedValue =
345+
OperationType::calcOneElement(reducedValue, oldTensor[index]);
346+
}
347+
return reducedValue;
348+
}
349+
350+
template <typename OperationType>
351+
struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
352+
353+
using OpRewritePattern<OperationType>::OpRewritePattern;
354+
355+
LogicalResult matchAndRewrite(OperationType op,
356+
PatternRewriter &rewriter) const override {
357+
Value inputOp = op.getInput();
358+
auto constOp = inputOp.getDefiningOp<tosa::ConstOp>();
359+
360+
if (!constOp)
361+
return rewriter.notifyMatchFailure(
362+
op, "reduce input must be const operation");
363+
364+
if (!inputOp.hasOneUse())
365+
return rewriter.notifyMatchFailure(
366+
op, "input operation has more than one user");
367+
368+
auto resultType = cast<ShapedType>(op.getOutput().getType());
369+
370+
if (!resultType.hasStaticShape())
371+
return rewriter.notifyMatchFailure(op, "result type shape is not static");
372+
373+
auto reductionAxis = op.getAxis();
374+
const auto denseElementsAttr = constOp.getValue();
375+
const auto shapedOldElementsValues =
376+
denseElementsAttr.getType().cast<ShapedType>();
377+
378+
if (!llvm::isa<IntegerType>(shapedOldElementsValues.getElementType()))
379+
return rewriter.notifyMatchFailure(
380+
op, "reduce input currently supported with integer type");
381+
382+
auto oldShape = shapedOldElementsValues.getShape();
383+
auto newShape = resultType.getShape();
384+
385+
auto newNumOfElements = std::accumulate(newShape.begin(), newShape.end(), 1,
386+
std::multiplies<int>());
387+
llvm::SmallVector<APInt> newReducedTensor(newNumOfElements);
388+
389+
for (int64_t reductionIndex = 0; reductionIndex < newNumOfElements;
390+
++reductionIndex) {
391+
392+
/// Let's reduce all the elements along this reduction axis
393+
newReducedTensor[reductionIndex] = calculateReducedValue<OperationType>(
394+
denseElementsAttr, oldShape, reductionAxis, reductionIndex);
395+
}
396+
397+
auto rankedTensorType = cast<RankedTensorType>(resultType);
398+
auto denseAttr =
399+
mlir::DenseElementsAttr::get(rankedTensorType, newReducedTensor);
400+
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, rankedTensorType, denseAttr);
401+
return success();
402+
}
403+
};
404+
292405
} // namespace
293406

407+
void mlir::tosa::populateTosaConstantReduction(MLIRContext *ctx,
408+
RewritePatternSet &patterns) {
409+
patterns.add<ReduceConstantOptimization<ReduceAllOp>>(ctx);
410+
patterns.add<ReduceConstantOptimization<ReduceAnyOp>>(ctx);
411+
patterns.add<ReduceConstantOptimization<ReduceMaxOp>>(ctx);
412+
patterns.add<ReduceConstantOptimization<ReduceMinOp>>(ctx);
413+
patterns.add<ReduceConstantOptimization<ReduceProdOp>>(ctx);
414+
patterns.add<ReduceConstantOptimization<ReduceSumOp>>(ctx);
415+
}
416+
294417
void mlir::tosa::populateTosaFoldConstantTransposePatterns(
295418
MLIRContext *ctx, RewritePatternSet &patterns) {
296419
patterns.add<TosaFoldConstantTranspose>(ctx);

mlir/lib/Dialect/Tosa/Transforms/TosaLayerwiseConstantFoldPass.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ struct TosaLayerwiseConstantFoldPass
5252

5353
mlir::tosa::populateTosaFoldConstantReciprocalPatterns(ctx, patterns);
5454
mlir::tosa::populateTosaFoldConstantTransposePatterns(ctx, patterns);
55+
mlir::tosa::populateTosaConstantReduction(ctx, patterns);
5556
populateTosaOpsCanonicalizationPatterns(ctx, patterns);
5657

5758
if (applyPatternsAndFoldGreedily(func, std::move(patterns)).failed())

0 commit comments

Comments
 (0)