Skip to content

Support for target specific lowering in the Tilikum bridge. #413

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion flang/include/flang/Optimizer/CodeGen/CGPasses.td
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,27 @@
include "mlir/Pass/PassBase.td"

def CodeGenRewrite : FunctionPass<"cg-rewrite"> {
let summary = "Rewrite some FIR ops into their code-gen forms.";
let summary = "Rewrite some FIR ops into their code-gen forms. "
"Fuse specific subgraphs into single Ops for code generation.";
let constructor = "fir::createFirCodeGenRewritePass()";
let dependentDialects = ["fir::FIROpsDialect"];
}

def TargetRewrite : Pass<"target-rewrite", "mlir::ModuleOp"> {
let summary = "Rewrite some FIR dialect into target specific forms. "
"Certain abstractions in the FIR dialect need to be rewritten "
"to reflect representations that may differ based on the "
"target machine.";
let constructor = "fir::createFirTargetRewritePass()";
let dependentDialects = ["fir::FIROpsDialect"];
let options = [
Option<"noCharacterConversion", "no-character-conversion",
"bool", /*default=*/"false",
"Disable target-specific conversion of CHARACTER.">,
Option<"noComplexConversion", "no-complex-conversion",
"bool", /*default=*/"false",
"Disable target-specific conversion of COMPLEX.">
];
}

#endif // FLANG_OPTIMIZER_CODEGEN_PASSES
12 changes: 12 additions & 0 deletions flang/include/flang/Optimizer/CodeGen/CodeGen.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#ifndef OPTIMIZER_CODEGEN_CODEGEN_H
#define OPTIMIZER_CODEGEN_CODEGEN_H

#include "mlir/IR/Module.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include <memory>
Expand All @@ -21,6 +22,17 @@ struct NameUniquer;
/// the code gen (to LLVM-IR dialect) conversion.
std::unique_ptr<mlir::Pass> createFirCodeGenRewritePass();

/// FirTargetRewritePass options.
struct TargetRewriteOptions {
bool noCharacterConversion{};
bool noComplexConversion{};
};

/// Prerequiste pass for code gen. Perform intermediate rewrites to tailor the
/// IR for the chosen target.
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createFirTargetRewritePass(
const TargetRewriteOptions &options = TargetRewriteOptions());

/// Convert FIR to the LLVM IR dialect
std::unique_ptr<mlir::Pass> createFIRToLLVMPass(NameUniquer &uniquer);

Expand Down
10 changes: 4 additions & 6 deletions flang/include/flang/Optimizer/Dialect/FIRDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
#ifndef OPTIMIZER_DIALECT_FIRDIALECT_H
#define OPTIMIZER_DIALECT_FIRDIALECT_H

#include "mlir/Conversion/Passes.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/InitAllDialects.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassRegistry.h"
#include "mlir/Transforms/Passes.h"
#include "mlir/Transforms/LocationSnapshot.h"
#include "mlir/Dialect/Affine/Passes.h"
#include "mlir/Conversion/Passes.h"
#include "mlir/Transforms/Passes.h"

namespace fir {

Expand Down Expand Up @@ -81,9 +81,7 @@ inline void registerGeneralPasses() {
mlir::registerConvertAffineToStandardPass();
}

inline void registerFIRPasses() {
registerGeneralPasses();
}
inline void registerFIRPasses() { registerGeneralPasses(); }

} // namespace fir

Expand Down
36 changes: 21 additions & 15 deletions flang/include/flang/Optimizer/Dialect/FIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,16 @@ def fir_Type : Type<CPred<"fir::isa_fir_or_std_type($_self)">,
// Fortran intrinsic types
def fir_CharacterType : Type<CPred<"$_self.isa<fir::CharacterType>()">,
"FIR character type">;
def fir_ComplexType : Type<CPred<"$_self.isa<fir::CplxType>()">,
def fir_ComplexType : Type<CPred<"$_self.isa<fir::ComplexType>()">,
"FIR complex type">;
def fir_IntegerType : Type<CPred<"$_self.isa<fir::IntType>()">,
def fir_IntegerType : Type<CPred<"$_self.isa<fir::IntegerType>()">,
"FIR integer type">;
def fir_LogicalType : Type<CPred<"$_self.isa<fir::LogicalType>()">,
"FIR logical type">;
def fir_RealType : Type<CPred<"$_self.isa<fir::RealType>()">,
"FIR real type">;
def fir_VectorType : Type<CPred<"$_self.isa<fir::VectorType>()">,
"FIR vector type">;

