Skip to content

Commit d26eb82

Browse files
committed
[mlir][bufferization] DeallocOp canonicalizer removing memrefs that are never deallocated
This simplifies the op and avoids unnecessary alias checks introduced during the lowering to memref. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D156807
1 parent 6676027 commit d26eb82

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -866,11 +866,60 @@ struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
866866
}
867867
};
868868

869+
/// Removes memrefs from the deallocation list if their associated condition is
870+
/// always 'false'.
871+
///
872+
/// Example:
873+
/// ```
874+
/// %0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
875+
/// if (%arg2, %false)
876+
/// ```
877+
/// becomes
878+
/// ```
879+
/// %0 = bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
880+
/// ```
881+
struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
882+
using OpRewritePattern<DeallocOp>::OpRewritePattern;
883+
884+
LogicalResult matchAndRewrite(DeallocOp deallocOp,
885+
PatternRewriter &rewriter) const override {
886+
SmallVector<Value> newMemrefs, newConditions;
887+
SmallVector<Value> replacements;
888+
889+
for (auto [res, memref, cond] :
890+
llvm::zip(deallocOp.getUpdatedConditions(), deallocOp.getMemrefs(),
891+
deallocOp.getConditions())) {
892+
if (matchPattern(cond, m_Zero())) {
893+
replacements.push_back(cond);
894+
continue;
895+
}
896+
newMemrefs.push_back(memref);
897+
newConditions.push_back(cond);
898+
replacements.push_back({});
899+
}
900+
901+
if (newMemrefs.size() == deallocOp.getMemrefs().size())
902+
return failure();
903+
904+
auto newDeallocOp = rewriter.create<DeallocOp>(
905+
deallocOp.getLoc(), newMemrefs, newConditions, deallocOp.getRetained());
906+
unsigned i = 0;
907+
for (auto &repl : replacements)
908+
if (!repl)
909+
repl = newDeallocOp.getResult(i++);
910+
911+
rewriter.replaceOp(deallocOp, replacements);
912+
return success();
913+
}
914+
};
915+
869916
} // anonymous namespace
870917

871918
void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
872919
MLIRContext *context) {
873-
results.add<DeallocRemoveDuplicates, EraseEmptyDealloc>(context);
920+
results
921+
.add<DeallocRemoveDuplicates, EraseEmptyDealloc, EraseAlwaysFalseDealloc>(
922+
context);
874923
}
875924

876925
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,17 @@ func.func @dealloc_canonicalize_retained_and_deallocated(%arg0: memref<2xi32>, %
309309
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: i1, [[ARG2:%.+]]: memref<2xi32>)
310310
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG2]] : memref<2xi32>) if ([[ARG1]]) retain ([[ARG0]] : memref<2xi32>)
311311
// CHECK-NEXT: return [[ARG1]], [[ARG1]], [[V0]] :
312+
313+
// -----
314+
315+
func.func @dealloc_always_false_condition(%arg0: memref<2xi32>, %arg1: memref<2xi32>, %arg2: i1) -> (i1, i1) {
316+
%false = arith.constant false
317+
%0:2 = bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>) if (%false, %arg2)
318+
return %0#0, %0#1 : i1, i1
319+
}
320+
321+
// CHECK-LABEL: func @dealloc_always_false_condition
322+
// CHECK-SAME: ([[ARG0:%.+]]: memref<2xi32>, [[ARG1:%.+]]: memref<2xi32>, [[ARG2:%.+]]: i1)
323+
// CHECK-NEXT: [[FALSE:%.+]] = arith.constant false
324+
// CHECK-NEXT: [[V0:%.+]] = bufferization.dealloc ([[ARG1]] : {{.*}}) if ([[ARG2]])
325+
// CHECK-NEXT: return [[FALSE]], [[V0]] :

0 commit comments

Comments
 (0)