@@ -1111,14 +1111,43 @@ LogicalResult GenerateOp::reifyResultShapes(
1111
1111
return success ();
1112
1112
}
1113
1113
1114
+ // / Extract operands and shape from a tensor with dynamic extents.
1115
+ static void operandsAndShape (TensorType resultType,
1116
+ Operation::operand_range dynamicExtents,
1117
+ SmallVectorImpl<Value> &newOperands,
1118
+ SmallVectorImpl<int64_t > &newShape) {
1119
+ auto operandsIt = dynamicExtents.begin ();
1120
+ for (int64_t dim : resultType.getShape ()) {
1121
+ if (!ShapedType::isDynamic (dim)) {
1122
+ newShape.push_back (dim);
1123
+ continue ;
1124
+ }
1125
+ APInt index;
1126
+ if (!matchPattern (*operandsIt, m_ConstantInt (&index))) {
1127
+ newShape.push_back (ShapedType::kDynamic );
1128
+ newOperands.push_back (*operandsIt++);
1129
+ continue ;
1130
+ }
1131
+ newShape.push_back (index.getSExtValue ());
1132
+ operandsIt++;
1133
+ }
1134
+ }
1135
+
1114
1136
LogicalResult GenerateOp::verify () {
1115
1137
// Ensure that the tensor type has as many dynamic dimensions as are
1116
1138
// specified by the operands.
1117
- RankedTensorType resultTy = llvm::cast<RankedTensorType>(getType ());
1118
- if (getNumOperands () != resultTy .getNumDynamicDims ())
1139
+ RankedTensorType resultType = llvm::cast<RankedTensorType>(getType ());
1140
+ if (getNumOperands () != resultType .getNumDynamicDims ())
1119
1141
return emitError (" must have as many index operands as dynamic extents "
1120
1142
" in the result type" );
1121
-
1143
+ // Ensure operands are non-negative.
1144
+ SmallVector<Value> newOperands;
1145
+ SmallVector<int64_t > newShape;
1146
+ operandsAndShape (resultType, getDynamicExtents (), newOperands, newShape);
1147
+ for (int64_t newdim : newShape) {
1148
+ if (newdim < 0 && !ShapedType::isDynamic (newdim))
1149
+ return emitError (" tensor dimensions must be non-negative" );
1150
+ }
1122
1151
return success ();
1123
1152
}
1124
1153
@@ -1176,24 +1205,11 @@ struct StaticTensorGenerate : public OpRewritePattern<GenerateOp> {
1176
1205
if (resultType.hasStaticShape ())
1177
1206
return failure ();
1178
1207
1179
- SmallVector<Value, 4 > newOperands;
1180
- SmallVector<int64_t , 4 > newShape;
1181
- auto operandsIt = tensorFromElements.getDynamicExtents ().begin ();
1182
-
1183
- for (int64_t dim : resultType.getShape ()) {
1184
- if (!ShapedType::isDynamic (dim)) {
1185
- newShape.push_back (dim);
1186
- continue ;
1187
- }
1188
- APInt index;
1189
- if (!matchPattern (*operandsIt, m_ConstantInt (&index))) {
1190
- newShape.push_back (ShapedType::kDynamic );
1191
- newOperands.push_back (*operandsIt++);
1192
- continue ;
1193
- }
1194
- newShape.push_back (index.getSExtValue ());
1195
- operandsIt++;
1196
- }
1208
+ Operation::operand_range dynamicExtents =
1209
+ tensorFromElements.getDynamicExtents ();
1210
+ SmallVector<Value> newOperands;
1211
+ SmallVector<int64_t > newShape;
1212
+ operandsAndShape (resultType, dynamicExtents, newOperands, newShape);
1197
1213
1198
1214
if (newOperands.size () == tensorFromElements.getDynamicExtents ().size ())
1199
1215
return failure ();
0 commit comments