Skip to content

Commit ed08252

Browse files
schweitzpgimemfrob
authored and
memfrob
committed
[flang][fir] Add FIR's vector type.
This patch adds support for `!fir.vector`, a rank one, constant length data type. flang-compiler/f18-llvm-project#413 Differential Revision: https://reviews.llvm.org/D96162
1 parent 5542253 commit ed08252

File tree

5 files changed

+113
-11
lines changed

5 files changed

+113
-11
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ def fir_LogicalType : Type<CPred<"$_self.isa<fir::LogicalType>()">,
4040
"FIR logical type">;
4141
def fir_RealType : Type<CPred<"$_self.isa<fir::RealType>()">,
4242
"FIR real type">;
43+
def fir_VectorType : Type<CPred<"$_self.isa<fir::VectorType>()">,
44+
"FIR vector type">;
4345

4446
// Generalized FIR and standard dialect types representing intrinsic types
4547
def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
@@ -61,7 +63,7 @@ def fir_SequenceType : Type<CPred<"$_self.isa<fir::SequenceType>()">,
6163
// Composable types
6264
def AnyCompositeLike : TypeConstraint<Or<[fir_RecordType.predicate,
6365
fir_SequenceType.predicate, fir_ComplexType.predicate,
64-
IsTupleTypePred]>, "any composite">;
66+
fir_VectorType.predicate, IsTupleTypePred]>, "any composite">;
6567

6668
// Reference to an entity type
6769
def fir_ReferenceType : Type<CPred<"$_self.isa<fir::ReferenceType>()">,

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ struct RecordTypeStorage;
5555
struct ReferenceTypeStorage;
5656
struct SequenceTypeStorage;
5757
struct TypeDescTypeStorage;
58+
struct VectorTypeStorage;
5859
} // namespace detail
5960

6061
// These isa_ routines follow the precedent of llvm::isa_or_null<>
@@ -363,14 +364,6 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
363364
llvm::StringRef name);
364365
};
365366

366-
mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
367-
368-
void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
369-
370-
/// Guarantee `type` is a scalar integral type (standard Integer, standard
371-
/// Index, or FIR Int). Aborts execution if condition is false.
372-
void verifyIntegralType(mlir::Type type);
373-
374367
/// Is `t` a FIR Real or MLIR Float type?
375368
inline bool isa_real(mlir::Type t) {
376369
return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
@@ -382,6 +375,33 @@ inline bool isa_integer(mlir::Type t) {
382375
t.isa<fir::IntegerType>();
383376
}
384377

378+
/// Replacement for the builtin vector type.
379+
/// The FIR vector type is always rank one. It's size is always a constant.
380+
/// A vector's element type must be real or integer.
381+
class VectorType : public mlir::Type::TypeBase<fir::VectorType, mlir::Type,
382+
detail::VectorTypeStorage> {
383+
public:
384+
using Base::Base;
385+
386+
static fir::VectorType get(uint64_t len, mlir::Type eleTy);
387+
mlir::Type getEleTy() const;
388+
uint64_t getLen() const;
389+
390+
static mlir::LogicalResult
391+
verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy);
392+
static bool isValidElementType(mlir::Type t) {
393+
return isa_real(t) || isa_integer(t);
394+
}
395+
};
396+
397+
mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
398+
399+
void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
400+
401+
/// Guarantee `type` is a scalar integral type (standard Integer, standard
402+
/// Index, or FIR Int). Aborts execution if condition is false.
403+
void verifyIntegralType(mlir::Type type);
404+
385405
/// Is `t` a FIR or MLIR Complex type?
386406
inline bool isa_complex(mlir::Type t) {
387407
return t.isa<fir::ComplexType>() || t.isa<mlir::ComplexType>();

flang/lib/Optimizer/Dialect/FIRDialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ fir::FIROpsDialect::FIROpsDialect(mlir::MLIRContext *ctx)
1818
addTypes<BoxType, BoxCharType, BoxProcType, CharacterType, fir::ComplexType,
1919
FieldType, HeapType, fir::IntegerType, LenType, LogicalType,
2020
PointerType, RealType, RecordType, ReferenceType, SequenceType,
21-
TypeDescType>();
21+
TypeDescType, fir::VectorType>();
2222
addAttributes<ClosedIntervalAttr, ExactTypeAttr, LowerBoundAttr,
2323
PointIntervalAttr, RealAttr, SubclassAttr, UpperBoundAttr>();
2424
addOperations<

flang/lib/Optimizer/Dialect/FIRType.cpp

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,19 @@ TypeDescType parseTypeDesc(mlir::DialectAsmParser &parser, mlir::Location loc) {
142142
return parseTypeSingleton<TypeDescType>(parser, loc);
143143
}
144144

145+
// `vector` `<` len `:` type `>`
146+
fir::VectorType parseVector(mlir::DialectAsmParser &parser,
147+
mlir::Location loc) {
148+
int64_t len = 0;
149+
mlir::Type eleTy;
150+
if (parser.parseLess() || parser.parseInteger(len) || parser.parseColon() ||
151+
parser.parseType(eleTy) || parser.parseGreater()) {
152+
parser.emitError(parser.getNameLoc(), "invalid vector type");
153+
return {};
154+
}
155+
return fir::VectorType::get(len, eleTy);
156+
}
157+
145158
// `void`
146159
mlir::Type parseVoid(mlir::DialectAsmParser &parser) {
147160
return parser.getBuilder().getNoneType();
@@ -346,6 +359,8 @@ mlir::Type fir::parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser) {
346359
return parseDerived(parser, loc);
347360
if (typeNameLit == "void")
348361
return parseVoid(parser);
362+
if (typeNameLit == "vector")
363+
return parseVector(parser, loc);
349364

350365
parser.emitError(parser.getNameLoc(), "unknown FIR type " + typeNameLit);
351366
return {};
@@ -790,6 +805,39 @@ struct TypeDescTypeStorage : public mlir::TypeStorage {
790805
explicit TypeDescTypeStorage(mlir::Type ofTy) : ofTy{ofTy} {}
791806
};
792807

808+
/// Vector type storage
809+
struct VectorTypeStorage : public mlir::TypeStorage {
810+
using KeyTy = std::tuple<uint64_t, mlir::Type>;
811+
812+
static unsigned hashKey(const KeyTy &key) {
813+
return llvm::hash_combine(std::get<uint64_t>(key),
814+
std::get<mlir::Type>(key));
815+
}
816+
817+
bool operator==(const KeyTy &key) const {
818+
return key == KeyTy{getLen(), getEleTy()};
819+
}
820+
821+
static VectorTypeStorage *construct(mlir::TypeStorageAllocator &allocator,
822+
const KeyTy &key) {
823+
auto *storage = allocator.allocate<VectorTypeStorage>();
824+
return new (storage)
825+
VectorTypeStorage{std::get<uint64_t>(key), std::get<mlir::Type>(key)};
826+
}
827+
828+
uint64_t getLen() const { return len; }
829+
mlir::Type getEleTy() const { return eleTy; }
830+
831+
protected:
832+
uint64_t len;
833+
mlir::Type eleTy;
834+
835+
private:
836+
VectorTypeStorage() = delete;
837+
explicit VectorTypeStorage(uint64_t len, mlir::Type eleTy)
838+
: len{len}, eleTy{eleTy} {}
839+
};
840+
793841
} // namespace detail
794842

795843
template <typename A, typename B>
@@ -1069,12 +1117,34 @@ mlir::LogicalResult fir::SequenceType::verifyConstructionInvariants(
10691117
eleTy.isa<BoxProcType>() || eleTy.isa<FieldType>() ||
10701118
eleTy.isa<LenType>() || eleTy.isa<HeapType>() ||
10711119
eleTy.isa<PointerType>() || eleTy.isa<ReferenceType>() ||
1072-
eleTy.isa<TypeDescType>() || eleTy.isa<SequenceType>())
1120+
eleTy.isa<TypeDescType>() || eleTy.isa<fir::VectorType>() ||
1121+
eleTy.isa<SequenceType>())
10731122
return mlir::emitError(loc, "cannot build an array of this element type: ")
10741123
<< eleTy << '\n';
10751124
return mlir::success();
10761125
}
10771126

