@@ -142,6 +142,19 @@ TypeDescType parseTypeDesc(mlir::DialectAsmParser &parser, mlir::Location loc) {
142
142
return parseTypeSingleton<TypeDescType>(parser, loc);
143
143
}
144
144
145
+ // `vector` `<` len `:` type `>`
146
+ fir::VectorType parseVector (mlir::DialectAsmParser &parser,
147
+ mlir::Location loc) {
148
+ int64_t len = 0 ;
149
+ mlir::Type eleTy;
150
+ if (parser.parseLess () || parser.parseInteger (len) || parser.parseColon () ||
151
+ parser.parseType (eleTy) || parser.parseGreater ()) {
152
+ parser.emitError (parser.getNameLoc (), " invalid vector type" );
153
+ return {};
154
+ }
155
+ return fir::VectorType::get (len, eleTy);
156
+ }
157
+
145
158
// `void`
146
159
mlir::Type parseVoid (mlir::DialectAsmParser &parser) {
147
160
return parser.getBuilder ().getNoneType ();
@@ -346,6 +359,8 @@ mlir::Type fir::parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser) {
346
359
return parseDerived (parser, loc);
347
360
if (typeNameLit == " void" )
348
361
return parseVoid (parser);
362
+ if (typeNameLit == " vector" )
363
+ return parseVector (parser, loc);
349
364
350
365
parser.emitError (parser.getNameLoc (), " unknown FIR type " + typeNameLit);
351
366
return {};
@@ -790,6 +805,39 @@ struct TypeDescTypeStorage : public mlir::TypeStorage {
790
805
explicit TypeDescTypeStorage (mlir::Type ofTy) : ofTy{ofTy} {}
791
806
};
792
807
808
+ // / Vector type storage
809
+ struct VectorTypeStorage : public mlir ::TypeStorage {
810
+ using KeyTy = std::tuple<uint64_t , mlir::Type>;
811
+
812
+ static unsigned hashKey (const KeyTy &key) {
813
+ return llvm::hash_combine (std::get<uint64_t >(key),
814
+ std::get<mlir::Type>(key));
815
+ }
816
+
817
+ bool operator ==(const KeyTy &key) const {
818
+ return key == KeyTy{getLen (), getEleTy ()};
819
+ }
820
+
821
+ static VectorTypeStorage *construct (mlir::TypeStorageAllocator &allocator,
822
+ const KeyTy &key) {
823
+ auto *storage = allocator.allocate <VectorTypeStorage>();
824
+ return new (storage)
825
+ VectorTypeStorage{std::get<uint64_t >(key), std::get<mlir::Type>(key)};
826
+ }
827
+
828
+ uint64_t getLen () const { return len; }
829
+ mlir::Type getEleTy () const { return eleTy; }
830
+
831
+ protected:
832
+ uint64_t len;
833
+ mlir::Type eleTy;
834
+
835
+ private:
836
+ VectorTypeStorage () = delete ;
837
+ explicit VectorTypeStorage (uint64_t len, mlir::Type eleTy)
838
+ : len{len}, eleTy{eleTy} {}
839
+ };
840
+
793
841
} // namespace detail
794
842
795
843
template <typename A, typename B>
@@ -1069,12 +1117,34 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
1069
1117
eleTy.isa <BoxProcType>() || eleTy.isa <FieldType>() ||
1070
1118
eleTy.isa <LenType>() || eleTy.isa <HeapType>() ||
1071
1119
eleTy.isa <PointerType>() || eleTy.isa <ReferenceType>() ||
1072
- eleTy.isa <TypeDescType>() || eleTy.isa <SequenceType>())
1120
+ eleTy.isa <TypeDescType>() || eleTy.isa <fir::VectorType>() ||
1121
+ eleTy.isa <SequenceType>())
1073
1122
return mlir::emitError (loc, " cannot build an array of this element type: " )
1074
1123
<< eleTy << ' \n ' ;
1075
1124
return mlir::success ();
1076
1125
}
1077
1126
1127
+ // ===----------------------------------------------------------------------===//
1128
+ // Vector type
1129
+ // ===----------------------------------------------------------------------===//
1130
+
1131
+ fir::VectorType fir::VectorType::get (uint64_t len, mlir::Type eleTy) {
1132
+ return Base::get (eleTy.getContext (), len, eleTy);
1133
+ }
1134
+
1135
+ mlir::Type fir::VectorType::getEleTy () const { return getImpl ()->getEleTy (); }
1136
+
1137
+ uint64_t fir::VectorType::getLen () const { return getImpl ()->getLen (); }
1138
+
1139
+ mlir::LogicalResult
1140
+ fir::VectorType::verifyConstructionInvariants (mlir::Location loc, uint64_t len,
1141
+ mlir::Type eleTy) {
1142
+ if (!(fir::isa_real (eleTy) || fir::isa_integer (eleTy)))
1143
+ return mlir::emitError (loc, " cannot build a vector of type " )
1144
+ << eleTy << ' \n ' ;
1145
+ return mlir::success ();
1146
+ }
1147
+
1078
1148
// compare if two shapes are equivalent
1079
1149
bool fir::operator ==(const SequenceType::Shape &sh_1,
1080
1150
const SequenceType::Shape &sh_2) {
@@ -1302,4 +1372,10 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
1302
1372
os << ' >' ;
1303
1373
return ;
1304
1374
}
1375
+ if (auto type = ty.dyn_cast <fir::VectorType>()) {
1376
+ os << " vector<" << type.getLen () << ' :' ;
1377
+ p.printType (type.getEleTy ());
1378
+ os << ' >' ;
1379
+ return ;
1380
+ }
1305
1381
}
0 commit comments