Skip to content

Commit aad9076

Browse files
authored
Merge pull request #413 from schweitzpgi/ch-target4
Support for target specific lowering in the Tilikum bridge.
2 parents cdf0e02 + 6403fed commit aad9076

39 files changed

+1627
-337
lines changed

flang/include/flang/Optimizer/CodeGen/CGPasses.td

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,27 @@
1717
include "mlir/Pass/PassBase.td"
1818

1919
def CodeGenRewrite : FunctionPass<"cg-rewrite"> {
20-
let summary = "Rewrite some FIR ops into their code-gen forms.";
20+
let summary = "Rewrite some FIR ops into their code-gen forms. "
21+
"Fuse specific subgraphs into single Ops for code generation.";
2122
let constructor = "fir::createFirCodeGenRewritePass()";
23+
let dependentDialects = ["fir::FIROpsDialect"];
24+
}
25+
26+
def TargetRewrite : Pass<"target-rewrite", "mlir::ModuleOp"> {
27+
let summary = "Rewrite some FIR dialect into target specific forms. "
28+
"Certain abstractions in the FIR dialect need to be rewritten "
29+
"to reflect representations that may differ based on the "
30+
"target machine.";
31+
let constructor = "fir::createFirTargetRewritePass()";
32+
let dependentDialects = ["fir::FIROpsDialect"];
33+
let options = [
34+
Option<"noCharacterConversion", "no-character-conversion",
35+
"bool", /*default=*/"false",
36+
"Disable target-specific conversion of CHARACTER.">,
37+
Option<"noComplexConversion", "no-complex-conversion",
38+
"bool", /*default=*/"false",
39+
"Disable target-specific conversion of COMPLEX.">
40+
];
2241
}
2342

2443
#endif // FLANG_OPTIMIZER_CODEGEN_PASSES

flang/include/flang/Optimizer/CodeGen/CodeGen.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#ifndef OPTIMIZER_CODEGEN_CODEGEN_H
1010
#define OPTIMIZER_CODEGEN_CODEGEN_H
1111

12+
#include "mlir/IR/Module.h"
1213
#include "mlir/Pass/Pass.h"
1314
#include "mlir/Pass/PassRegistry.h"
1415
#include <memory>
@@ -21,6 +22,17 @@ struct NameUniquer;
2122
/// the code gen (to LLVM-IR dialect) conversion.
2223
std::unique_ptr<mlir::Pass> createFirCodeGenRewritePass();
2324

25+
/// FirTargetRewritePass options.
26+
struct TargetRewriteOptions {
27+
bool noCharacterConversion{};
28+
bool noComplexConversion{};
29+
};
30+
31+
/// Prerequiste pass for code gen. Perform intermediate rewrites to tailor the
32+
/// IR for the chosen target.
33+
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createFirTargetRewritePass(
34+
const TargetRewriteOptions &options = TargetRewriteOptions());
35+
2436
/// Convert FIR to the LLVM IR dialect
2537
std::unique_ptr<mlir::Pass> createFIRToLLVMPass(NameUniquer &uniquer);
2638

flang/include/flang/Optimizer/Dialect/FIRDialect.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@
99
#ifndef OPTIMIZER_DIALECT_FIRDIALECT_H
1010
#define OPTIMIZER_DIALECT_FIRDIALECT_H
1111

12+
#include "mlir/Conversion/Passes.h"
13+
#include "mlir/Dialect/Affine/Passes.h"
1214
#include "mlir/IR/Dialect.h"
1315
#include "mlir/InitAllDialects.h"
1416
#include "mlir/Pass/Pass.h"
1517
#include "mlir/Pass/PassRegistry.h"
16-
#include "mlir/Transforms/Passes.h"
1718
#include "mlir/Transforms/LocationSnapshot.h"
18-
#include "mlir/Dialect/Affine/Passes.h"
19-
#include "mlir/Conversion/Passes.h"
19+
#include "mlir/Transforms/Passes.h"
2020

