Skip to content

Commit c0bc775

Browse files
[MLIR][LLVM] Add CG Profile module flags support (llvm#137115)
Dialect only accept arbitrary module flag values in face of simple types like int and string. Whenever metadata is a bit more complex use specific attributes to map functionality. This PR adds an attribute to represent "CG Profile" entries, verifiers, import / translate support.
1 parent b6746b0 commit c0bc775

File tree

9 files changed

+189
-13
lines changed

9 files changed

+189
-13
lines changed

mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td

+33-6
Original file line numberDiff line numberDiff line change
@@ -1322,7 +1322,7 @@ def LLVM_DereferenceableAttr : LLVM_Attr<"Dereferenceable", "dereferenceable"> {
13221322
}
13231323

13241324
//===----------------------------------------------------------------------===//
1325-
// ModuleFlagAttr
1325+
// ModuleFlagAttr & related
13261326
//===----------------------------------------------------------------------===//
13271327

13281328
def ModuleFlagAttr
@@ -1332,14 +1332,22 @@ def ModuleFlagAttr
13321332
Represents a single entry of llvm.module.flags metadata
13331333
(llvm::Module::ModuleFlagEntry in LLVM). The first element is a behavior
13341334
flag described by `ModFlagBehaviorAttr`, the second is a string ID
1335-
and third is the value of the flag. Current supported types of values:
1336-
- Integer constants
1337-
- Strings
1335+
and third is the value of the flag. Supported keys and values include:
1336+
- Arbitrary `key`s holding integer constants or strings.
1337+
- Domain specific keys (e.g "CG Profile"), holding lists of supported
1338+
module flag values (e.g. `llvm.cgprofile_entry`).
13381339

13391340
Example:
13401341
```mlir
1341-
#llvm.mlir.module_flag<error, "wchar_size", 4>
1342-
#llvm.mlir.module_flag<error, "probe-stack", "inline-asm">
1342+
llvm.module_flags [
1343+
#llvm.mlir.module_flag<error, "wchar_size", 4>,
1344+
#llvm.mlir.module_flag<error, "probe-stack", "inline-asm">,
1345+
#llvm.mlir.module_flag<append, "CG Profile", [
1346+
#llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
1347+
#llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
1348+
#llvm.cgprofile_entry<from = @to, to = @from, count = 222>
1349+
]
1350+
>]
13431351
```
13441352
}];
13451353
let parameters = (ins "ModFlagBehavior":$behavior,
@@ -1349,6 +1357,25 @@ def ModuleFlagAttr
13491357
let genVerifyDecl = 1;
13501358
}
13511359

1360+
def ModuleFlagCGProfileEntryAttr
1361+
: LLVM_Attr<"ModuleFlagCGProfileEntry", "cgprofile_entry"> {
1362+
let summary = "CG profile module flag entry";
1363+
let description = [{
1364+
Describes a single entry for a CG profile module flag. Example:
1365+
```mlir
1366+
llvm.module_flags [
1367+
#llvm.mlir.module_flag<append, "CG Profile",
1368+
[#llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
1369+
...
1370+
]>]
1371+
```
1372+
}];
1373+
let parameters = (ins "FlatSymbolRefAttr":$from,
1374+
"FlatSymbolRefAttr":$to,
1375+
"uint64_t":$count);
1376+
let assemblyFormat = "`<` struct(params) `>`";
1377+
}
1378+
13521379
//===----------------------------------------------------------------------===//
13531380
// LLVM_DependentLibrariesAttr
13541381
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.td

+5
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@ def LLVM_Dialect : Dialect {
8888
return "llvm.dependent_libraries";
8989
}
9090

91+
/// Names of known llvm module flag keys.
92+
static StringRef getModuleFlagKeyCGProfileName() {
93+
return "CG Profile";
94+
}
95+
9196
/// Returns `true` if the given type is compatible with the LLVM dialect.
9297
static bool isCompatibleType(Type);
9398

mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp

