Skip to content

Commit 98eead8

Browse files
committed
[mlir][Value] Add v.getDefiningOp<OpTy>()
Summary: This makes a common pattern of `dyn_cast_or_null<OpTy>(v.getDefiningOp())` more concise. Differential Revision: https://reviews.llvm.org/D79681
1 parent 51e6fc4 commit 98eead8

File tree

26 files changed

+56
-66
lines changed

26 files changed

+56
-66
lines changed

mlir/docs/Tutorials/Toy/Ch-3.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
9191
mlir::PatternRewriter &rewriter) const override {
9292
// Look through the input of the current transpose.
9393
mlir::Value transposeInput = op.getOperand();
94-
TransposeOp transposeInputOp =
95-
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
94+
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
9695
9796
// Input defined by another transpose? If not, no match.
9897
if (!transposeInputOp)

mlir/examples/toy/Ch3/mlir/ToyCombine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
4040
mlir::PatternRewriter &rewriter) const override {
4141
// Look through the input of the current transpose.
4242
mlir::Value transposeInput = op.getOperand();
43-
TransposeOp transposeInputOp =
44-
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
43+
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
4544

4645
// Input defined by another transpose? If not, no match.
4746
if (!transposeInputOp)

mlir/examples/toy/Ch4/mlir/ToyCombine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
4545
mlir::PatternRewriter &rewriter) const override {
4646
// Look through the input of the current transpose.
4747
mlir::Value transposeInput = op.getOperand();
48-
TransposeOp transposeInputOp =
49-
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
48+
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
5049

5150
// Input defined by another transpose? If not, no match.
5251
if (!transposeInputOp)

mlir/examples/toy/Ch5/mlir/ToyCombine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
4545
mlir::PatternRewriter &rewriter) const override {
4646
// Look through the input of the current transpose.
4747
mlir::Value transposeInput = op.getOperand();
48-
TransposeOp transposeInputOp =
49-
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
48+
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
5049

5150
// Input defined by another transpose? If not, no match.
5251
if (!transposeInputOp)

mlir/examples/toy/Ch6/mlir/ToyCombine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
4545
mlir::PatternRewriter &rewriter) const override {
4646
// Look through the input of the current transpose.
4747
mlir::Value transposeInput = op.getOperand();
48-
TransposeOp transposeInputOp =
49-
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
48+
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
5049

5150
// Input defined by another transpose? If not, no match.
5251
if (!transposeInputOp)

mlir/examples/toy/Ch7/mlir/ToyCombine.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,7 @@ struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
6363
mlir::PatternRewriter &rewriter) const override {
6464
// Look through the input of the current transpose.
6565
mlir::Value transposeInput = op.getOperand();
66-
TransposeOp transposeInputOp =
67-
llvm::dyn_cast_or_null<TransposeOp>(transposeInput.getDefiningOp());
66+
TransposeOp transposeInputOp = transposeInput.getDefiningOp<TransposeOp>();
6867

6968
// Input defined by another transpose? If not, no match.
7069
if (!transposeInputOp)

mlir/include/mlir/IR/Value.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,13 @@ class Value {
116116
/// defines it.
117117
Operation *getDefiningOp() const;
118118

119+
/// If this value is the result of an operation of type OpTy, return the
120+
/// operation that defines it.
121+
template <typename OpTy>
122+
OpTy getDefiningOp() const {
123+
return llvm::dyn_cast_or_null<OpTy>(getDefiningOp());
124+
}
125+
119126
/// If this value is the result of an operation, use it as a location,
120127
/// otherwise return an unknown location.
121128
Location getLoc() const;

mlir/lib/Analysis/AffineAnalysis.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ addMemRefAccessConstraints(const AffineValueMap &srcAccessMap,
453453
auto symbol = operands[i];
454454
assert(isValidSymbol(symbol));
455455
// Check if the symbol is a constant.
456-
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(symbol.getDefiningOp()))
456+
if (auto cOp = symbol.getDefiningOp<ConstantIndexOp>())
457457
dependenceDomain->setIdToConstant(valuePosMap.getSymPos(symbol),
458458
cOp.getValue());
459459
}

mlir/lib/Analysis/AffineStructures.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,7 @@ void FlatAffineConstraints::addInductionVarOrTerminalSymbol(Value id) {
665665
// Add top level symbol.
666666
addSymbolId(getNumSymbolIds(), id);
667667
// Check if the symbol is a constant.
668-
if (auto constOp = dyn_cast_or_null<ConstantIndexOp>(id.getDefiningOp()))
668+
if (auto constOp = id.getDefiningOp<ConstantIndexOp>())
669669
setIdToConstant(id, constOp.getValue());
670670
}
671671

mlir/lib/Analysis/Utils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ ComputationSliceState::getAsConstraints(FlatAffineConstraints *cst) {
6464
assert(cst->containsId(value) && "value expected to be present");
6565
if (isValidSymbol(value)) {
6666
// Check if the symbol is a constant.
67-
if (auto cOp = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
67+
if (auto cOp = value.getDefiningOp<ConstantIndexOp>())
6868
cst->setIdToConstant(value, cOp.getValue());
6969
} else if (auto loop = getForInductionVarOwner(value)) {
7070
if (failed(cst->addAffineForOpDomain(loop)))

mlir/lib/Conversion/LoopsToGPU/LoopsToGPU.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ struct LoopToGpuConverter {
219219

220220
// Return true if the value is obviously a constant "one".
221221
static bool isConstantOne(Value value) {
222-
if (auto def = dyn_cast_or_null<ConstantIndexOp>(value.getDefiningOp()))
222+
if (auto def = value.getDefiningOp<ConstantIndexOp>())
223223
return def.getValue() == 1;
224224
return false;
225225
}
@@ -505,11 +505,11 @@ struct ParallelToGpuLaunchLowering : public OpRewritePattern<ParallelOp> {
505505
/// `upperBound`.
506506
static Value deriveStaticUpperBound(Value upperBound,
507507
PatternRewriter &rewriter) {
508-
if (auto op = dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp())) {
508+
if (auto op = upperBound.getDefiningOp<ConstantIndexOp>()) {
509509
return op;
510510
}
511511

512-
if (auto minOp = dyn_cast_or_null<AffineMinOp>(upperBound.getDefiningOp())) {
512+
if (auto minOp = upperBound.getDefiningOp<AffineMinOp>()) {
513513
for (const AffineExpr &result : minOp.map().getResults()) {
514514
if (auto constExpr = result.dyn_cast<AffineConstantExpr>()) {
515515
return rewriter.create<ConstantIndexOp>(minOp.getLoc(),
@@ -518,7 +518,7 @@ static Value deriveStaticUpperBound(Value upperBound,
518518
}
519519
}
520520

521-
if (auto multiplyOp = dyn_cast_or_null<MulIOp>(upperBound.getDefiningOp())) {
521+
if (auto multiplyOp = upperBound.getDefiningOp<MulIOp>()) {
522522
if (auto lhs = dyn_cast_or_null<ConstantIndexOp>(
523523
deriveStaticUpperBound(multiplyOp.getOperand(0), rewriter)
524524
.getDefiningOp()))
@@ -607,7 +607,7 @@ static LogicalResult processParallelLoop(
607607
launchIndependent](Value val) -> Value {
608608
if (launchIndependent(val))
609609
return val;
610-
if (ConstantOp constOp = dyn_cast_or_null<ConstantOp>(val.getDefiningOp()))
610+
if (ConstantOp constOp = val.getDefiningOp<ConstantOp>())
611611
return rewriter.create<ConstantOp>(constOp.getLoc(), constOp.getValue());
612612
return {};
613613
};

mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ resolveSourceIndices(Location loc, PatternRewriter &rewriter,
110110
LogicalResult
111111
LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
112112
PatternRewriter &rewriter) const {
113-
auto subViewOp = dyn_cast_or_null<SubViewOp>(loadOp.memref().getDefiningOp());
113+
auto subViewOp = loadOp.memref().getDefiningOp<SubViewOp>();
114114
if (!subViewOp) {
115115
return failure();
116116
}
@@ -131,8 +131,7 @@ LoadOpOfSubViewFolder::matchAndRewrite(LoadOp loadOp,
131131
LogicalResult
132132
StoreOpOfSubViewFolder::matchAndRewrite(StoreOp storeOp,
133133
PatternRewriter &rewriter) const {
134-
auto subViewOp =
135-
dyn_cast_or_null<SubViewOp>(storeOp.memref().getDefiningOp());
134+
auto subViewOp = storeOp.memref().getDefiningOp<SubViewOp>();
136135
if (!subViewOp) {
137136
return failure();
138137
}

mlir/lib/Dialect/Affine/EDSC/Builders.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims,
9393
unsigned &numSymbols) {
9494
AffineExpr d;
9595
Value resultVal = nullptr;
96-
if (auto constant = dyn_cast_or_null<ConstantIndexOp>(val.getDefiningOp())) {
96+
if (auto constant = val.getDefiningOp<ConstantIndexOp>()) {
9797
d = getAffineConstantExpr(constant.getValue(), context);
9898
} else if (isValidSymbol(val) && !isValidDim(val)) {
9999
d = getAffineSymbolExpr(numSymbols++, context);

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -591,7 +591,7 @@ AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
591591
// 2. Compose AffineApplyOps and dispatch dims or symbols.
592592
for (unsigned i = 0, e = operands.size(); i < e; ++i) {
593593
auto t = operands[i];
594-
auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp());
594+
auto affineApply = t.getDefiningOp<AffineApplyOp>();
595595
if (affineApply) {
596596
// a. Compose affine.apply operations.
597597
LLVM_DEBUG(affineApply.getOperation()->print(
@@ -912,7 +912,7 @@ void AffineApplyOp::getCanonicalizationPatterns(
912912
static LogicalResult foldMemRefCast(Operation *op) {
913913
bool folded = false;
914914
for (OpOperand &operand : op->getOpOperands()) {
915-
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
915+
auto cast = operand.get().getDefiningOp<MemRefCastOp>();
916916
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
917917
operand.set(cast.getOperand());
918918
folded = true;

mlir/lib/Dialect/Affine/Transforms/SuperVectorize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -965,7 +965,7 @@ static Value vectorizeOperand(Value operand, Operation *op,
965965
return nullptr;
966966
}
967967
// 3. vectorize constant.
968-
if (auto constant = dyn_cast_or_null<ConstantOp>(operand.getDefiningOp())) {
968+
if (auto constant = operand.getDefiningOp<ConstantOp>()) {
969969
return vectorizeConstant(
970970
op, constant,
971971
VectorType::get(state->strategy->vectorSizes, operand.getType()));

mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -425,19 +425,18 @@ static LogicalResult verify(LandingpadOp op) {
425425
} else {
426426
// catch - global addresses only.
427427
// Bitcast ops should have global addresses as their args.
428-
if (auto bcOp = dyn_cast_or_null<BitcastOp>(value.getDefiningOp())) {
429-
if (auto addrOp =
430-
dyn_cast_or_null<AddressOfOp>(bcOp.arg().getDefiningOp()))
428+
if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
429+
if (auto addrOp = bcOp.arg().getDefiningOp<AddressOfOp>())
431430
continue;
432431
return op.emitError("constant clauses expected")
433432
.attachNote(bcOp.getLoc())
434433
<< "global addresses expected as operand to "
435434
"bitcast used in clauses for landingpad";
436435
}
437436
// NullOp and AddressOfOp allowed
438-
if (dyn_cast_or_null<NullOp>(value.getDefiningOp()))
437+
if (value.getDefiningOp<NullOp>())
439438
continue;
440-
if (dyn_cast_or_null<AddressOfOp>(value.getDefiningOp()))
439+
if (value.getDefiningOp<AddressOfOp>())
441440
continue;
442441
return op.emitError("clause #")
443442
<< idx << " is not a known constant - null, addressof, bitcast";

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op);
5252
static LogicalResult foldMemRefCast(Operation *op) {
5353
bool folded = false;
5454
for (OpOperand &operand : op->getOpOperands()) {
55-
auto castOp = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
55+
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
5656
if (castOp && canFoldIntoConsumerOp(castOp)) {
5757
operand.set(castOp.getOperand());
5858
folded = true;

mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,8 +319,8 @@ fuseProducerOfDep(OpBuilder &b, LinalgOp consumer, unsigned consumerIdx,
319319

320320
// Must be a subview or a slice to guarantee there are loops we can fuse
321321
// into.
322-
auto subView = dyn_cast_or_null<SubViewOp>(consumedView.getDefiningOp());
323-
auto slice = dyn_cast_or_null<SliceOp>(consumedView.getDefiningOp());
322+
auto subView = consumedView.getDefiningOp<SubViewOp>();
323+
auto slice = consumedView.getDefiningOp<SliceOp>();
324324
if (!subView && !slice) {
325325
LLVM_DEBUG(dbgs() << "\nNot fusable (not a subview or slice)");
326326
continue;

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
8888
/// Otherwise return size.
8989
static Value extractSmallestConstantBoundingSize(OpBuilder &b, Location loc,
9090
Value size) {
91-
auto affineMinOp = dyn_cast_or_null<AffineMinOp>(size.getDefiningOp());
91+
auto affineMinOp = size.getDefiningOp<AffineMinOp>();
9292
if (!affineMinOp)
9393
return size;
9494
int64_t minConst = std::numeric_limits<int64_t>::max();
@@ -112,7 +112,7 @@ static Value allocBuffer(Type elementType, Value size, bool dynamicBuffers,
112112
alignment_attr =
113113
IntegerAttr::get(IntegerType::get(64, ctx), alignment.getValue());
114114
if (!dynamicBuffers)
115-
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp()))
115+
if (auto cst = size.getDefiningOp<ConstantIndexOp>())
116116
return std_alloc(
117117
MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)),
118118
ValueRange{}, alignment_attr);

mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp,
287287
// accesses, unless we statically know the subview size divides the view
288288
// size evenly.
289289
int64_t viewSize = viewType.getDimSize(r);
290-
auto sizeCst = dyn_cast_or_null<ConstantIndexOp>(size.getDefiningOp());
290+
auto sizeCst = size.getDefiningOp<ConstantIndexOp>();
291291
if (ShapedType::isDynamic(viewSize) || !sizeCst ||
292292
(viewSize % sizeCst.getValue()) != 0) {
293293
// Compute min(size, dim - offset) to avoid out-of-bounds accesses.

mlir/lib/Dialect/Quant/IR/QuantOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ QuantizationDialect::QuantizationDialect(MLIRContext *context)
3636
OpFoldResult StorageCastOp::fold(ArrayRef<Attribute> operands) {
3737
// Matches x -> [scast -> scast] -> y, replacing the second scast with the
3838
// value of x if the casts invert each other.
39-
auto srcScastOp = dyn_cast_or_null<StorageCastOp>(arg().getDefiningOp());
39+
auto srcScastOp = arg().getDefiningOp<StorageCastOp>();
4040
if (!srcScastOp || srcScastOp.arg().getType() != getType())
4141
return OpFoldResult();
4242
return srcScastOp.arg();

mlir/lib/Dialect/SCF/SCF.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
5555
}
5656

5757
static LogicalResult verify(ForOp op) {
58-
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(op.step().getDefiningOp()))
58+
if (auto cst = op.step().getDefiningOp<ConstantIndexOp>())
5959
if (cst.getValue() <= 0)
6060
return op.emitOpError("constant step operand must be positive");
6161

@@ -403,7 +403,7 @@ static LogicalResult verify(ParallelOp op) {
403403

404404
// Check whether all constant step values are positive.
405405
for (Value stepValue : stepValues)
406-
if (auto cst = dyn_cast_or_null<ConstantIndexOp>(stepValue.getDefiningOp()))
406+
if (auto cst = stepValue.getDefiningOp<ConstantIndexOp>())
407407
if (cst.getValue() <= 0)
408408
return op.emitOpError("constant step operand must be positive");
409409

mlir/lib/Dialect/SCF/Transforms/ParallelLoopSpecialization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ static void specializeLoopForUnrolling(ParallelOp op) {
2929
SmallVector<int64_t, 2> constantIndices;
3030
constantIndices.reserve(op.upperBound().size());
3131
for (auto bound : op.upperBound()) {
32-
auto minOp = dyn_cast_or_null<AffineMinOp>(bound.getDefiningOp());
32+
auto minOp = bound.getDefiningOp<AffineMinOp>();
3333
if (!minOp)
3434
return;
3535
int64_t minConstant = std::numeric_limits<int64_t>::max();

mlir/lib/Dialect/StandardOps/IR/Ops.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ static detail::op_matcher<ConstantIndexOp> m_ConstantIndex() {
209209
static LogicalResult foldMemRefCast(Operation *op) {
210210
bool folded = false;
211211
for (OpOperand &operand : op->getOpOperands()) {
212-
auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
212+
auto cast = operand.get().getDefiningOp<MemRefCastOp>();
213213
if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
214214
operand.set(cast.getOperand());
215215
folded = true;
@@ -1696,7 +1696,7 @@ bool IndexCastOp::areCastCompatible(Type a, Type b) {
16961696

16971697
OpFoldResult IndexCastOp::fold(ArrayRef<Attribute> cstOperands) {
16981698
// Fold IndexCast(IndexCast(x)) -> x
1699-
auto cast = dyn_cast_or_null<IndexCastOp>(getOperand().getDefiningOp());
1699+
auto cast = getOperand().getDefiningOp<IndexCastOp>();
17001700
if (cast && cast.getOperand().getType() == getType())
17011701
return cast.getOperand();
17021702

@@ -2617,8 +2617,7 @@ OpFoldResult SubViewOp::fold(ArrayRef<Attribute>) {
26172617
auto folds = [](Operation *op) {
26182618
bool folded = false;
26192619
for (OpOperand &operand : op->getOpOperands()) {
2620-
auto castOp =
2621-
dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
2620+
auto castOp = operand.get().getDefiningOp<MemRefCastOp>();
26222621
if (castOp && canFoldIntoConsumerOp(castOp)) {
26232622
operand.set(castOp.getOperand());
26242623
folded = true;
@@ -2890,12 +2889,11 @@ struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
28902889
LogicalResult matchAndRewrite(ViewOp viewOp,
28912890
PatternRewriter &rewriter) const override {
28922891
Value memrefOperand = viewOp.getOperand(0);
2893-
MemRefCastOp memrefCastOp =
2894-
dyn_cast_or_null<MemRefCastOp>(memrefOperand.getDefiningOp());
2892+
MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp<MemRefCastOp>();
28952893
if (!memrefCastOp)
28962894
return failure();
28972895
Value allocOperand = memrefCastOp.getOperand();
2898-
AllocOp allocOp = dyn_cast_or_null<AllocOp>(allocOperand.getDefiningOp());
2896+
AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
28992897
if (!allocOp)
29002898
return failure();
29012899
rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1611,7 +1611,7 @@ class TransposeFolder final : public OpRewritePattern<TransposeOp> {
16111611

16121612
// Return if the input of 'transposeOp' is not defined by another transpose.
16131613
TransposeOp parentTransposeOp =
1614-
dyn_cast_or_null<TransposeOp>(transposeOp.vector().getDefiningOp());
1614+
transposeOp.vector().getDefiningOp<TransposeOp>();
16151615
if (!parentTransposeOp)
16161616
return failure();
16171617

@@ -1684,7 +1684,7 @@ OpFoldResult TupleGetOp::fold(ArrayRef<Attribute> operands) {
16841684
// into:
16851685
// %t = vector.tuple .., %e_i, .. // one less use
16861686
// %x = %e_i
1687-
if (auto tupleOp = dyn_cast_or_null<TupleOp>(getOperand().getDefiningOp()))
1687+
if (auto tupleOp = getOperand().getDefiningOp<TupleOp>())
16881688
return tupleOp.getOperand(getIndex());
16891689
return {};
16901690
}

0 commit comments

Comments
 (0)