2121
namespace fir {
2222

@@ -81,9 +81,7 @@ inline void registerGeneralPasses() {
8181
mlir::registerConvertAffineToStandardPass();
8282
}
8383

84-
inline void registerFIRPasses() {
85-
registerGeneralPasses();
86-
}
84+
inline void registerFIRPasses() { registerGeneralPasses(); }
8785

8886
} // namespace fir
8987

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,16 @@ def fir_Type : Type<CPred<"fir::isa_fir_or_std_type($_self)">,
3030
// Fortran intrinsic types
3131
def fir_CharacterType : Type<CPred<"$_self.isa<fir::CharacterType>()">,
3232
"FIR character type">;
33-
def fir_ComplexType : Type<CPred<"$_self.isa<fir::CplxType>()">,
33+
def fir_ComplexType : Type<CPred<"$_self.isa<fir::ComplexType>()">,
3434
"FIR complex type">;
35-
def fir_IntegerType : Type<CPred<"$_self.isa<fir::IntType>()">,
35+
def fir_IntegerType : Type<CPred<"$_self.isa<fir::IntegerType>()">,
3636
"FIR integer type">;
3737
def fir_LogicalType : Type<CPred<"$_self.isa<fir::LogicalType>()">,
3838
"FIR logical type">;
3939
def fir_RealType : Type<CPred<"$_self.isa<fir::RealType>()">,
4040
"FIR real type">;
41+
def fir_VectorType : Type<CPred<"$_self.isa<fir::VectorType>()">,
42+
"FIR vector type">;
4143

4244
// Generalized FIR and standard dialect types representing intrinsic types
4345
def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
@@ -59,7 +61,7 @@ def fir_SequenceType : Type<CPred<"$_self.isa<fir::SequenceType>()">,
5961
// Composable types
6062
def AnyCompositeLike : TypeConstraint<Or<[fir_RecordType.predicate,
6163
fir_SequenceType.predicate, fir_ComplexType.predicate,
62-
IsTupleTypePred]>, "any composite">;
64+
fir_VectorType.predicate, IsTupleTypePred]>, "any composite">;
6365

6466
// Reference to an entity type
6567
def fir_ReferenceType : Type<CPred<"$_self.isa<fir::ReferenceType>()">,
@@ -77,6 +79,10 @@ def fir_PointerType : Type<CPred<"$_self.isa<fir::PointerType>()">,
7779
def AnyReferenceLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
7880
fir_HeapType.predicate, fir_PointerType.predicate]>, "any reference">;
7981

82+
// The legal types of global symbols
83+
def AnyAddressableLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
84+
FunctionType.predicate]>, "any addressable">;
85+
8086
// A descriptor tuple (captures a reference to an entity and other information)
8187
def fir_BoxType : Type<CPred<"$_self.isa<fir::BoxType>()">, "box type">;
8288