+16-4
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,20 @@ LogicalResult
380380
ModuleFlagAttr::verify(function_ref<InFlightDiagnostic()> emitError,
381381
LLVM::ModFlagBehavior flagBehavior, StringAttr key,
382382
Attribute value) {
383-
if (!isa<IntegerAttr, StringAttr>(value))
384-
return emitError()
385-
<< "only integer and string values are currently supported";
386-
return success();
383+
if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) {
384+
auto arrayAttr = dyn_cast<ArrayAttr>(value);
385+
if ((!arrayAttr) || (!llvm::all_of(arrayAttr, [](Attribute attr) {
386+
return isa<ModuleFlagCGProfileEntryAttr>(attr);
387+
})))
388+
return emitError()
389+
<< "'CG Profile' key expects an array of '#llvm.cgprofile_entry'";
390+
return success();
391+
}
392+
393+
if (isa<IntegerAttr, StringAttr>(value))
394+
return success();
395+
396+
return emitError() << "only integer and string values are currently "
397+
"supported for unknown key '"
398+
<< key << "'";
387399
}

mlir/lib/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,31 @@ static void convertLinkerOptionsOp(ArrayAttr options,
271271
linkerMDNode->addOperand(listMDNode);
272272
}
273273

274+
static llvm::Metadata *
275+
convertModuleFlagValue(StringRef key, ArrayAttr arrayAttr,
276+
llvm::IRBuilderBase &builder,
277+
LLVM::ModuleTranslation &moduleTranslation) {
278+
llvm::LLVMContext &context = builder.getContext();
279+
llvm::MDBuilder mdb(context);
280+
SmallVector<llvm::Metadata *> nodes;
281+
282+
if (key == LLVMDialect::getModuleFlagKeyCGProfileName()) {
283+
for (auto entry : arrayAttr.getAsRange<ModuleFlagCGProfileEntryAttr>()) {
284+
llvm::Function *fromFn =
285+
moduleTranslation.lookupFunction(entry.getFrom().getValue());
286+
llvm::Function *toFn =
287+
moduleTranslation.lookupFunction(entry.getTo().getValue());
288+
llvm::Metadata *vals[] = {
289+
llvm::ValueAsMetadata::get(fromFn), llvm::ValueAsMetadata::get(toFn),
290+
mdb.createConstant(llvm::ConstantInt::get(
291+
llvm::Type::getInt64Ty(context), entry.getCount()))};
292+
nodes.push_back(llvm::MDNode::get(context, vals));
293+
}
294+
return llvm::MDTuple::getDistinct(context, nodes);
295+
}
296+
return nullptr;
297+
}
298+
274299
static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
275300
LLVM::ModuleTranslation &moduleTranslation) {
276301
llvm::Module *llvmModule = moduleTranslation.getLLVMModule();
@@ -286,6 +311,11 @@ static void convertModuleFlagsOp(ArrayAttr flags, llvm::IRBuilderBase &builder,
286311
llvm::Type::getInt32Ty(builder.getContext()),
287312
intAttr.getInt()));
288313
})
314+
.Case<ArrayAttr>([&](auto arrayAttr) {
315+
return convertModuleFlagValue(flagAttr.getKey().getValue(),
316+
arrayAttr, builder,
317+
moduleTranslation);
318+
})
289319
.Default([](auto) { return nullptr; });
290320

291321
assert(valueMetadata && "expected valid metadata");

mlir/lib/Target/LLVMIR/ModuleImport.cpp

+39-1
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,39 @@ void ModuleImport::addDebugIntrinsic(llvm::CallInst *intrinsic) {
519519
debugIntrinsics.insert(intrinsic);
520520
}
521521

