Skip to content

Commit 6382c2a

Browse files
authored
Forward mode: memset (rust-lang#490)
* Forward mode: memset * Fix activity check * Add tests
1 parent 0ed7544 commit 6382c2a

File tree

3 files changed

+128
-21
lines changed

3 files changed

+128
-21
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2521,29 +2521,29 @@ class AdjointGenerator
25212521
void visitMemSetInst(llvm::MemSetInst &MS) {
25222522
eraseIfUnused(MS);
25232523

2524-
if (gutils->isConstantInstruction(&MS))
2525-
return;
2526-
25272524
Value *orig_op0 = MS.getOperand(0);
25282525
Value *orig_op1 = MS.getOperand(1);
2529-
Value *op1 = gutils->getNewFromOriginal(orig_op1);
2530-
Value *op2 = gutils->getNewFromOriginal(MS.getOperand(2));
2531-
Value *op3 = gutils->getNewFromOriginal(MS.getOperand(3));
25322526

25332527
// TODO this should 1) assert that the value being meset is constant
25342528
// 2) duplicate the memset for the inverted pointer
25352529

2530+
if (gutils->isConstantInstruction(&MS) &&
2531+
Mode != DerivativeMode::ForwardMode) {
2532+
return;
2533+
}
2534+
2535+
// If constant destination then no operation needs doing
2536+
if (gutils->isConstantValue(orig_op0)) {
2537+
return;
2538+
}
2539+
25362540
if (!gutils->isConstantValue(orig_op1)) {
25372541
llvm::errs() << "couldn't handle non constant inst in memset to "
25382542
"propagate differential to\n"
25392543
<< MS;
25402544
report_fatal_error("non constant in memset");
25412545
}
25422546

2543-
// If constant destination then no operation needs doing
2544-
if (gutils->isConstantValue(orig_op0))
2545-
return;
2546-
25472547
bool backwardsShadow = false;
25482548
bool forwardsShadow = true;
25492549
for (auto pair : gutils->backwardsOnlyShadows) {
@@ -2560,23 +2560,31 @@ class AdjointGenerator
25602560
if ((Mode == DerivativeMode::ReverseModePrimal && forwardsShadow) ||
25612561
(Mode == DerivativeMode::ReverseModeGradient && backwardsShadow) ||
25622562
(Mode == DerivativeMode::ReverseModeCombined &&
2563-
(forwardsShadow && backwardsShadow))) {
2564-
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MS));
2565-
2566-
SmallVector<Value *, 4> args;
2567-
args.push_back(gutils->invertPointerM(orig_op0, BuilderZ));
2568-
args.push_back(gutils->lookupM(op1, BuilderZ));
2569-
args.push_back(gutils->lookupM(op2, BuilderZ));
2570-
args.push_back(gutils->lookupM(op3, BuilderZ));
2571-
2572-
Type *tys[] = {args[0]->getType(), args[2]->getType()};
2563+
(forwardsShadow && backwardsShadow)) ||
2564+
Mode == DerivativeMode::ForwardMode) {
2565+
IRBuilder<> BuilderZ(&MS);
2566+
getForwardBuilder(BuilderZ);
25732567

2568+
bool forwardMode = Mode == DerivativeMode::ForwardMode;
2569+
2570+
Value *op0 = gutils->invertPointerM(orig_op0, BuilderZ);
2571+
Value *op1 = gutils->getNewFromOriginal(MS.getOperand(1));
2572+
if (!forwardMode)
2573+
op1 = gutils->lookupM(op1, BuilderZ);
2574+
Value *op2 = gutils->getNewFromOriginal(MS.getOperand(2));
2575+
if (!forwardMode)
2576+
op2 = gutils->lookupM(op2, BuilderZ);
2577+
Value *op3 = gutils->getNewFromOriginal(MS.getOperand(3));
2578+
if (!forwardMode)
2579+
op3 = gutils->lookupM(op3, BuilderZ);
2580+
2581+
Type *tys[] = {op0->getType(), op2->getType()};
2582+
Value *args[] = {op0, op1, op2, op3};
25742583
auto Defs =
25752584
gutils->getInvertedBundles(&MS,
25762585
{ValueType::Shadow, ValueType::Primal,
25772586
ValueType::Primal, ValueType::Primal},
25782587
BuilderZ, /*lookup*/ false);
2579-
25802588
auto cal = BuilderZ.CreateCall(
25812589
Intrinsic::getDeclaration(MS.getParent()->getParent()->getParent(),
25822590
Intrinsic::memset, tys),
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -S | FileCheck %s
2+
3+
declare void @__enzyme_fwddiff(i8*, double*, double*, double*, double*)
4+
5+
declare void @llvm.memset.p0i8.i64(i8*, i8, i64, i1)
6+
7+
define void @f(double* %x, double* %y) {
8+
entry:
9+
%x1 = load double, double* %x
10+
%yptr = bitcast double* %y to i8*
11+
call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
12+
%y1 = load double, double* %y
13+
%x2 = fmul double %x1, %y1
14+
store double %x2, double* %x
15+
store double %x2, double* %y
16+
call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
17+
ret void
18+
}
19+
20+
define void @df(double* %x, double* %xp, double* %y, double* %dy) {
21+
entry:
22+
tail call void @__enzyme_fwddiff(i8* bitcast (void (double*, double*)* @f to i8*), double* %x, double* %xp, double* %y, double* %dy)
23+
ret void
24+
}
25+
26+
27+
; CHECK: define internal void @fwddiffef(double* %x, double* %"x'", double* %y, double* %"y'")
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %x1 = load double, double* %x
30+
; CHECK-NEXT: %0 = load double, double* %"x'"
31+
; CHECK-NEXT: %"yptr'ipc" = bitcast double* %"y'" to i8*
32+
; CHECK-NEXT: %yptr = bitcast double* %y to i8*
33+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
34+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %"yptr'ipc", i8 0, i64 8, i1 false)
35+
; CHECK-NEXT: %y1 = load double, double* %y
36+
; CHECK-NEXT: %1 = load double, double* %"y'"
37+
; CHECK-NEXT: %x2 = fmul double %x1, %y1
38+
; CHECK-NEXT: %2 = fmul fast double %0, %y1
39+
; CHECK-NEXT: %3 = fmul fast double %1, %x1
40+
; CHECK-NEXT: %4 = fadd fast double %2, %3
41+
; CHECK-NEXT: store double %x2, double* %x
42+
; CHECK-NEXT: store double %4, double* %"x'"
43+
; CHECK-NEXT: store double %x2, double* %y
44+
; CHECK-NEXT: store double %4, double* %"y'"
45+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
46+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %"yptr'ipc", i8 0, i64 8, i1 false)
47+
; CHECK-NEXT: ret void
48+
; CHECK-NEXT: }
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -S | FileCheck %s
2+
3+
declare void @__enzyme_autodiff(i8*, double*, double*, double*, double*)
4+
5+
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1)
6+
7+
define void @f(double* %x, double* %y) {
8+
%x1 = load double, double* %x, align 8
9+
%yptr = bitcast double* %y to i8*
10+
call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
11+
%y1 = load double, double* %y, align 8
12+
%x2 = fmul double %x1, %y1
13+
store double %x2, double* %x, align 8
14+
store double %x2, double* %y, align 8
15+
call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
16+
ret void
17+
}
18+
19+
define void @df(double* %x, double* %xp, double* %y, double* %dy) {
20+
tail call void @__enzyme_autodiff(i8* bitcast (void (double*, double*)* @f to i8*), double* %x, double* %xp, double* %y, double* %dy)
21+
ret void
22+
}
23+
24+
; CHECK: define internal void @diffef(double* %x, double* %"x'", double* %y, double* %"y'")
25+
; CHECK-NEXT: invert:
26+
; CHECK-NEXT: %x1 = load double, double* %x
27+
; CHECK-NEXT: %yptr = bitcast double* %y to i8*
28+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
29+
; CHECK-NEXT: %y1 = load double, double* %y
30+
; CHECK-NEXT: %x2 = fmul double %x1, %y1
31+
; CHECK-NEXT: store double %x2, double* %x
32+
; CHECK-NEXT: store double %x2, double* %y
33+
; CHECK-NEXT: call void @llvm.memset.p0i8.i64(i8* %yptr, i8 0, i64 8, i1 false)
34+
; CHECK-NEXT: %0 = load double, double* %"y'"
35+
; CHECK-NEXT: store double 0.000000e+00, double* %"y'"
36+
; CHECK-NEXT: %1 = fadd fast double 0.000000e+00, %0
37+
; CHECK-NEXT: %2 = load double, double* %"x'"
38+
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
39+
; CHECK-NEXT: %3 = fadd fast double %1, %2
40+
; CHECK-NEXT: %m0diffex1 = fmul fast double %3, %y1
41+
; CHECK-NEXT: %m1diffey1 = fmul fast double %3, %x1
42+
; CHECK-NEXT: %4 = fadd fast double 0.000000e+00, %m0diffex1
43+
; CHECK-NEXT: %5 = fadd fast double 0.000000e+00, %m1diffey1
44+
; CHECK-NEXT: %6 = load double, double* %"y'"
45+
; CHECK-NEXT: %7 = fadd fast double %6, %5
46+
; CHECK-NEXT: store double %7, double* %"y'"
47+
; CHECK-NEXT: %8 = load double, double* %"x'"
48+
; CHECK-NEXT: %9 = fadd fast double %8, %4
49+
; CHECK-NEXT: store double %9, double* %"x'"
50+
; CHECK-NEXT: ret void
51+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)