Skip to content

Commit 64bd5bb

Browse files
rikhuijzerftynse
authored andcommitted
[mlir] Avoid tensor canonicalizer crash on negative dimensions
Fixes #59703. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D151611
1 parent c76a3e7 commit 64bd5bb

File tree

2 files changed

+51
-21
lines changed

2 files changed

+51
-21
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1111,14 +1111,43 @@ LogicalResult GenerateOp::reifyResultShapes(
11111111
return success();
11121112
}
11131113

1114+
/// Extract operands and shape from a tensor with dynamic extents.
1115+
static void operandsAndShape(TensorType resultType,
1116+
Operation::operand_range dynamicExtents,
1117+
SmallVectorImpl<Value> &newOperands,
1118+
SmallVectorImpl<int64_t> &newShape) {
1119+
auto operandsIt = dynamicExtents.begin();
1120+
for (int64_t dim : resultType.getShape()) {
1121+
if (!ShapedType::isDynamic(dim)) {
1122+
newShape.push_back(dim);
1123+
continue;
1124+
}
1125+
APInt index;
1126+
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
1127+
newShape.push_back(ShapedType::kDynamic);
1128+
newOperands.push_back(*operandsIt++);
1129+
continue;
1130+
}
1131+
newShape.push_back(index.getSExtValue());
1132+
operandsIt++;
1133+
}
1134+
}
1135+
11141136
LogicalResult GenerateOp::verify() {
11151137
// Ensure that the tensor type has as many dynamic dimensions as are
11161138
// specified by the operands.
1117-
RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType());
1118-
if (getNumOperands() != resultTy.getNumDynamicDims())
1139+
RankedTensorType resultType = llvm::cast<RankedTensorType>(getType());
1140+
if (getNumOperands() != resultType.getNumDynamicDims())
11191141
return emitError("must have as many index operands as dynamic extents "
11201142
"in the result type");
1121-
1143+
// Ensure operands are non-negative.
1144+
SmallVector<Value> newOperands;
1145+
SmallVector<int64_t> newShape;
1146+
operandsAndShape(resultType, getDynamicExtents(), newOperands, newShape);
1147+
for (int64_t newdim : newShape) {
1148+
if (newdim < 0 && !ShapedType::isDynamic(newdim))
1149+
return emitError("tensor dimensions must be non-negative");
1150+
}
11221151
return success();
11231152
}
11241153

@@ -1176,24 +1205,11 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
11761205
if (resultType.hasStaticShape())
11771206
return failure();
11781207

1179-
SmallVector<Value, 4> newOperands;
1180-
SmallVector<int64_t, 4> newShape;
1181-
auto operandsIt = tensorFromElements.getDynamicExtents().begin();
1182-
1183-
for (int64_t dim : resultType.getShape()) {
1184-
if (!ShapedType::isDynamic(dim)) {
1185-
newShape.push_back(dim);
1186-
continue;
1187-
}
1188-
APInt index;
1189-
if (!matchPattern(*operandsIt, m_ConstantInt(&index))) {
1190-
newShape.push_back(ShapedType::kDynamic);
1191-
newOperands.push_back(*operandsIt++);
1192-
continue;
1193-
}
1194-
newShape.push_back(index.getSExtValue());
1195-
operandsIt++;
1196-
}
1208+
Operation::operand_range dynamicExtents =
1209+
tensorFromElements.getDynamicExtents();
1210+
SmallVector<Value> newOperands;
1211+
SmallVector<int64_t> newShape;
1212+
operandsAndShape(resultType, dynamicExtents, newOperands, newShape);
11971213

11981214
if (newOperands.size() == tensorFromElements.getDynamicExtents().size())
11991215
return failure();

mlir/test/Dialect/Tensor/invalid.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,20 @@ func.func @tensor.generate(%m : index, %n : index)
112112
} : tensor<?x3x?xf32>
113113
return %tnsr : tensor<?x3x?xf32>
114114
}
115+
116+
// -----
117+
118+
func.func @generate_negative_size() -> tensor<?x8xi32> {
119+
%cst = arith.constant 0 : i32
120+
%size = index.constant -128
121+
// expected-error@+1 {{tensor dimensions must be non-negative}}
122+
%tensor = tensor.generate %size {
123+
^bb0(%arg0: index, %arg1: index):
124+
tensor.yield %cst : i32
125+
} : tensor<?x8xi32>
126+
return %tensor : tensor<?x8xi32>
127+
}
128+
115129
// -----
116130

117131
func.func @tensor.reshape_element_type_mismatch(

0 commit comments

Comments
 (0)