1127+
//===----------------------------------------------------------------------===//
1128+
// Vector type
1129+
//===----------------------------------------------------------------------===//
1130+
1131+
fir::VectorType fir::VectorType::get(uint64_t len, mlir::Type eleTy) {
1132+
return Base::get(eleTy.getContext(), len, eleTy);
1133+
}
1134+
1135+
mlir::Type fir::VectorType::getEleTy() const { return getImpl()->getEleTy(); }
1136+
1137+
uint64_t fir::VectorType::getLen() const { return getImpl()->getLen(); }
1138+
1139+
mlir::LogicalResult
1140+
fir::VectorType::verifyConstructionInvariants(mlir::Location loc, uint64_t len,
1141+
mlir::Type eleTy) {
1142+
if (!(fir::isa_real(eleTy) || fir::isa_integer(eleTy)))
1143+
return mlir::emitError(loc, "cannot build a vector of type ")
1144+
<< eleTy << '\n';
1145+
return mlir::success();
1146+
}
1147+
10781148
// compare if two shapes are equivalent
10791149
bool fir::operator==(const SequenceType::Shape &sh_1,
10801150
const SequenceType::Shape &sh_2) {
@@ -1302,4 +1372,10 @@ void fir::printFirType(FIROpsDialect *, mlir::Type ty,
13021372
os << '>';
13031373
return;
13041374
}
1375+
if (auto type = ty.dyn_cast<fir::VectorType>()) {
1376+
os << "vector<" << type.getLen() << ':';
1377+
p.printType(type.getEleTy());
1378+
os << '>';
1379+
return;
1380+
}
13051381
}

flang/test/Fir/fir-types.fir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,7 @@ func private @box5() -> !fir.box<!fir.type<derived3{f:f32}>>
7171
// CHECK-LABEL: func private @oth3() -> !fir.tdesc<!fir.type<derived7{f1:f32,f2:f32}>>
7272
func private @oth2() -> !fir.field
7373
func private @oth3() -> !fir.tdesc<!fir.type<derived7{f1:f32,f2:f32}>>
74+
75+
// FIR vector
76+
// CHECK-LABEL: func private @vecty(i1) -> !fir.vector<10:i32>
77+
func private @vecty(i1) -> !fir.vector<10:i32>

0 commit comments

Comments
 (0)