Skip to content

Commit b9f3a8f

Browse files
authored
Implement enzyme_dupv (rust-lang#618)
* Implement enzyme_dupv * Add test
1 parent cf73e23 commit b9f3a8f

File tree

2 files changed

+132
-1
lines changed

2 files changed

+132
-1
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,7 @@ class Enzyme : public ModulePass {
458458
IRBuilder<> Builder(CI);
459459
unsigned truei = 0;
460460
unsigned width = 1;
461+
std::map<unsigned, Value *> batchOffset;
461462
bool returnUsed = !cast<Function>(fn)->getReturnType()->isVoidTy() &&
462463
!cast<Function>(fn)->getReturnType()->isEmptyTy();
463464

@@ -656,6 +657,20 @@ class Enzyme : public ModulePass {
656657
if (metaString && metaString.getValue().startswith("enzyme_")) {
657658
if (*metaString == "enzyme_dup") {
658659
ty = DIFFE_TYPE::DUP_ARG;
660+
} else if (*metaString == "enzyme_dupv") {
661+
ty = DIFFE_TYPE::DUP_ARG;
662+
++i;
663+
Value *offset_arg = CI->getArgOperand(i);
664+
if (auto cint = dyn_cast<IntegerType>(offset_arg->getType())) {
665+
batchOffset[i + 1] = offset_arg;
666+
} else {
667+
EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
668+
"enzyme_batch must be followd by an integer "
669+
"offset.",
670+
*CI->getArgOperand(i), " in", *CI);
671+
return false;
672+
}
673+
continue;
659674
} else if (*metaString == "enzyme_dupnoneed") {
660675
ty = DIFFE_TYPE::DUP_NONEED;
661676
} else if (*metaString == "enzyme_out") {
@@ -753,6 +768,7 @@ class Enzyme : public ModulePass {
753768
++i;
754769

755770
Value *res = nullptr;
771+
bool batch = batchOffset.count(i - 1) != 0;
756772

757773
for (unsigned v = 0; v < width; ++v) {
758774
#if LLVM_VERSION_MAJOR >= 14
@@ -771,6 +787,21 @@ class Enzyme : public ModulePass {
771787

772788
// cast diffe
773789
Value *element = CI->getArgOperand(i);
790+
if (batch) {
791+
if (auto elementPtrTy = dyn_cast<PointerType>(element->getType())) {
792+
element = Builder.CreateBitCast(
793+
element, PointerType::get(Type::getInt8Ty(CI->getContext()),
794+
elementPtrTy->getAddressSpace()));
795+
element = Builder.CreateGEP(
796+
element,
797+
Builder.CreateMul(
798+
batchOffset[i - 1],
799+
ConstantInt::get(batchOffset[i - 1]->getType(), v)));
800+
element = Builder.CreateBitCast(element, elementPtrTy);
801+
} else {
802+
return false;
803+
}
804+
}
774805
if (PTy != element->getType()) {
775806
element = castToDiffeFunctionArgType(Builder, CI, FT, PTy, i, mode,
776807
element, truei);
@@ -786,7 +817,7 @@ class Enzyme : public ModulePass {
786817
element->getType(), width)),
787818
element, {v});
788819

789-
if (v < width - 1) {
820+
if (v < width - 1 && !batch) {
790821
++i;
791822
}
792823

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -S | FileCheck %s
2+
3+
4+
@enzyme_width = external global i32, align 4
5+
@enzyme_dupv = external global i32, align 4
6+
7+
define void @square(double* nocapture readonly %x, double* nocapture %out) {
8+
entry:
9+
%0 = load double, double* %x, align 8
10+
%mul = fmul double %0, %0
11+
store double %mul, double* %out, align 8
12+
ret void
13+
}
14+
15+
define void @dsquare(double* %x, double* %dx, double* %out, double* %dout) {
16+
entry:
17+
%0 = load i32, i32* @enzyme_width, align 4
18+
%1 = load i32, i32* @enzyme_dupv, align 4
19+
call void (i8*, ...) @__enzyme_fwddiff(i8* bitcast (void (double*, double*)* @square to i8*), i32 %0, i32 3, i32 %1, i64 16, double* %x, double* %dx, i32 %1, i64 16, double* %out, double* %dout)
20+
ret void
21+
}
22+
23+
declare void @__enzyme_fwddiff(i8*, ...)
24+
25+
26+
; CHECK: define void @dsquare(double* %x, double* %dx, double* %out, double* %dout)
27+
; CHECK-NEXT: entry:
28+
; CHECK-NEXT: %0 = load i32, i32* @enzyme_width, align 4
29+
; CHECK-NEXT: %1 = load i32, i32* @enzyme_dupv, align 4
30+
; CHECK-NEXT: %2 = bitcast double* %dx to i8*
31+
; CHECK-NEXT: %3 = getelementptr i8, i8* %2, i64 0
32+
; CHECK-NEXT: %4 = bitcast i8* %3 to double*
33+
; CHECK-NEXT: %5 = insertvalue [3 x double*] undef, double* %4, 0
34+
; CHECK-NEXT: %6 = bitcast double* %dx to i8*
35+
; CHECK-NEXT: %7 = getelementptr i8, i8* %6, i64 16
36+
; CHECK-NEXT: %8 = bitcast i8* %7 to double*
37+
; CHECK-NEXT: %9 = insertvalue [3 x double*] %5, double* %8, 1
38+
; CHECK-NEXT: %10 = bitcast double* %dx to i8*
39+
; CHECK-NEXT: %11 = getelementptr i8, i8* %10, i64 32
40+
; CHECK-NEXT: %12 = bitcast i8* %11 to double*
41+
; CHECK-NEXT: %13 = insertvalue [3 x double*] %9, double* %12, 2
42+
; CHECK-NEXT: %14 = bitcast double* %dout to i8*
43+
; CHECK-NEXT: %15 = getelementptr i8, i8* %14, i64 0
44+
; CHECK-NEXT: %16 = bitcast i8* %15 to double*
45+
; CHECK-NEXT: %17 = insertvalue [3 x double*] undef, double* %16, 0
46+
; CHECK-NEXT: %18 = bitcast double* %dout to i8*
47+
; CHECK-NEXT: %19 = getelementptr i8, i8* %18, i64 16
48+
; CHECK-NEXT: %20 = bitcast i8* %19 to double*
49+
; CHECK-NEXT: %21 = insertvalue [3 x double*] %17, double* %20, 1
50+
; CHECK-NEXT: %22 = bitcast double* %dout to i8*
51+
; CHECK-NEXT: %23 = getelementptr i8, i8* %22, i64 32
52+
; CHECK-NEXT: %24 = bitcast i8* %23 to double*
53+
; CHECK-NEXT: %25 = insertvalue [3 x double*] %21, double* %24, 2
54+
; CHECK-NEXT: call void @fwddiffe3square(double* %x, [3 x double*] %13, double* %out, [3 x double*] %25)
55+
; CHECK-NEXT: ret void
56+
; CHECK-NEXT: }
57+
58+
; CHECK: define internal void @fwddiffe3square(double* nocapture readonly %x, [3 x double*] %"x'", double* nocapture %out, [3 x double*] %"out'")
59+
; CHECK-NEXT: entry:
60+
; CHECK-NEXT: %0 = load double, double* %x, align 8
61+
; CHECK-NEXT: %1 = extractvalue [3 x double*] %"x'", 0
62+
; CHECK-NEXT: %2 = load double, double* %1, align 8
63+
; CHECK-NEXT: %3 = insertvalue [3 x double] undef, double %2, 0
64+
; CHECK-NEXT: %4 = extractvalue [3 x double*] %"x'", 1
65+
; CHECK-NEXT: %5 = load double, double* %4, align 8
66+
; CHECK-NEXT: %6 = insertvalue [3 x double] %3, double %5, 1
67+
; CHECK-NEXT: %7 = extractvalue [3 x double*] %"x'", 2
68+
; CHECK-NEXT: %8 = load double, double* %7, align 8
69+
; CHECK-NEXT: %9 = insertvalue [3 x double] %6, double %8, 2
70+
; CHECK-NEXT: %mul = fmul double %0, %0
71+
; CHECK-NEXT: %10 = extractvalue [3 x double] %9, 0
72+
; CHECK-NEXT: %11 = extractvalue [3 x double] %9, 0
73+
; CHECK-NEXT: %12 = fmul fast double %10, %0
74+
; CHECK-NEXT: %13 = fmul fast double %11, %0
75+
; CHECK-NEXT: %14 = fadd fast double %12, %13
76+
; CHECK-NEXT: %15 = insertvalue [3 x double] undef, double %14, 0
77+
; CHECK-NEXT: %16 = extractvalue [3 x double] %9, 1
78+
; CHECK-NEXT: %17 = extractvalue [3 x double] %9, 1
79+
; CHECK-NEXT: %18 = fmul fast double %16, %0
80+
; CHECK-NEXT: %19 = fmul fast double %17, %0
81+
; CHECK-NEXT: %20 = fadd fast double %18, %19
82+
; CHECK-NEXT: %21 = insertvalue [3 x double] %15, double %20, 1
83+
; CHECK-NEXT: %22 = extractvalue [3 x double] %9, 2
84+
; CHECK-NEXT: %23 = extractvalue [3 x double] %9, 2
85+
; CHECK-NEXT: %24 = fmul fast double %22, %0
86+
; CHECK-NEXT: %25 = fmul fast double %23, %0
87+
; CHECK-NEXT: %26 = fadd fast double %24, %25
88+
; CHECK-NEXT: %27 = insertvalue [3 x double] %21, double %26, 2
89+
; CHECK-NEXT: store double %mul, double* %out, align 8
90+
; CHECK-NEXT: %28 = extractvalue [3 x double*] %"out'", 0
91+
; CHECK-NEXT: %29 = extractvalue [3 x double] %27, 0
92+
; CHECK-NEXT: store double %29, double* %28, align 8
93+
; CHECK-NEXT: %30 = extractvalue [3 x double*] %"out'", 1
94+
; CHECK-NEXT: %31 = extractvalue [3 x double] %27, 1
95+
; CHECK-NEXT: store double %31, double* %30, align 8
96+
; CHECK-NEXT: %32 = extractvalue [3 x double*] %"out'", 2
97+
; CHECK-NEXT: %33 = extractvalue [3 x double] %27, 2
98+
; CHECK-NEXT: store double %33, double* %32, align 8
99+
; CHECK-NEXT: ret void
100+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)