Skip to content

Commit 9e48a51

Browse files
wsmosesvchuravy
andauthored
Improve MPI rank/size simplification (rust-lang#484)
* make MPI wrapper actually useful * Add test * Improve julia mpi aliasing Co-authored-by: Valentin Churavy <[email protected]>
1 parent f47e913 commit 9e48a51

File tree

3 files changed

+181
-15
lines changed

3 files changed

+181
-15
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,8 @@ static void handleKnownFunctions(llvm::Function &F) {
230230
}
231231
F.addParamAttr(6, Attribute::WriteOnly);
232232
}
233-
if (F.getName() == "MPI_Comm_rank") {
233+
if (F.getName() == "MPI_Comm_rank" || F.getName() == "PMPI_Comm_rank" ||
234+
F.getName() == "MPI_Comm_size" || F.getName() == "PMPI_Comm_size") {
234235
F.addFnAttr(Attribute::InaccessibleMemOrArgMemOnly);
235236
F.addFnAttr(Attribute::NoUnwind);
236237
F.addFnAttr(Attribute::NoRecurse);
@@ -243,8 +244,10 @@ static void handleKnownFunctions(llvm::Function &F) {
243244
F.addParamAttr(0, Attribute::NoCapture);
244245
F.addParamAttr(0, Attribute::ReadOnly);
245246
}
246-
F.addParamAttr(1, Attribute::WriteOnly);
247-
F.addParamAttr(1, Attribute::NoCapture);
247+
if (F.getFunctionType()->getParamType(1)->isPointerTy()) {
248+
F.addParamAttr(1, Attribute::WriteOnly);
249+
F.addParamAttr(1, Attribute::NoCapture);
250+
}
248251
}
249252
if (F.getName() == "MPI_Wait") {
250253
F.addFnAttr(Attribute::NoUnwind);

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,12 @@ Function *CreateMPIWrapper(Function *F) {
657657
Function *W = Function::Create(FT, GlobalVariable::InternalLinkage, name,
658658
F->getParent());
659659
llvm::Attribute::AttrKind attrs[] = {
660+
#if LLVM_VERSION_MAJOR >= 9
661+
Attribute::WillReturn,
662+
#endif
663+
#if LLVM_VERSION_MAJOR >= 12
664+
Attribute::MustProgress,
665+
#endif
660666
Attribute::ReadOnly,
661667
Attribute::Speculatable,
662668
Attribute::NoUnwind,
@@ -684,6 +690,12 @@ Function *CreateMPIWrapper(Function *F) {
684690
IRBuilder<> B(entry);
685691
auto alloc = B.CreateAlloca(F->getReturnType());
686692
Value *args[] = {W->arg_begin(), alloc};
693+
694+
auto T = F->getFunctionType()->getParamType(1);
695+
if (!isa<PointerType>(T)) {
696+
assert(isa<IntegerType>(T));
697+
args[1] = B.CreatePtrToInt(args[1], T);
698+
}
687699
B.CreateCall(F, args);
688700
#if LLVM_VERSION_MAJOR > 7
689701
B.CreateRet(B.CreateLoad(F->getReturnType(), alloc));
@@ -692,7 +704,8 @@ Function *CreateMPIWrapper(Function *F) {
692704
#endif
693705
return W;
694706
}
695-
static void SimplifyMPIQueries(Function &NewF) {
707+
static void SimplifyMPIQueries(Function &NewF, FunctionAnalysisManager &FAM) {
708+
DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(NewF);
696709
SmallVector<CallInst *, 4> Todo;
697710
SmallVector<CallInst *, 0> OMPBounds;
698711
for (auto &BB : NewF) {
@@ -703,10 +716,8 @@ static void SimplifyMPIQueries(Function &NewF) {
703716
continue;
704717
if (Fn->getName() == "MPI_Comm_rank" ||
705718
Fn->getName() == "PMPI_Comm_rank" ||
706-
Fn->getName() == "MPI_Comm_size") {
707-
if (!CI->use_empty()) {
708-
continue;
709-
}
719+
Fn->getName() == "MPI_Comm_size" ||
720+
Fn->getName() == "PMPI_Comm_size") {
710721
Todo.push_back(CI);
711722
}
712723
if (Fn->getName() == "__kmpc_for_static_init_4" ||
@@ -721,9 +732,56 @@ static void SimplifyMPIQueries(Function &NewF) {
721732
for (auto CI : Todo) {
722733
IRBuilder<> B(CI);
723734
Value *arg[] = {CI->getArgOperand(0)};
724-
auto res = B.CreateCall(CreateMPIWrapper(CI->getCalledFunction()), arg);
725-
B.CreateStore(res, CI->getArgOperand(1));
735+
SmallVector<OperandBundleDef, 2> Defs;
736+
CI->getOperandBundlesAsDefs(Defs);
737+
auto res =
738+
B.CreateCall(CreateMPIWrapper(CI->getCalledFunction()), arg, Defs);
739+
Value *storePointer = CI->getArgOperand(1);
740+
741+
// Comm_rank and Comm_size return Err, assume 0 is success
742+
CI->replaceAllUsesWith(ConstantInt::get(CI->getType(), 0));
726743
CI->eraseFromParent();
744+
745+
while (auto Cast = dyn_cast<CastInst>(storePointer)) {
746+
storePointer = Cast->getOperand(0);
747+
if (Cast->use_empty())
748+
Cast->eraseFromParent();
749+
}
750+
751+
B.SetInsertPoint(res);
752+
753+
if (auto PT = dyn_cast<PointerType>(storePointer->getType())) {
754+
if (PT->getElementType() != res->getType())
755+
storePointer = B.CreateBitCast(
756+
storePointer,
757+
PointerType::get(res->getType(), PT->getAddressSpace()));
758+
} else {
759+
assert(isa<IntegerType>(storePointer->getType()));
760+
storePointer = B.CreateIntToPtr(storePointer,
761+
PointerType::getUnqual(res->getType()));
762+
}
763+
if (isa<AllocaInst>(storePointer)) {
764+
// If this is only loaded from, immedaitely replace
765+
// Immediately replace all dominated stores.
766+
SmallVector<LoadInst *, 2> LI;
767+
bool nonload = false;
768+
for (auto &U : storePointer->uses()) {
769+
if (auto L = dyn_cast<LoadInst>(U.getUser())) {
770+
LI.push_back(L);
771+
} else
772+
nonload = true;
773+
}
774+
if (!nonload) {
775+
for (auto L : LI) {
776+
if (DT.dominates(res, L)) {
777+
L->replaceAllUsesWith(res);
778+
L->eraseFromParent();
779+
}
780+
}
781+
}
782+
}
783+
B.SetInsertPoint(res->getNextNode());
784+
B.CreateStore(res, storePointer);
727785
}
728786
for (auto Bound : OMPBounds) {
729787
for (int i = 4; i <= 6; i++) {
@@ -747,6 +805,13 @@ static void SimplifyMPIQueries(Function &NewF) {
747805
Bound->addParamAttr(i, Attribute::NoCapture);
748806
}
749807
}
808+
PreservedAnalyses PA;
809+
PA.preserve<AssumptionAnalysis>();
810+
PA.preserve<TargetLibraryAnalysis>();
811+
PA.preserve<LoopAnalysis>();
812+
PA.preserve<DominatorTreeAnalysis>();
813+
PA.preserve<PostDominatorTreeAnalysis>();
814+
FAM.invalidate(NewF, PA);
750815
}
751816

752817
/// Perform recursive inlinining on NewF up to the given limit
@@ -1104,11 +1169,7 @@ Function *PreProcessCache::preprocessForClone(Function *F,
11041169
ConstantFoldTerminator(BE);
11051170
}
11061171

1107-
{
1108-
SimplifyMPIQueries(*NewF);
1109-
PreservedAnalyses PA;
1110-
FAM.invalidate(*NewF, PA);
1111-
}
1172+
SimplifyMPIQueries(*NewF, FAM);
11121173

11131174
if (EnzymeLowerGlobals) {
11141175
std::vector<CallInst *> Calls;
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -correlated-propagation -adce -S | FileCheck %s
2+
source_filename = "text"
3+
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
4+
target triple = "x86_64-pc-linux-gnu"
5+
6+
; Function Attrs: cold noreturn nounwind
7+
declare void @llvm.trap() #0
8+
9+
declare dso_local i32 @MPI_Comm_rank(i64, i64)
10+
11+
define double @sum(double* %arg, i64 %comm) {
12+
bb:
13+
%alloc = alloca i32, align 8
14+
%i5 = ptrtoint i32* %alloc to i64
15+
br label %bb11
16+
17+
bb11: ; preds = %bb
18+
%idx = phi i64 [ 0, %bb ], [ %inc, %bb22 ]
19+
%sum = phi double [ 0.000000e+00, %bb ], [ %add, %bb22 ]
20+
%inc = add i64 %idx, 1
21+
%i13 = getelementptr inbounds double, double* %arg, i64 %idx
22+
%i14 = load double, double* %i13, align 8
23+
%i16 = fmul double %i14, %i14
24+
%i19 = call i32 @MPI_Comm_rank(i64 %comm, i64 %i5)
25+
%ld = load i32, i32* %alloc
26+
%cf = uitofp i32 %ld to double
27+
%mm = fmul double %i16, %cf
28+
%add = fadd double %sum, %mm
29+
%i20 = icmp eq i32 %i19, 0
30+
br i1 %i20, label %bb22, label %bb21
31+
32+
bb21: ; preds = %bb11, %bb
33+
call void @llvm.trap() #1
34+
unreachable
35+
36+
bb22:
37+
%cmp = icmp eq i64 %idx, 9
38+
br i1 %cmp, label %exit, label %bb11
39+
40+
exit:
41+
ret double %add
42+
}
43+
44+
define void @dsum(double* %x, double* %xp, i64 %n) {
45+
entry:
46+
%0 = tail call double (double (double*, i64)*, ...) @__enzyme_autodiff(double (double*, i64)* nonnull @sum, double* %x, double* %xp, i64 %n)
47+
ret void
48+
}
49+
50+
declare double @__enzyme_autodiff(double (double*, i64)*, ...)
51+
52+
attributes #0 = { cold noreturn nounwind }
53+
attributes #1 = { noreturn }
54+
55+
; CHECK: define internal void @diffesum(double* %arg, double* %"arg'", i64 %comm, double %differeturn)
56+
; CHECK-NEXT: bb:
57+
; CHECK-NEXT: %0 = alloca i32
58+
; CHECK-NEXT: %1 = alloca i32
59+
; CHECK-NEXT: br label %bb11
60+
61+
; CHECK: bb11: ; preds = %bb11, %bb
62+
; CHECK-NEXT: %iv = phi i64 [ %iv.next, %bb11 ], [ 0, %bb ]
63+
; CHECK-NEXT: %iv.next = add nuw nsw i64 %iv, 1
64+
; CHECK-NEXT: %2 = bitcast i32* %1 to i8*
65+
; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %2)
66+
; CHECK-NEXT: %3 = ptrtoint i32* %1 to i64
67+
; CHECK-NEXT: %4 = call i32 @MPI_Comm_rank(i64 %comm, i64 %3)
68+
; CHECK-NEXT: %5 = bitcast i32* %1 to i8*
69+
; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %5)
70+
; CHECK-NEXT: %cmp = icmp eq i64 %iv, 9
71+
; CHECK-NEXT: br i1 %cmp, label %invertbb22, label %bb11
72+
73+
; CHECK: invertbb: ; preds = %invertbb22
74+
; CHECK-NEXT: ret void
75+
76+
; CHECK: incinvertbb11: ; preds = %invertbb22
77+
; CHECK-NEXT: %6 = add nsw i64 %"iv'ac.0", -1
78+
; CHECK-NEXT: br label %invertbb22
79+
80+
; CHECK: invertbb22: ; preds = %bb11, %incinvertbb11
81+
; CHECK-NEXT: %"iv'ac.0" = phi i64 [ %6, %incinvertbb11 ], [ 9, %bb11 ]
82+
; CHECK-NEXT: %7 = bitcast i32* %0 to i8*
83+
; CHECK-NEXT: call void @llvm.lifetime.start.p0i8(i64 4, i8* nonnull %7)
84+
; CHECK-NEXT: %8 = ptrtoint i32* %0 to i64
85+
; CHECK-NEXT: %9 = call i32 @MPI_Comm_rank(i64 %comm, i64 %8)
86+
; CHECK-NEXT: %10 = load i32, i32* %0
87+
; CHECK-NEXT: %11 = bitcast i32* %0 to i8*
88+
; CHECK-NEXT: call void @llvm.lifetime.end.p0i8(i64 4, i8* nonnull %11)
89+
; CHECK-NEXT: %cf_unwrap = uitofp i32 %10 to double
90+
; CHECK-NEXT: %m0diffei16 = fmul fast double %differeturn, %cf_unwrap
91+
; CHECK-NEXT: %i13_unwrap = getelementptr inbounds double, double* %arg, i64 %"iv'ac.0"
92+
; CHECK-NEXT: %i14_unwrap = load double, double* %i13_unwrap, align 8, !invariant.group !0
93+
; CHECK-NEXT: %m0diffei14 = fmul fast double %m0diffei16, %i14_unwrap
94+
; CHECK-NEXT: %m1diffei14 = fmul fast double %m0diffei16, %i14_unwrap
95+
; CHECK-NEXT: %12 = fadd fast double %m0diffei14, %m1diffei14
96+
; CHECK-NEXT: %"i13'ipg_unwrap" = getelementptr inbounds double, double* %"arg'", i64 %"iv'ac.0"
97+
; CHECK-NEXT: %13 = load double, double* %"i13'ipg_unwrap", align 8
98+
; CHECK-NEXT: %14 = fadd fast double %13, %12
99+
; CHECK-NEXT: store double %14, double* %"i13'ipg_unwrap", align 8
100+
; CHECK-NEXT: %15 = icmp eq i64 %"iv'ac.0", 0
101+
; CHECK-NEXT: br i1 %15, label %invertbb, label %incinvertbb11
102+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)