522+
static Attribute convertCGProfileModuleFlagValue(ModuleOp mlirModule,
523+
llvm::MDTuple *mdTuple) {
524+
auto getFunctionSymbol = [&](const llvm::MDOperand &funcMDO) {
525+
auto *f = cast<llvm::ValueAsMetadata>(funcMDO);
526+
auto *llvmFn = cast<llvm::Function>(f->getValue()->stripPointerCasts());
527+
return FlatSymbolRefAttr::get(mlirModule->getContext(), llvmFn->getName());
528+
};
529+
530+
// Each tuple element becomes one ModuleFlagCGProfileEntryAttr.
531+
SmallVector<Attribute> cgProfile;
532+
for (unsigned i = 0; i < mdTuple->getNumOperands(); i++) {
533+
const llvm::MDOperand &mdo = mdTuple->getOperand(i);
534+
auto *cgEntry = cast<llvm::MDNode>(mdo);
535+
llvm::Constant *llvmConstant =
536+
cast<llvm::ConstantAsMetadata>(cgEntry->getOperand(2))->getValue();
537+
uint64_t count = cast<llvm::ConstantInt>(llvmConstant)->getZExtValue();
538+
cgProfile.push_back(ModuleFlagCGProfileEntryAttr::get(
539+
mlirModule->getContext(), getFunctionSymbol(cgEntry->getOperand(0)),
540+
getFunctionSymbol(cgEntry->getOperand(1)), count));
541+
}
542+
return ArrayAttr::get(mlirModule->getContext(), cgProfile);
543+
}
544+
545+
/// Invoke specific handlers for each known module flag value, returns nullptr
546+
/// if the key is unknown or unimplemented.
547+
static Attribute convertModuleFlagValueFromMDTuple(ModuleOp mlirModule,
548+
StringRef key,
549+
llvm::MDTuple *mdTuple) {
550+
if (key == LLVMDialect::getModuleFlagKeyCGProfileName())
551+
return convertCGProfileModuleFlagValue(mlirModule, mdTuple);
552+
return nullptr;
553+
}
554+
522555
LogicalResult ModuleImport::convertModuleFlagsMetadata() {
523556
SmallVector<llvm::Module::ModuleFlagEntry> llvmModuleFlags;
524557
llvmModule->getModuleFlagsMetadata(llvmModuleFlags);
@@ -530,7 +563,12 @@ LogicalResult ModuleImport::convertModuleFlagsMetadata() {
530563
valAttr = builder.getI32IntegerAttr(constInt->getZExtValue());
531564
} else if (auto *mdString = dyn_cast<llvm::MDString>(val)) {
532565
valAttr = builder.getStringAttr(mdString->getString());
533-
} else {
566+
} else if (auto *mdTuple = dyn_cast<llvm::MDTuple>(val)) {
567+
valAttr = convertModuleFlagValueFromMDTuple(mlirModule, key->getString(),
568+
mdTuple);
569+
}
570+
571+
if (!valAttr) {
534572
emitWarning(mlirModule.getLoc())
535573
<< "unsupported module flag value: " << diagMD(val, llvmModule.get());
536574
continue;

mlir/test/Dialect/LLVMIR/invalid.mlir

+16
Original file line numberDiff line numberDiff line change
@@ -1784,6 +1784,22 @@ module {
17841784

17851785
// -----
17861786

1787+
module {
1788+
// expected-error@below {{'CG Profile' key expects an array of '#llvm.cgprofile_entry'}}
1789+
llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
1790+
"yo"
1791+
]>]
1792+
}
1793+
1794+
// -----
1795+
1796+
module {
1797+
// expected-error@below {{'CG Profile' key expects an array of '#llvm.cgprofile_entry'}}
1798+
llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", 3 : i64>]
1799+
}
1800+
1801+
// -----
1802+
17871803
llvm.func @t0() -> !llvm.ptr {
17881804
%0 = llvm.blockaddress <function = @t0, tag = <id = 1>> : !llvm.ptr
17891805
llvm.blocktag <id = 1>

mlir/test/Dialect/LLVMIR/module-roundtrip.mlir

+12-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@ module {
66
#llvm.mlir.module_flag<max, "PIE Level", 2 : i32>,
77
#llvm.mlir.module_flag<max, "uwtable", 2 : i32>,
88
#llvm.mlir.module_flag<max, "frame-pointer", 1 : i32>,
9-
#llvm.mlir.module_flag<override, "probe-stack", "inline-asm">]
9+
#llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
10+
#llvm.mlir.module_flag<append, "CG Profile", [
11+
#llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
12+
#llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
13+
#llvm.cgprofile_entry<from = @to, to = @from, count = 222>
14+
]>]
1015
}
1116

