Skip to content

Commit 6e6352f

Browse files
authored
[TOSA] Add TosaToMLProgram conversion (llvm#69787)
This patch adds a new pass to lower TOSA StatefulOps to corresponding ML Program ops (https://mlir.llvm.org/docs/Dialects/MLProgramOps/). Signed-off-by: Jerry Ge <[email protected]>
1 parent beb121f commit 6e6352f

File tree

9 files changed

+226
-0
lines changed

9 files changed

+226
-0
lines changed

mlir/include/mlir/Conversion/Passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "mlir/Conversion/TensorToSPIRV/TensorToSPIRVPass.h"
6161
#include "mlir/Conversion/TosaToArith/TosaToArith.h"
6262
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
63+
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
6364
#include "mlir/Conversion/TosaToSCF/TosaToSCF.h"
6465
#include "mlir/Conversion/TosaToTensor/TosaToTensor.h"
6566
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

+13
Original file line numberDiff line numberDiff line change
@@ -1093,6 +1093,19 @@ def TosaToLinalgNamed
10931093
let constructor = "tosa::createTosaToLinalgNamed()";
10941094
}
10951095

1096+
//===----------------------------------------------------------------------===//
1097+
// TosaToMLProgram
1098+
//===----------------------------------------------------------------------===//
1099+
1100+
def TosaToMLProgram : Pass<"tosa-to-mlprogram", "ModuleOp"> {
1101+
let summary = "Lower TOSA to the MLProgram dialect";
1102+
let dependentDialects = ["ml_program::MLProgramDialect"];
1103+
let description = [{
1104+
Pass that converts TOSA's variable operator operations to the equivalent
1105+
MLProgram operations.
1106+
}];
1107+
}
1108+
10961109
//===----------------------------------------------------------------------===//
10971110
// TosaToSCF
10981111
//===----------------------------------------------------------------------===//
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//===-- TosaToMLProgram.h - TOSA to MLProgram dialect lowerings-*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file declares the passes for the TOSA to MLProgram Dialect conversion.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
14+
#define MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H
15+
16+
#include "mlir/Pass/Pass.h"
17+
#include "mlir/Transforms/DialectConversion.h"
18+
19+
namespace mlir {
20+
21+
#define GEN_PASS_DECL_TOSATOMLPROGRAM
22+
23+
namespace tosa {
24+
25+
void populateTosaToMLProgramConversionPatterns(RewritePatternSet *patterns);
26+
27+
} // namespace tosa
28+
} // namespace mlir
29+
30+
#endif // MLIR_CONVERSION_TOSATOMLPROGRAM_TOSATOMLPROGRAM_H

mlir/lib/Conversion/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ add_subdirectory(TensorToLinalg)
5050
add_subdirectory(TensorToSPIRV)
5151
add_subdirectory(TosaToArith)
5252
add_subdirectory(TosaToLinalg)
53+
add_subdirectory(TosaToMLProgram)
5354
add_subdirectory(TosaToSCF)
5455
add_subdirectory(TosaToTensor)
5556
add_subdirectory(UBToLLVM)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(MLIRTosaToMLProgram
2+
TosaToMLProgram.cpp
3+
TosaToMLProgramPass.cpp
4+
5+
ADDITIONAL_HEADER_DIRS
6+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Tosa
7+
${MLIR_MAIN_INCLUDE_DIR}/mlir/IR
8+
9+
DEPENDS
10+
MLIRConversionPassIncGen
11+
12+
LINK_LIBS PUBLIC
13+
MLIRIR
14+
MLIRMLProgramDialect
15+
MLIRPass
16+
MLIRTosaDialect
17+
MLIRTosaTransforms
18+
MLIRSupport
19+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
//===- TosaToMLProgram.cpp - Lowering Tosa to MLProgram Dialect------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// These rewriters lower from the TOSA dialect to the MLProgram dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
14+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
15+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16+
#include "mlir/IR/IRMapping.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
19+
using namespace mlir;
20+
using namespace tosa;
21+
namespace {
22+
23+
class VariableOpConverter : public OpRewritePattern<tosa::VariableOp> {
24+
public:
25+
using OpRewritePattern<tosa::VariableOp>::OpRewritePattern;
26+
27+
LogicalResult matchAndRewrite(tosa::VariableOp op,
28+
PatternRewriter &rewriter) const final {
29+
auto newVariable = rewriter.create<mlir::ml_program::GlobalOp>(
30+
op.getLoc(), op.getName(), op.getType(), /*is_mutable=*/true,
31+
op.getInitialValueAttr(), /*sym_visibility=*/nullptr);
32+
newVariable.setPrivate();
33+
rewriter.replaceOp(op, newVariable);
34+
return success();
35+
}
36+
};
37+
38+
class VariableWriteOpConverter
39+
: public OpRewritePattern<tosa::VariableWriteOp> {
40+
public:
41+
using OpRewritePattern<tosa::VariableWriteOp>::OpRewritePattern;
42+
43+
LogicalResult matchAndRewrite(tosa::VariableWriteOp op,
44+
PatternRewriter &rewriter) const final {
45+
auto globalSymbolRef =
46+
SymbolRefAttr::get(rewriter.getContext(), op.getName());
47+
auto newVariableWrite = rewriter.create<ml_program::GlobalStoreOp>(
48+
op.getLoc(), globalSymbolRef, op.getValue());
49+
rewriter.replaceOp(op, newVariableWrite);
50+
return success();
51+
}
52+
};
53+
54+
class VariableReadOpConverter : public OpRewritePattern<tosa::VariableReadOp> {
55+
public:
56+
using OpRewritePattern<tosa::VariableReadOp>::OpRewritePattern;
57+
58+
LogicalResult matchAndRewrite(tosa::VariableReadOp op,
59+
PatternRewriter &rewriter) const final {
60+
auto globalSymbolRef =
61+
SymbolRefAttr::get(rewriter.getContext(), op.getName());
62+
auto newVariableRead = rewriter.create<ml_program::GlobalLoadOp>(
63+
op.getLoc(), op.getType(), globalSymbolRef);
64+
rewriter.replaceOp(op, newVariableRead);
65+
66+
return success();
67+
}
68+
};
69+
70+
} // namespace
71+
72+
void mlir::tosa::populateTosaToMLProgramConversionPatterns(
73+
RewritePatternSet *patterns) {
74+
patterns->add<VariableOpConverter, VariableWriteOpConverter,
75+
VariableReadOpConverter>(patterns->getContext());
76+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
//===- TosaToMLProgramPass.cpp - Lowering Tosa to MLProgram Dialect--------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This transformation pass legalizes the TOSA dialect to the MLProgram dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Conversion/TosaToMLProgram/TosaToMLProgram.h"
14+
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
15+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
16+
#include "mlir/Dialect/Tosa/Transforms/Passes.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Pass/PassManager.h"
19+
#include "mlir/Transforms/DialectConversion.h"
20+
21+
namespace mlir {
22+
#define GEN_PASS_DEF_TOSATOMLPROGRAM
23+
#include "mlir/Conversion/Passes.h.inc"
24+
} // namespace mlir
25+
26+
using namespace mlir;
27+
using namespace tosa;
28+
29+
namespace {
30+
struct TosaToMLProgram : public impl::TosaToMLProgramBase<TosaToMLProgram> {
31+
public:
32+
void runOnOperation() override {
33+
auto *context = &getContext();
34+
auto moduleOp = getOperation();
35+
36+
RewritePatternSet patterns(context);
37+
ConversionTarget target(*context);
38+
target.addIllegalOp<tosa::VariableOp, tosa::VariableReadOp,
39+
tosa::VariableWriteOp>();
40+
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
41+
42+
mlir::tosa::populateTosaToMLProgramConversionPatterns(&patterns);
43+
44+
if (failed(applyPartialConversion(moduleOp, target, std::move(patterns))))
45+
signalPassFailure();
46+
}
47+
};
48+
} // namespace
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: mlir-opt --tosa-to-mlprogram %s -o -| FileCheck %s
2+
3+
module {
4+
// CHECK: ml_program.global private mutable @var_x(dense<7.000000e+00> : tensor<1xf32>) : tensor<1xf32>
5+
tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32>
6+
func.func @test_stateful_ops(%arg0: tensor<1xf32>) -> (tensor<1xf32>) {
7+
// CHECK: ml_program.global_store @var_x = %arg0 : tensor<1xf32>
8+
tosa.variable.write @var_x, %arg0 : tensor<1xf32>
9+
// CHECK: %[[LOAD:.+]] = ml_program.global_load @var_x : tensor<1xf32>
10+
%0 = tosa.variable.read @var_x : tensor<1xf32>
11+
return %0 : tensor<1xf32>
12+
}
13+
}

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

+25
Original file line numberDiff line numberDiff line change
@@ -3817,6 +3817,7 @@ cc_library(
38173817
":TensorToSPIRV",
38183818
":TosaToArith",
38193819
":TosaToLinalg",
3820+
":TosaToMLProgram",
38203821
":TosaToSCF",
38213822
":TosaToTensor",
38223823
":UBToLLVM",
@@ -11212,6 +11213,30 @@ cc_library(
1121211213
],
1121311214
)
1121411215

11216+
cc_library(
11217+
name = "TosaToMLProgram",
11218+
srcs = glob([
11219+
"lib/Conversion/TosaToMLProgram/*.cpp",
11220+
"lib/Conversion/TosaToMLProgram/*.h",
11221+
]),
11222+
hdrs = glob([
11223+
"include/mlir/Conversion/TosaToMLProgram/*.h",
11224+
]),
11225+
includes = [
11226+
"include",
11227+
"lib/Conversion/TosaToMLProgram",
11228+
],
11229+
deps = [
11230+
":ConversionPassIncGen",
11231+
":FuncDialect",
11232+
":IR",
11233+
":Pass",
11234+
":MLProgramDialect",
11235+
":TosaDialect",
11236+
":Transforms",
11237+
],
11238+
)
11239+
1121511240
cc_library(
1121611241
name = "TosaToSCF",
1121711242
srcs = glob([

0 commit comments

Comments
 (0)