// Generalized FIR and standard dialect types representing intrinsic types
def AnyIntegerLike : TypeConstraint<Or<[SignlessIntegerLike.predicate,
Expand All @@ -59,7 +61,7 @@ def fir_SequenceType : Type<CPred<"$_self.isa<fir::SequenceType>()">,
// Composable types
def AnyCompositeLike : TypeConstraint<Or<[fir_RecordType.predicate,
fir_SequenceType.predicate, fir_ComplexType.predicate,
IsTupleTypePred]>, "any composite">;
fir_VectorType.predicate, IsTupleTypePred]>, "any composite">;

// Reference to an entity type
def fir_ReferenceType : Type<CPred<"$_self.isa<fir::ReferenceType>()">,
Expand All @@ -77,6 +79,10 @@ def fir_PointerType : Type<CPred<"$_self.isa<fir::PointerType>()">,
def AnyReferenceLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
fir_HeapType.predicate, fir_PointerType.predicate]>, "any reference">;

// The legal types of global symbols
def AnyAddressableLike : TypeConstraint<Or<[fir_ReferenceType.predicate,
FunctionType.predicate]>, "any addressable">;

// A descriptor tuple (captures a reference to an entity and other information)
def fir_BoxType : Type<CPred<"$_self.isa<fir::BoxType>()">, "box type">;

