Skip to content

Commit e26b63a

Browse files
authored
Preserve stack alignment (rust-lang#487)
1 parent e40a859 commit e26b63a

File tree

4 files changed

+58
-5
lines changed

4 files changed

+58
-5
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8167,14 +8167,26 @@ class AdjointGenerator
81678167
if (Mode == DerivativeMode::ReverseModeGradient) {
81688168
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
81698169
} else {
8170-
if (hasMetadata(orig, "enzyme_fromstack")) {
8170+
if (auto MD = hasMetadata(orig, "enzyme_fromstack")) {
81718171
IRBuilder<> B(newCall);
81728172
if (auto CI = dyn_cast<ConstantInt>(orig->getArgOperand(0))) {
81738173
B.SetInsertPoint(gutils->inversionAllocs);
81748174
}
81758175
auto replacement = B.CreateAlloca(
81768176
Type::getInt8Ty(orig->getContext()),
81778177
gutils->getNewFromOriginal(orig->getArgOperand(0)));
8178+
auto Alignment =
8179+
cast<ConstantInt>(
8180+
cast<ConstantAsMetadata>(MD->getOperand(0))->getValue())
8181+
->getLimitedValue();
8182+
// Don't set zero alignment
8183+
if (Alignment) {
8184+
#if LLVM_VERSION_MAJOR >= 10
8185+
replacement->setAlignment(Align(Alignment));
8186+
#else
8187+
replacement->setAlignment(Alignment);
8188+
#endif
8189+
}
81788190
gutils->replaceAWithB(newCall, replacement);
81798191
gutils->erase(newCall);
81808192
}

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,11 @@ static inline void UpgradeAllocasToMallocs(Function *NewF,
312312
CallInst *CI = dyn_cast<CallInst>(rep);
313313
if (auto C = dyn_cast<CastInst>(rep))
314314
CI = cast<CallInst>(C->getOperand(0));
315-
CI->setMetadata("enzyme_fromstack", MDNode::get(CI->getContext(), {}));
315+
CI->setMetadata("enzyme_fromstack",
316+
MDNode::get(CI->getContext(),
317+
{ConstantAsMetadata::get(ConstantInt::get(
318+
IntegerType::get(AI->getContext(), 64),
319+
AI->getAlignment()))}));
316320
#if LLVM_VERSION_MAJOR >= 14
317321
CI->addRetAttr(Attribute::NoAlias);
318322
#else

enzyme/Enzyme/Utils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,9 @@ static inline bool hasMetadata(const llvm::GlobalObject *O,
203203
}
204204

205205
/// Check if an instruction has metadata
206-
static inline bool hasMetadata(const llvm::Instruction *O,
207-
llvm::StringRef kind) {
208-
return O->getMetadata(kind) != nullptr;
206+
static inline llvm::MDNode *hasMetadata(const llvm::Instruction *O,
207+
llvm::StringRef kind) {
208+
return O->getMetadata(kind);
209209
}
210210

211211
/// Potential return type of generated functions
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -S | FileCheck %s
2+
3+
define float @f(float %this) {
4+
entry:
5+
%call = tail call float @sub(float %this)
6+
%res = fmul float %call, %call
7+
ret float %res
8+
}
9+
10+
declare void @julia.write_barrier(float* readnone nocapture)
11+
12+
define float @sub(float %this) {
13+
entry:
14+
%alloc = alloca float, align 256
15+
store float %this, float* %alloc, align 8
16+
call void @julia.write_barrier(float* %alloc)
17+
ret float %this
18+
}
19+
20+
define float @g(float %t) {
21+
entry:
22+
%0 = tail call float (float (float)*, ...) @__enzyme_autodiff(float (float)* @f, float %t)
23+
ret float %0
24+
}
25+
26+
declare float @__enzyme_autodiff(float (float)*, ...)
27+
28+
; ensure both alignment is maintained and that the alloca
29+
; is not preserved for the reverse
30+
; CHECK: define internal float @augmented_sub(float %this)
31+
; CHECK-NEXT: entry:
32+
; CHECK-NEXT: %0 = alloca i8, i64 4, align 256
33+
; CHECK-NEXT: %alloc = bitcast i8* %0 to float*
34+
; CHECK-NEXT: store float %this, float* %alloc, align 8
35+
; CHECK-NEXT: call void @julia.write_barrier(float* %alloc)
36+
; CHECK-NEXT: ret float %this
37+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)