Skip to content

Commit 8f0c014

Browse files
[mlir][sparse] add parallelization options to mini pipeline (llvm#104233)
1 parent 1293ab3 commit 8f0c014

File tree

5 files changed

+69
-5
lines changed

5 files changed

+69
-5
lines changed

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,8 @@ std::unique_ptr<Pass> createSparsificationAndBufferizationPass(
263263
bool createSparseDeallocs, bool enableRuntimeLibrary,
264264
bool enableBufferInitialization, unsigned vectorLength,
265265
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
266-
SparseEmitStrategy emitStrategy);
266+
SparseEmitStrategy emitStrategy,
267+
SparseParallelizationStrategy parallelizationStrategy);
267268

268269
//===----------------------------------------------------------------------===//
269270
// Sparse Iteration Transform Passes

mlir/include/mlir/Dialect/SparseTensor/Transforms/Passes.td

+17
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,23 @@ def SparsificationAndBufferization : Pass<"sparsification-and-bufferization", "M
496496
"Emit (experimental) loops (with sparse.iterate)."),
497497
clEnumValN(mlir::SparseEmitStrategy::kDebugInterface, "debug-interface",
498498
"Emit non-functional but easy-to-read interfaces to debug."))}]>,
499+
Option<"parallelization", "parallelization-strategy", "mlir::SparseParallelizationStrategy",
500+
"mlir::SparseParallelizationStrategy::kNone",
501+
"Set the parallelization strategy", [{llvm::cl::values(
502+
clEnumValN(mlir::SparseParallelizationStrategy::kNone, "none",
503+
"Turn off sparse parallelization."),
504+
clEnumValN(mlir::SparseParallelizationStrategy::kDenseOuterLoop,
505+
"dense-outer-loop",
506+
"Enable dense outer loop sparse parallelization."),
507+
clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageOuterLoop,
508+
"any-storage-outer-loop",
509+
"Enable sparse parallelization regardless of storage for the outer loop."),
510+
clEnumValN(mlir::SparseParallelizationStrategy::kDenseAnyLoop,
511+
"dense-any-loop",
512+
"Enable dense parallelization for any loop."),
513+
clEnumValN(mlir::SparseParallelizationStrategy::kAnyStorageAnyLoop,
514+
"any-storage-any-loop",
515+
"Enable sparse parallelization for any storage and loop."))}]>,
499516
];
500517
}
501518

mlir/lib/Dialect/SparseTensor/Pipelines/SparseTensorPipelines.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ void mlir::sparse_tensor::buildSparsifier(OpPassManager &pm,
4545
/*enableVLAVectorization=*/options.armSVE,
4646
/*enableSIMDIndex32=*/options.force32BitVectorIndices,
4747
options.enableGPULibgen,
48-
options.sparsificationOptions().sparseEmitStrategy));
48+
options.sparsificationOptions().sparseEmitStrategy,
49+
options.sparsificationOptions().parallelizationStrategy));
4950

5051
// Bail-early for test setup.
5152
if (options.testBufferizationAnalysisOnly)

mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp

+10-3
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,8 @@ class SparsificationAndBufferizationPass
7878
const SparsificationOptions &sparsificationOptions,
7979
bool createSparseDeallocs, bool enableRuntimeLibrary,
8080
bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
81-
bool gpu, SparseEmitStrategy emitStrategy)
81+
bool gpu, SparseEmitStrategy emitStrategy,
82+
SparseParallelizationStrategy parallelizationStrategy)
8283
: bufferizationOptions(bufferizationOptions),
8384
sparsificationOptions(sparsificationOptions),
8485
createSparseDeallocs(createSparseDeallocs),
@@ -90,6 +91,7 @@ class SparsificationAndBufferizationPass
9091
enableSIMDIndex32 = index32;
9192
enableGPULibgen = gpu;
9293
sparseEmitStrategy = emitStrategy;
94+
parallelization = parallelizationStrategy;
9395
}
9496

9597
/// Bufferize all dense ops. This assumes that no further analysis is needed
@@ -124,6 +126,9 @@ class SparsificationAndBufferizationPass
124126
// Overrides the default emit strategy using user-provided value.
125127
this->sparsificationOptions.sparseEmitStrategy = sparseEmitStrategy;
126128

129+
// Overrides the default parallelization strategy using user-provided value.
130+
this->sparsificationOptions.parallelizationStrategy = parallelization;
131+
127132
// Run enabling transformations.
128133
{
129134
OpPassManager pm("builtin.module");
@@ -248,10 +253,12 @@ std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
248253
bool createSparseDeallocs, bool enableRuntimeLibrary,
249254
bool enableBufferInitialization, unsigned vectorLength,
250255
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
251-
SparseEmitStrategy emitStrategy) {
256+
SparseEmitStrategy emitStrategy,
257+
SparseParallelizationStrategy parallelizationStrategy) {
252258
return std::make_unique<
253259
mlir::sparse_tensor::SparsificationAndBufferizationPass>(
254260
bufferizationOptions, sparsificationOptions, createSparseDeallocs,
255261
enableRuntimeLibrary, enableBufferInitialization, vectorLength,
256-
enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy);
262+
enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy,
263+
parallelizationStrategy);
257264
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// RUN: mlir-opt %s --sparsification-and-bufferization | FileCheck %s --check-prefix=CHECK-NOPARA
2+
// RUN: mlir-opt %s --sparsification-and-bufferization="parallelization-strategy=any-storage-any-loop" | FileCheck %s --check-prefix=CHECK-PARA
3+
4+
// Test to ensure we can pass parallelization flags into
5+
// the mini sparsification and bufferization pipeline.
6+
7+
#SparseMatrix = #sparse_tensor.encoding<{
8+
map = (d0, d1) -> (d0 : compressed, d1 : compressed)
9+
}>
10+
11+
#trait_ss = {
12+
indexing_maps = [
13+
affine_map<(i,j) -> (i,j)>, // A
14+
affine_map<(i,j) -> (i,j)> // X (out)
15+
],
16+
iterator_types = ["parallel", "parallel"],
17+
doc = "X(i,j) = A(i,j) * SCALE"
18+
}
19+
20+
//
21+
// CHECK-NOPARA-LABEL: func.func @scale_ss
22+
// CHECK-NOPARA: scf.for
23+
//
24+
// CHECK-PARA-LABEL: func.func @scale_ss
25+
// CHECK-PARA: scf.parallel
26+
//
27+
func.func @scale_ss(%scale: f32,
28+
%arga: tensor<?x?xf32, #SparseMatrix>,
29+
%argx: tensor<?x?xf32>) -> tensor<?x?xf32> {
30+
%0 = linalg.generic #trait_ss
31+
ins(%arga: tensor<?x?xf32, #SparseMatrix>)
32+
outs(%argx: tensor<?x?xf32>) {
33+
^bb(%a: f32, %x: f32):
34+
%0 = arith.mulf %a, %scale : f32
35+
linalg.yield %0 : f32
36+
} -> tensor<?x?xf32>
37+
return %0 : tensor<?x?xf32>
38+
}

0 commit comments

Comments
 (0)