Skip to content

Commit 0447c03

Browse files
committed
Handle copysign
1 parent 95dae5d commit 0447c03

File tree

4 files changed

+80
-0
lines changed

4 files changed

+80
-0
lines changed

Diff for: enzyme/Enzyme/ActivityAnalysis.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,12 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
168168
if (F->getIntrinsicID() == Intrinsic::trap)
169169
return true;
170170

171+
/// Only the first argument (magnitude) of copysign is active
172+
if (F->getIntrinsicID() == Intrinsic::copysign &&
173+
CI->getArgOperand(0) != val) {
174+
return true;
175+
}
176+
171177
/// Use of the value as a non-src/dst in memset/memcpy/memmove is an inactive
172178
/// use
173179
if (F->getIntrinsicID() == Intrinsic::memset && CI->getArgOperand(0) != val &&

Diff for: enzyme/Enzyme/AdjointGenerator.h

+43
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,7 @@ class AdjointGenerator
13521352
case Intrinsic::log10:
13531353
case Intrinsic::exp:
13541354
case Intrinsic::exp2:
1355+
case Intrinsic::copysign:
13551356
case Intrinsic::pow:
13561357
case Intrinsic::powi:
13571358
#if LLVM_VERSION_MAJOR >= 9
@@ -1631,6 +1632,48 @@ class AdjointGenerator
16311632
}
16321633
return;
16331634
}
1635+
case Intrinsic::copysign: {
1636+
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
1637+
1638+
Value *xsign = nullptr;
1639+
{
1640+
Type *tys[] = {orig_ops[0]->getType()};
1641+
SmallVector<Value *, 2> args = {
1642+
ConstantFP::get(tys[0], 1.0),
1643+
lookup(gutils->getNewFromOriginal(orig_ops[0]), Builder2)};
1644+
1645+
auto cal = cast<CallInst>(Builder2.CreateCall(
1646+
Intrinsic::getDeclaration(M, Intrinsic::copysign, tys), args));
1647+
cal->copyIRFlags(&II);
1648+
cal->setAttributes(II.getAttributes());
1649+
cal->setCallingConv(II.getCallingConv());
1650+
cal->setTailCallKind(II.getTailCallKind());
1651+
cal->setDebugLoc(gutils->getNewFromOriginal(II.getDebugLoc()));
1652+
xsign = cal;
1653+
}
1654+
1655+
Value *ysign = nullptr;
1656+
{
1657+
Type *tys[] = {orig_ops[1]->getType()};
1658+
SmallVector<Value *, 2> args = {
1659+
ConstantFP::get(tys[0], 1.0),
1660+
lookup(gutils->getNewFromOriginal(orig_ops[1]), Builder2)};
1661+
1662+
auto cal = cast<CallInst>(Builder2.CreateCall(
1663+
Intrinsic::getDeclaration(M, Intrinsic::copysign, tys), args));
1664+
cal->copyIRFlags(&II);
1665+
cal->setAttributes(II.getAttributes());
1666+
cal->setCallingConv(II.getCallingConv());
1667+
cal->setTailCallKind(II.getTailCallKind());
1668+
cal->setDebugLoc(gutils->getNewFromOriginal(II.getDebugLoc()));
1669+
ysign = cal;
1670+
}
1671+
Value *dif0 =
1672+
Builder2.CreateFMul(Builder2.CreateFMul(xsign, ysign), vdiff);
1673+
addToDiffe(orig_ops[0], dif0, Builder2, II.getType());
1674+
}
1675+
return;
1676+
}
16341677
case Intrinsic::powi: {
16351678
if (vdiff && !gutils->isConstantValue(orig_ops[0])) {
16361679

Diff for: enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,7 @@ void TypeAnalyzer::visitIntrinsicInst(llvm::IntrinsicInst &I) {
17041704
#if LLVM_VERSION_MAJOR >= 9
17051705
case Intrinsic::experimental_vector_reduce_v2_fadd:
17061706
#endif
1707+
case Intrinsic::copysign:
17071708
case Intrinsic::maxnum:
17081709
case Intrinsic::minnum:
17091710
case Intrinsic::pow:

Diff for: enzyme/test/Enzyme/copysign.ll

+30
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -simplifycfg -S | FileCheck %s
2+
3+
; Function Attrs: noinline nounwind readnone uwtable
4+
define double @tester(double %x, double %y) {
5+
entry:
6+
%0 = tail call fast double @llvm.copysign.f64(double %x, double %y)
7+
ret double %0
8+
}
9+
10+
define double @test_derivative(double %x, double %y) {
11+
entry:
12+
%0 = tail call double (double (double, double)*, ...) @__enzyme_autodiff(double (double, double)* nonnull @tester, double %x, double %y)
13+
ret double %0
14+
}
15+
16+
declare double @llvm.copysign.f64(double, double)
17+
18+
; Function Attrs: nounwind
19+
declare double @__enzyme_autodiff(double (double, double)*, ...)
20+
21+
; CHECK: define internal {{(dso_local )?}}{ double, double } @diffetester(double %x, double %y, double %[[differet:.+]])
22+
; CHECK-NEXT: entry:
23+
; CHECK-NEXT: %0 = tail call fast double @llvm.copysign.f64(double 1.000000e+00, double %x)
24+
; CHECK-NEXT: %1 = tail call fast double @llvm.copysign.f64(double 1.000000e+00, double %y)
25+
; CHECK-NEXT: %2 = fmul fast double %0, %1
26+
; CHECK-NEXT: %3 = fmul fast double %2, %[[differet]]
27+
; CHECK-NEXT: %4 = insertvalue { double, double } undef, double %3, 0
28+
; CHECK-NEXT: %5 = insertvalue { double, double } %4, double 0.000000e+00, 1
29+
; CHECK-NEXT: ret { double, double } %5
30+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)