Skip to content

Commit e5d4699

Browse files
authored
Implement FNeg for Forward Mode (rust-lang#232)
* implement fneg for forward mode * fix tests * fail early
1 parent 8cbd07e commit e5d4699

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,6 @@ class AdjointGenerator
235235
void visitInstruction(llvm::Instruction &inst) {
236236
// TODO explicitly handle all instructions rather than using the catch all
237237
// below
238-
if (Mode == DerivativeMode::ReverseModePrimal)
239-
return;
240238

241239
#if LLVM_VERSION_MAJOR >= 10
242240
if (auto *FPMO = dyn_cast<FPMathOperator>(&inst)) {
@@ -248,16 +246,34 @@ class AdjointGenerator
248246
Value *orig_op1 = FPMO->getOperand(0);
249247
bool constantval1 = gutils->isConstantValue(orig_op1);
250248

251-
IRBuilder<> Builder2(inst.getParent());
252-
getReverseBuilder(Builder2);
249+
if (constantval1) {
250+
return;
251+
}
253252

254-
Value *idiff = diffe(FPMO, Builder2);
253+
switch (Mode) {
254+
case DerivativeMode::ReverseModeCombined:
255+
case DerivativeMode::ReverseModeGradient: {
256+
IRBuilder<> Builder2(inst.getParent());
257+
getReverseBuilder(Builder2);
255258

256-
if (!constantval1) {
259+
Value *idiff = diffe(FPMO, Builder2);
257260
Value *dif1 = Builder2.CreateFNeg(idiff);
258261
setDiffe(FPMO, Constant::getNullValue(FPMO->getType()), Builder2);
259262
addToDiffe(orig_op1, dif1, Builder2,
260263
dif1->getType()->getScalarType());
264+
break;
265+
}
266+
case DerivativeMode::ForwardMode: {
267+
IRBuilder<> Builder2(&inst);
268+
getForwardBuilder(Builder2);
269+
270+
Value *idiff = diffe(orig_op1, Builder2);
271+
Value *dif1 = Builder2.CreateFNeg(idiff);
272+
setDiffe(FPMO, dif1, Builder2);
273+
break;
274+
}
275+
case DerivativeMode::ReverseModePrimal:
276+
return;
261277
}
262278
return;
263279
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
; RUN: if [ %llvmver -ge 10 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
2+
3+
; extern double __enzyme_fwddiff(void*, double, double);
4+
;
5+
; double fneg(double x) {
6+
; return -x;
7+
; }
8+
;
9+
; double dfneg(double x) {
10+
; return __enzyme_fwddiff((void*)fneg, x, 1.0);
11+
; }
12+
13+
14+
define double @fneg(double %x) {
15+
%fneg = fneg double %x
16+
ret double %fneg
17+
}
18+
19+
define double @dfneg(double %x) {
20+
%1 = call double @__enzyme_fwddiff(double (double)* @fneg, double %x, double 1.0)
21+
ret double %1
22+
}
23+
24+
declare double @__enzyme_fwddiff(double (double)*, double, double)
25+
26+
27+
; CHECK: define internal { double } @diffefneg(double %x, double %"x'") {
28+
; CHECK-NEXT: %1 = fneg fast double %"x'"
29+
; CHECK-NEXT: %2 = insertvalue { double } undef, double %1, 0
30+
; CHECK-NEXT: ret { double } %2
31+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)