@@ -723,7 +729,7 @@ class fir_IntegralSwitchTerminatorOp<string mnemonic,
723729
let verifier = [{
724730
if (!(getSelector().getType().isa<mlir::IntegerType>() ||
725731
getSelector().getType().isa<mlir::IndexType>() ||
726-
getSelector().getType().isa<fir::IntType>()))
732+
getSelector().getType().isa<fir::IntegerType>()))
727733
return emitOpError("must be an integer");
728734
auto cases = getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue();
729735
auto count = getNumDest();
@@ -847,7 +853,7 @@ def fir_SelectCaseOp : fir_SwitchTerminatorOp<"select_case"> {
847853
let verifier = [{
848854
if (!(getSelector().getType().isa<mlir::IntegerType>() ||
849855
getSelector().getType().isa<mlir::IndexType>() ||
850-
getSelector().getType().isa<fir::IntType>() ||
856+
getSelector().getType().isa<fir::IntegerType>() ||
851857
getSelector().getType().isa<fir::LogicalType>() ||
852858
getSelector().getType().isa<fir::CharacterType>()))
853859
return emitOpError("must be an integer, character, or logical");
@@ -2349,6 +2355,8 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
23492355
let extraClassDeclaration = [{
23502356
static constexpr StringRef calleeAttrName() { return "callee"; }
23512357

2358+
mlir::FunctionType getFunctionType();
2359+
23522360
/// Get the argument operands to the called function.
23532361
operand_range getArgOperands() {
23542362
if (auto calling = getAttrOfType<SymbolRefAttr>(calleeAttrName()))
@@ -2410,7 +2418,6 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
24102418
parser.resolveOperands(
24112419
operands, calleeType.getInputs(), calleeLoc, result.operands))
24122420
return mlir::failure();
2413-
result.addAttribute("fn_type", mlir::TypeAttr::get(calleeType));
24142421
return mlir::success();
24152422
}];
24162423

@@ -2422,10 +2429,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
24222429
p.printOperands(args());
24232430
}
24242431
p << ')';
2425-
p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"});
2426-
auto resTy{getResultTypes()};
2427-
llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes());
2428-
p << " : " << mlir::FunctionType::get(argTy, resTy, getContext());
2432+
p.printOptionalAttrDict(getAttrs(), {"method"});
2433+
p << " : " << getFunctionType();
24292434
}];
24302435

24312436
let extraClassDeclaration = [{
@@ -2676,7 +2681,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
26762681
}];
26772682

26782683
let verifier = [{
2679-
if (!getType().isa<fir::CplxType>())
2684+
if (!getType().isa<fir::ComplexType>())
26802685
return emitOpError("must be a !fir.complex type");
26812686
return mlir::success();
26822687
}];
@@ -2747,7 +2752,8 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
27472752