1217
// CHECK: llvm.module_flags [
@@ -15,4 +20,9 @@ module {
1520
// CHECK-SAME: #llvm.mlir.module_flag<max, "PIE Level", 2 : i32>,
1621
// CHECK-SAME: #llvm.mlir.module_flag<max, "uwtable", 2 : i32>,
1722
// CHECK-SAME: #llvm.mlir.module_flag<max, "frame-pointer", 1 : i32>,
18-
// CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">]
23+
// CHECK-SAME: #llvm.mlir.module_flag<override, "probe-stack", "inline-asm">,
24+
// CHECK-SAME: #llvm.mlir.module_flag<append, "CG Profile", [
25+
// CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
26+
// CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
27+
// CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
28+
// CHECK-SAME: ]>]

mlir/test/Target/LLVMIR/Import/module-flags.ll

+19
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,22 @@
2525
!12 = !{ i32 2, !"qux", i32 42 }
2626
!13 = !{ i32 3, !"qux", !{ !"foo", i32 1 }}
2727
!llvm.module.flags = !{ !10, !11, !12, !13 }
28+
29+
; // -----
30+
31+
declare void @from(i32)
32+
declare void @to()
33+
34+
!llvm.module.flags = !{!20}
35+
36+
!20 = !{i32 5, !"CG Profile", !21}
37+
!21 = distinct !{!22, !23, !24}
38+
!22 = !{ptr @from, ptr @to, i64 222}
39+
!23 = !{ptr @from, ptr @from, i64 222}
40+
!24 = !{ptr @to, ptr @from, i64 222}
41+
42+
; CHECK: llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
43+
; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
44+
; CHECK-SAME: #llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
45+
; CHECK-SAME: #llvm.cgprofile_entry<from = @to, to = @from, count = 222>
46+
; CHECK-SAME: ]>]

mlir/test/Target/LLVMIR/llvmir.mlir

+19
Original file line numberDiff line numberDiff line change
@@ -2838,6 +2838,25 @@ module {
28382838

28392839
// -----
28402840

2841+
llvm.module_flags [#llvm.mlir.module_flag<append, "CG Profile", [
2842+
#llvm.cgprofile_entry<from = @from, to = @to, count = 222>,
2843+
#llvm.cgprofile_entry<from = @from, to = @from, count = 222>,
2844+
#llvm.cgprofile_entry<from = @to, to = @from, count = 222>
2845+
]>]
2846+
llvm.func @from(i32)
2847+
llvm.func @to()
2848+
2849+
// CHECK: !llvm.module.flags = !{![[#CGPROF:]], ![[#DBG:]]}
2850+
2851+
// CHECK: ![[#CGPROF]] = !{i32 5, !"CG Profile", ![[#LIST:]]}
2852+
// CHECK: ![[#LIST]] = distinct !{![[#ENTRY_A:]], ![[#ENTRY_B:]], ![[#ENTRY_C:]]}
2853+
// CHECK: ![[#ENTRY_A]] = !{ptr @from, ptr @to, i64 222}
2854+
// CHECK: ![[#ENTRY_B]] = !{ptr @from, ptr @from, i64 222}
2855+
// CHECK: ![[#ENTRY_C]] = !{ptr @to, ptr @from, i64 222}
2856+
// CHECK: ![[#DBG]] = !{i32 2, !"Debug Info Version", i32 3}
2857+
2858+
// -----
2859+
28412860
module attributes {llvm.dependent_libraries = ["foo", "bar"]} {}
28422861

28432862
// CHECK: !llvm.dependent-libraries = !{![[#LIBFOO:]], ![[#LIBBAR:]]}

0 commit comments

Comments
 (0)