Expand Down Expand Up @@ -723,7 +729,7 @@ class fir_IntegralSwitchTerminatorOp<string mnemonic,
let verifier = [{
if (!(getSelector().getType().isa<mlir::IntegerType>() ||
getSelector().getType().isa<mlir::IndexType>() ||
getSelector().getType().isa<fir::IntType>()))
getSelector().getType().isa<fir::IntegerType>()))
return emitOpError("must be an integer");
auto cases = getAttrOfType<mlir::ArrayAttr>(getCasesAttr()).getValue();
auto count = getNumDest();
Expand Down Expand Up @@ -847,7 +853,7 @@ def fir_SelectCaseOp : fir_SwitchTerminatorOp<"select_case"> {
let verifier = [{
if (!(getSelector().getType().isa<mlir::IntegerType>() ||
getSelector().getType().isa<mlir::IndexType>() ||
getSelector().getType().isa<fir::IntType>() ||
getSelector().getType().isa<fir::IntegerType>() ||
getSelector().getType().isa<fir::LogicalType>() ||
getSelector().getType().isa<fir::CharacterType>()))
return emitOpError("must be an integer, character, or logical");
Expand Down Expand Up @@ -2349,6 +2355,8 @@ def fir_CallOp : fir_Op<"call", [CallOpInterface]> {
let extraClassDeclaration = [{
static constexpr StringRef calleeAttrName() { return "callee"; }

mlir::FunctionType getFunctionType();

/// Get the argument operands to the called function.
operand_range getArgOperands() {
if (auto calling = getAttrOfType<SymbolRefAttr>(calleeAttrName()))
Expand Down Expand Up @@ -2410,7 +2418,6 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
parser.resolveOperands(
operands, calleeType.getInputs(), calleeLoc, result.operands))
return mlir::failure();
result.addAttribute("fn_type", mlir::TypeAttr::get(calleeType));
return mlir::success();
}];

Expand All @@ -2422,10 +2429,8 @@ def fir_DispatchOp : fir_Op<"dispatch", []> {
p.printOperands(args());
}
p << ')';
p.printOptionalAttrDict(getAttrs(), {"fn_type", "method"});
auto resTy{getResultTypes()};
llvm::SmallVector<mlir::Type, 8> argTy(getOperandTypes());
p << " : " << mlir::FunctionType::get(argTy, resTy, getContext());
p.printOptionalAttrDict(getAttrs(), {"method"});
p << " : " << getFunctionType();
}];

let extraClassDeclaration = [{
Expand Down Expand Up @@ -2676,7 +2681,7 @@ def fir_ConstcOp : fir_Op<"constc", [NoSideEffect]> {
}];

let verifier = [{
if (!getType().isa<fir::CplxType>())
if (!getType().isa<fir::ComplexType>())
return emitOpError("must be a !fir.complex type");
return mlir::success();
}];
Expand Down Expand Up @@ -2747,7 +2752,8 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {

let description = [{
Convert a symbol (a function or global reference) to an SSA-value to be
used in other Operations.
used in other Operations. References to Fortran symbols are distinguished
via this operation from other arbitrary constant values.

```mlir
%p = fir.address_of(@symbol) : !fir.ref<f64>
Expand All @@ -2756,7 +2762,7 @@ def fir_AddrOfOp : fir_OneResultOp<"address_of", [NoSideEffect]> {

let arguments = (ins SymbolRefAttr:$symbol);

let results = (outs fir_ReferenceType:$resTy);
let results = (outs AnyAddressableLike:$resTy);

let assemblyFormat = "`(` $symbol `)` attr-dict `:` type($resTy)";
}
Expand Down Expand Up @@ -2814,8 +2820,8 @@ def fir_ConvertOp : fir_OneResultOp<"convert", [NoSideEffect]> {

def FortranTypeAttr : Attr<And<[CPred<"$_self.isa<TypeAttr>()">,
Or<[CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::CharacterType>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::CplxType>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::IntType>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::ComplexType>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::IntegerType>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::LogicalType>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::RealType>()">,
CPred<"$_self.cast<TypeAttr>().getValue().isa<fir::RecordType>()">]>]>,
Expand Down
60 changes: 39 additions & 21 deletions flang/include/flang/Optimizer/Dialect/FIRType.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ struct BoxTypeStorage;
struct BoxCharTypeStorage;
struct BoxProcTypeStorage;
struct CharacterTypeStorage;
struct CplxTypeStorage;
struct ComplexTypeStorage;
struct FieldTypeStorage;
struct HeapTypeStorage;
struct IntTypeStorage;
struct IntegerTypeStorage;
struct LenTypeStorage;
struct LogicalTypeStorage;
struct PointerTypeStorage;
Expand All @@ -58,6 +58,7 @@ struct ShapeTypeStorage;
struct ShapeShiftTypeStorage;
struct SliceTypeStorage;
struct TypeDescTypeStorage;
struct VectorTypeStorage;
} // namespace detail

// These isa_ routines follow the precedent of llvm::isa_or_null<>
Expand Down Expand Up @@ -125,11 +126,11 @@ class CharacterType
/// Model of a Fortran COMPLEX intrinsic type, including the KIND type
/// parameter. COMPLEX is a floating point type with a real and imaginary
/// member.
class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,
detail::CplxTypeStorage> {
class ComplexType : public mlir::Type::TypeBase<fir::ComplexType, mlir::Type,
detail::ComplexTypeStorage> {
public:
using Base::Base;
static CplxType get(mlir::MLIRContext *ctxt, KindTy kind);
static fir::ComplexType get(mlir::MLIRContext *ctxt, KindTy kind);

/// Get the corresponding fir.real<k> type.
mlir::Type getElementType() const;
Expand All @@ -139,19 +140,18 @@ class CplxType : public mlir::Type::TypeBase<CplxType, mlir::Type,

/// Model of a Fortran INTEGER intrinsic type, including the KIND type
/// parameter.
class IntType
: public mlir::Type::TypeBase<IntType, mlir::Type, detail::IntTypeStorage> {
class IntegerType : public mlir::Type::TypeBase<fir::IntegerType, mlir::Type,
detail::IntegerTypeStorage> {
public:
using Base::Base;
static IntType get(mlir::MLIRContext *ctxt, KindTy kind);
static fir::IntegerType get(mlir::MLIRContext *ctxt, KindTy kind);
KindTy getFKind() const;
};

/// Model of a Fortran LOGICAL intrinsic type, including the KIND type
/// parameter.
class LogicalType
: public mlir::Type::TypeBase<LogicalType, mlir::Type,
detail::LogicalTypeStorage> {
class LogicalType : public mlir::Type::TypeBase<LogicalType, mlir::Type,
detail::LogicalTypeStorage> {
public:
using Base::Base;
static LogicalType get(mlir::MLIRContext *ctxt, KindTy kind);
Expand Down Expand Up @@ -414,14 +414,6 @@ class RecordType : public mlir::Type::TypeBase<RecordType, mlir::Type,
llvm::StringRef name);
};

mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);

void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);

/// Guarantee `type` is a scalar integral type (standard Integer, standard
/// Index, or FIR Int). Aborts execution if condition is false.
void verifyIntegralType(mlir::Type type);

/// Is `t` a FIR Real or MLIR Float type?
inline bool isa_real(mlir::Type t) {
return t.isa<fir::RealType>() || t.isa<mlir::FloatType>();
Expand All @@ -430,12 +422,38 @@ inline bool isa_real(mlir::Type t) {
/// Is `t` an integral type?
inline bool isa_integer(mlir::Type t) {
return t.isa<mlir::IndexType>() || t.isa<mlir::IntegerType>() ||
t.isa<fir::IntType>();
t.isa<fir::IntegerType>();
}

/// Replacement for the standard dialect's vector type. Relaxes some of the
/// constraints and imposes some new ones.
class VectorType : public mlir::Type::TypeBase<fir::VectorType, mlir::Type,
detail::VectorTypeStorage> {
public:
using Base::Base;

static fir::VectorType get(uint64_t len, mlir::Type eleTy);
mlir::Type getEleTy() const;
uint64_t getLen() const;

static mlir::LogicalResult
verifyConstructionInvariants(mlir::Location, uint64_t len, mlir::Type eleTy);
static bool isValidElementType(mlir::Type t) {
return isa_real(t) || isa_integer(t);
}
};

mlir::Type parseFirType(FIROpsDialect *, mlir::DialectAsmParser &parser);

void printFirType(FIROpsDialect *, mlir::Type ty, mlir::DialectAsmPrinter &p);

/// Guarantee `type` is a scalar integral type (standard Integer, standard
/// Index, or FIR Int). Aborts execution if condition is false.
void verifyIntegralType(mlir::Type type);

/// Is `t` a FIR or MLIR Complex type?
inline bool isa_complex(mlir::Type t) {
return t.isa<fir::CplxType>() || t.isa<mlir::ComplexType>();
return t.isa<fir::ComplexType>() || t.isa<mlir::ComplexType>();
}

inline bool isa_char_string(mlir::Type t) {
Expand Down
12 changes: 9 additions & 3 deletions flang/include/flang/Optimizer/Support/FIRContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,21 +29,24 @@ namespace fir {
class KindMapping;
struct NameUniquer;

/// Set the target triple for the module.
/// Set the target triple for the module. `triple` must not be deallocated while
/// module `mod` is still live.
void setTargetTriple(mlir::ModuleOp mod, llvm::Triple &triple);

/// Get a pointer to the Triple instance from the Module. If none was set,
/// returns a nullptr.
llvm::Triple *getTargetTriple(mlir::ModuleOp mod);

/// Set the name uniquer for the module.
/// Set the name uniquer for the module. `uniquer` must not be deallocated while
/// module `mod` is still live.
void setNameUniquer(mlir::ModuleOp mod, NameUniquer &uniquer);

/// Get a pointer to the NameUniquer instance from the Module. If none was set,
/// returns a nullptr.
NameUniquer *getNameUniquer(mlir::ModuleOp mod);

/// Set the kind mapping for the module.
/// Set the kind mapping for the module. `kindMap` must not be deallocated while
/// module `mod` is still live.
void setKindMapping(mlir::ModuleOp mod, KindMapping &kindMap);

/// Get a pointer to the KindMapping instance from the Module. If none was set,
Expand All @@ -53,6 +56,9 @@ KindMapping *getKindMapping(mlir::ModuleOp mod);
/// Helper for determining the target from the host, etc. Tools may use this
/// function to provide a consistent interpretation of the `--target=<string>`
/// command-line option.
/// An empty string ("") or "default" will specify that the default triple
/// should be used. "native" will specify that the host machine be used to
/// construct the triple.
std::string determineTargetTriple(llvm::StringRef triple);

} // namespace fir
Expand Down
5 changes: 3 additions & 2 deletions flang/include/flang/Optimizer/Transforms/RewritePatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,9 @@ def createConstantOp
"rewriter.getIndexAttr($1.dyn_cast<IntegerAttr>().getInt()))">;

def ForwardConstantConvertPattern
: Pat<(fir_ConvertOp:$res (ConstantOp $attr)),
: Pat<(fir_ConvertOp:$res (ConstantOp:$cnt $attr)),
(createConstantOp $res, $attr),
[(IndexTypePred $res)]>;
[(IndexTypePred $res)
,(IntegerTypePred $cnt)]>;

#endif // FIR_REWRITE_PATTERNS
4 changes: 2 additions & 2 deletions flang/lib/Lower/ComplexExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
mlir::Type
Fortran::lower::ComplexExprHelper::getComplexPartType(mlir::Type complexType) {
return Fortran::lower::convertReal(
builder.getContext(), complexType.cast<fir::CplxType>().getFKind());
builder.getContext(), complexType.cast<fir::ComplexType>().getFKind());
}

mlir::Type
Expand All @@ -27,7 +27,7 @@ Fortran::lower::ComplexExprHelper::getComplexPartType(mlir::Value cplx) {
mlir::Value Fortran::lower::ComplexExprHelper::createComplex(fir::KindTy kind,
mlir::Value real,
mlir::Value imag) {
auto complexTy = fir::CplxType::get(builder.getContext(), kind);
auto complexTy = fir::ComplexType::get(builder.getContext(), kind);
mlir::Value und = builder.create<fir::UndefOp>(loc, complexTy);
return insert<Part::Imag>(insert<Part::Real>(und, real), imag);
}
Expand Down
Loading