27482753
let description = [{
27492754
Convert a symbol (a function or global reference) to an SSA-value to be
2750-
used in other Operations.
2755+
used in other Operations. References to Fortran symbols are distinguished
2756+
via this operation from other arbitrary constant values.
27512757

27522758
```mlir
27532759
%p = fir.address_of(@symbol) : !fir.ref<f64>
@@ -2756,7 +2762,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {
27562762

27572763
let arguments = (ins SymbolRefAttr:$symbol);
27582764

2759-
let results = (outs fir_ReferenceType:$resTy);
2765+
let results = (outs AnyAddressableLike:$resTy);
27602766

27612767
let assemblyFormat = "`(` $symbol `)` attr-dict `:` type($resTy)";
27622768
}
@@ -2814,8 +2820,8 @@ def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> {
28142820

28152821
def FortranTypeAttr : Attr<And<[CPred<"$_self.isa<TypeAttr>()">,
28162822
Or<[CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::CharacterType>()">,
2817-
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::CplxType>()">,
2818-
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::IntType>()">,
2823+
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::ComplexType>()">,
2824+
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::IntegerType>()">,
28192825
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::LogicalType>()">,
28202826
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::RealType>()">,
28212827
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::RecordType>()">]>]>,

flang/include/flang/Optimizer/Dialect/FIRType.h

Lines changed: 39 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@ struct BoxTypeStorage;
4343
struct BoxCharTypeStorage;
4444
struct BoxProcTypeStorage;
4545
struct CharacterTypeStorage;
46-
struct CplxTypeStorage;
46+
struct ComplexTypeStorage;
4747
struct FieldTypeStorage;
4848
struct HeapTypeStorage;
49-
struct IntTypeStorage;
49+
struct IntegerTypeStorage;
5050
struct LenTypeStorage;
5151
struct LogicalTypeStorage;
5252
struct PointerTypeStorage;
@@ -58,6 +58,7 @@ struct ShapeTypeStorage;
5858
struct ShapeShiftTypeStorage;
5959
struct SliceTypeStorage;
6060
struct TypeDescTypeStorage;
61+
struct VectorTypeStorage;
6162
} // namespace detail
6263

6364
// These isa_ routines follow the precedent of llvm::isa_or_null<>
@@ -125,11 +126,11 @@ class CharacterType
125126
/// Model of a Fortran COMPLEX intrinsic type, including the KIND type
126127
/// parameter. COMPLEX is a floating point type with a real and imaginary
127128
/// member.
128-
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
129-
detail::CplxTypeStorage> {
129+
class ComplexType : public mlir::Type::TypeBase<fir::ComplexType, mlir::Type,
130+
detail::ComplexTypeStorage> {
130131
public:
131132
using Base::Base;
132-
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
133+
static fir::ComplexType get(mlir::MLIRContext *ctxt, KindTy kind);
133134

134135
/// Get the corresponding fir.real<k> type.
135136
mlir::Type getElementType() const;
@@ -139,19 +140,18 @@ class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
139140

140141
/// Model of a Fortran INTEGER intrinsic type, including the KIND type
141142
/// parameter.
142-
class IntType
143-
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
143+
class IntegerType : public mlir::Type::TypeBase<fir::IntegerType, mlir::Type,
144+
detail::IntegerTypeStorage> {
144145
public:
145146
using Base::Base;
146-
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
147+
static fir::IntegerType get(mlir::MLIRContext *ctxt, KindTy kind);
147148
KindTy getFKind() const;
148149
};
149150

150151
/// Model of a Fortran LOGICAL intrinsic type, including the KIND type
151152
/// parameter.
152-
class LogicalType
153-
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
154-
detail::LogicalTypeStorage> {
153+
class LogicalType : public mlir::Type::TypeBase<LogicalType, mlir::Type,
154+
detail::LogicalTypeStorage> {
155155
public:
156156
using Base::Base;
157157
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
@@ -414,14 +414,6 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
414414
llvm::StringRef name);
415415
};
416416

417-
mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
418-
419-
void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
420-
421-
/// Guarantee `type` is a scalar integral type (standard Integer, standard
422-
/// Index, or FIR Int). Aborts execution if condition is false.
423-
void verifyIntegralType(mlir::Type type);
424-
425417
/// Is `t` a FIR Real or MLIR Float type?
426418
inline bool isa_real(mlir::Type t) {
427419
return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
@@ -430,12 +422,38 @@ inline bool isa_real(mlir::Type t) {
430422
/// Is `t` an integral type?
431423
inline bool isa_integer(mlir::Type t) {
432424
return t.isa<mlir::IndexType>() || t.isa<mlir::IntegerType>() ||
433-
t.isa<fir::IntType>();
425+
t.isa<fir::IntegerType>();
434426
}
435427

428+
/// Replacement for the standard dialect's vector type. Relaxes some of the
429+
/// constraints and imposes some new ones.
430+
class VectorType : public mlir::Type::TypeBase<fir::VectorType, mlir::Type,
431+
detail::VectorTypeStorage> {
432+
public:
433+
using Base::Base;
434+
435+
static fir::VectorType get(uint64_t len, mlir::Type eleTy);
436+
mlir::Type getEleTy() const;
437+
uint64_t getLen() const;
438+
439+
static mlir::LogicalResult
440+
verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy);
441+
static bool isValidElementType(mlir::Type t) {
442+
return isa_real(t) || isa_integer(t);
443+
}
444+
};
445+
446+
mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);
447+
448+
void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);
449+
450+
/// Guarantee `type` is a scalar integral type (standard Integer, standard
451+
/// Index, or FIR Int). Aborts execution if condition is false.
452+
void verifyIntegralType(mlir::Type type);
453+
436454
/// Is `t` a FIR or MLIR Complex type?
437455
inline bool isa_complex(mlir::Type t) {
438-
return t.isa<fir::CplxType>() || t.isa<mlir::ComplexType>();
456+
return t.isa<fir::ComplexType>() || t.isa<mlir::ComplexType>();
439457
}
440458

441459
inline bool isa_char_string(mlir::Type t) {

flang/include/flang/Optimizer/Support/FIRContext.h

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,21 +29,24 @@ namespace fir {
2929
class KindMapping;
3030
struct NameUniquer;
3131

32-
/// Set the target triple for the module.
32+
/// Set the target triple for the module. `triple` must not be deallocated while
33+
/// module `mod` is still live.
3334
void setTargetTriple(mlir::ModuleOp mod, llvm::Triple &triple);
3435

3536
/// Get a pointer to the Triple instance from the Module. If none was set,
3637
/// returns a nullptr.
3738
llvm::Triple *getTargetTriple(mlir::ModuleOp mod);
3839

39-
/// Set the name uniquer for the module.
40+
/// Set the name uniquer for the module. `uniquer` must not be deallocated while
41+
/// module `mod` is still live.
4042
void setNameUniquer(mlir::ModuleOp mod, NameUniquer &uniquer);
4143

4244
/// Get a pointer to the NameUniquer instance from the Module. If none was set,
4345
/// returns a nullptr.
4446
NameUniquer *getNameUniquer(mlir::ModuleOp mod);
4547

46-
/// Set the kind mapping for the module.
48+
/// Set the kind mapping for the module. `kindMap` must not be deallocated while
49+
/// module `mod` is still live.
4750
void setKindMapping(mlir::ModuleOp mod, KindMapping &kindMap);
4851

4952
/// Get a pointer to the KindMapping instance from the Module. If none was set,
@@ -53,6 +56,9 @@ KindMapping *getKindMapping(mlir::ModuleOp mod);
5356
/// Helper for determining the target from the host, etc. Tools may use this
5457
/// function to provide a consistent interpretation of the `--target=<string>`
5558
/// command-line option.
59+
/// An empty string ("") or "default" will specify that the default triple
60+
/// should be used. "native" will specify that the host machine be used to
61+
/// construct the triple.
5662
std::string determineTargetTriple(llvm::StringRef triple);
5763

5864
} // namespace fir

flang/include/flang/Optimizer/Transforms/RewritePatterns.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,9 @@ def createConstantOp
5151
"rewriter.getIndexAttr($1.dyn_cast<IntegerAttr>().getInt()))">;
5252

5353
def ForwardConstantConvertPattern
54-
: Pat<(fir_ConvertOp:$res (ConstantOp $attr)),
54+
: Pat<(fir_ConvertOp:$res (ConstantOp:$cnt $attr)),
5555
(createConstantOp $res, $attr),
56-
[(IndexTypePred $res)]>;
56+
[(IndexTypePred $res)
57+
,(IntegerTypePred $cnt)]>;
5758

5859
#endif // FIR_REWRITE_PATTERNS

flang/lib/Lower/ComplexExpr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
mlir::Type
1717
Fortran::lower::ComplexExprHelper::getComplexPartType(mlir::Type complexType) {
1818
return Fortran::lower::convertReal(
19-
builder.getContext(), complexType.cast<fir::CplxType>().getFKind());
19+
builder.getContext(), complexType.cast<fir::ComplexType>().getFKind());
2020
}
2121

2222
mlir::Type
@@ -27,7 +27,7 @@ Fortran::lower::ComplexExprHelper::getComplexPartType(mlir::Value cplx) {
2727
mlir::Value Fortran::lower::ComplexExprHelper::createComplex(fir::KindTy kind,
2828
mlir::Value real,
2929
mlir::Value imag) {
30-
auto complexTy = fir::CplxType::get(builder.getContext(), kind);
30+
auto complexTy = fir::ComplexType::get(builder.getContext(), kind);
3131
mlir::Value und = builder.create<fir::UndefOp>(loc, complexTy);
3232
return insert<Part::Imag>(insert<Part::Real>(und, real), imag);
3333
}

0 commit comments

Comments
 (0)