@@ -78,7 +78,8 @@ class SparsificationAndBufferizationPass
78
78
const SparsificationOptions &sparsificationOptions,
79
79
bool createSparseDeallocs, bool enableRuntimeLibrary,
80
80
bool enableBufferInitialization, unsigned vl, bool vla, bool index32,
81
- bool gpu, SparseEmitStrategy emitStrategy)
81
+ bool gpu, SparseEmitStrategy emitStrategy,
82
+ SparseParallelizationStrategy parallelizationStrategy)
82
83
: bufferizationOptions(bufferizationOptions),
83
84
sparsificationOptions(sparsificationOptions),
84
85
createSparseDeallocs(createSparseDeallocs),
@@ -90,6 +91,7 @@ class SparsificationAndBufferizationPass
90
91
enableSIMDIndex32 = index32;
91
92
enableGPULibgen = gpu;
92
93
sparseEmitStrategy = emitStrategy;
94
+ parallelization = parallelizationStrategy;
93
95
}
94
96
95
97
// / Bufferize all dense ops. This assumes that no further analysis is needed
@@ -124,6 +126,9 @@ class SparsificationAndBufferizationPass
124
126
// Overrides the default emit strategy using user-provided value.
125
127
this ->sparsificationOptions .sparseEmitStrategy = sparseEmitStrategy;
126
128
129
+ // Overrides the default parallelization strategy using user-provided value.
130
+ this ->sparsificationOptions .parallelizationStrategy = parallelization;
131
+
127
132
// Run enabling transformations.
128
133
{
129
134
OpPassManager pm (" builtin.module" );
@@ -248,10 +253,12 @@ std::unique_ptr<mlir::Pass> mlir::createSparsificationAndBufferizationPass(
248
253
bool createSparseDeallocs, bool enableRuntimeLibrary,
249
254
bool enableBufferInitialization, unsigned vectorLength,
250
255
bool enableVLAVectorization, bool enableSIMDIndex32, bool enableGPULibgen,
251
- SparseEmitStrategy emitStrategy) {
256
+ SparseEmitStrategy emitStrategy,
257
+ SparseParallelizationStrategy parallelizationStrategy) {
252
258
return std::make_unique<
253
259
mlir::sparse_tensor::SparsificationAndBufferizationPass>(
254
260
bufferizationOptions, sparsificationOptions, createSparseDeallocs,
255
261
enableRuntimeLibrary, enableBufferInitialization, vectorLength,
256
- enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy);
262
+ enableVLAVectorization, enableSIMDIndex32, enableGPULibgen, emitStrategy,
263
+ parallelizationStrategy);
257
264
}
0 commit comments