12
12
//
13
13
// ===----------------------------------------------------------------------===//
14
14
15
- #include " mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h"
16
15
#include " mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h"
17
- #include " mlir/Conversion/GPUToSPIRV/GPUToSPIRVPass.h"
18
16
#include " mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
19
17
#include " mlir/Conversion/LLVMCommon/LoweringOptions.h"
20
18
#include " mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
30
28
#include " mlir/Dialect/MemRef/Transforms/Passes.h"
31
29
#include " mlir/Dialect/SCF/IR/SCF.h"
32
30
#include " mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
33
- #include " mlir/Dialect/SPIRV/IR/SPIRVOps.h"
34
- #include " mlir/Dialect/SPIRV/Transforms/Passes.h"
35
31
#include " mlir/Dialect/Vector/IR/VectorOps.h"
36
32
#include " mlir/ExecutionEngine/JitRunner.h"
37
33
#include " mlir/Pass/Pass.h"
43
39
44
40
using namespace mlir ;
45
41
46
- namespace {
47
- struct VulkanRunnerOptions {
48
- llvm::cl::OptionCategory category{" mlir-vulkan-runner options" };
49
- llvm::cl::opt<bool > spirvWebGPUPrepare{
50
- " vulkan-runner-spirv-webgpu-prepare" ,
51
- llvm::cl::desc (" Run MLIR transforms used when targetting WebGPU" ),
52
- llvm::cl::cat (category)};
53
- };
54
- } // namespace
55
-
56
- static LogicalResult runMLIRPasses (Operation *op,
57
- VulkanRunnerOptions &options) {
42
+ static LogicalResult runMLIRPasses (Operation *op, JitRunnerOptions &) {
58
43
auto module = dyn_cast<ModuleOp>(op);
59
44
if (!module)
60
45
return op->emitOpError (" expected a 'builtin.module' op" );
61
46
PassManager passManager (module.getContext ());
62
47
if (failed (applyPassManagerCLOptions (passManager)))
63
48
return failure ();
64
49
65
- passManager.addPass (createGpuKernelOutliningPass ());
66
- passManager.addPass (memref::createFoldMemRefAliasOpsPass ());
67
-
68
- ConvertToSPIRVPassOptions convertToSPIRVOptions{};
69
- convertToSPIRVOptions.convertGPUModules = true ;
70
- passManager.addPass (createConvertToSPIRVPass (convertToSPIRVOptions));
71
- OpPassManager &modulePM = passManager.nest <spirv::ModuleOp>();
72
- modulePM.addPass (spirv::createSPIRVLowerABIAttributesPass ());
73
- modulePM.addPass (spirv::createSPIRVUpdateVCEPass ());
74
- if (options.spirvWebGPUPrepare )
75
- modulePM.addPass (spirv::createSPIRVWebGPUPreparePass ());
76
-
77
50
passManager.addPass (createConvertGpuLaunchFuncToVulkanLaunchFuncPass ());
78
51
passManager.addPass (createFinalizeMemRefToLLVMConversionPass ());
79
52
passManager.addPass (createConvertVectorToLLVMPass ());
@@ -96,15 +69,8 @@ int main(int argc, char **argv) {
96
69
llvm::InitializeNativeTarget ();
97
70
llvm::InitializeNativeTargetAsmPrinter ();
98
71
99
- // Initialize runner-specific CLI options. These will be parsed and
100
- // initialzied in `JitRunnerMain`.
101
- VulkanRunnerOptions options;
102
- auto runPassesWithOptions = [&options](Operation *op, JitRunnerOptions &) {
103
- return runMLIRPasses (op, options);
104
- };
105
-
106
72
mlir::JitRunnerConfig jitRunnerConfig;
107
- jitRunnerConfig.mlirTransformer = runPassesWithOptions ;
73
+ jitRunnerConfig.mlirTransformer = runMLIRPasses ;
108
74
109
75
mlir::DialectRegistry registry;
110
76
registry.insert <mlir::arith::ArithDialect, mlir::LLVM::LLVMDialect,
0 commit comments