Skip to content

Commit ed8e5f8

Browse files
authored
Implement enzyme_dupnoneedv (rust-lang#636)
1 parent b71d59c commit ed8e5f8

File tree

2 files changed

+116
-0
lines changed

2 files changed

+116
-0
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -671,6 +671,20 @@ class Enzyme : public ModulePass {
671671
return false;
672672
}
673673
continue;
674+
} else if (*metaString == "enzyme_dupnoneedv") {
675+
ty = DIFFE_TYPE::DUP_NONEED;
676+
++i;
677+
Value *offset_arg = CI->getArgOperand(i);
678+
if (offset_arg->getType()->isIntegerTy()) {
679+
batchOffset[i + 1] = offset_arg;
680+
} else {
681+
EmitFailure("IllegalVectorOffset", CI->getDebugLoc(), CI,
682+
"enzyme_batch must be followd by an integer "
683+
"offset.",
684+
*CI->getArgOperand(i), " in", *CI);
685+
return false;
686+
}
687+
continue;
674688
} else if (*metaString == "enzyme_dupnoneed") {
675689
ty = DIFFE_TYPE::DUP_NONEED;
676690
} else if (*metaString == "enzyme_out") {
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -S | FileCheck %s
2+
3+
@enzyme_width = external global i32, align 4
4+
@enzyme_dupv = external global i32, align 4
5+
@enzyme_dupnoneedv = external global i32, align 4
6+
7+
8+
define void @square(double* nocapture readonly %x, double* nocapture %out) {
9+
entry:
10+
%0 = load double, double* %x, align 8
11+
%mul = fmul double %0, %0
12+
store double %mul, double* %out, align 8
13+
ret void
14+
}
15+
16+
define void @dsquare(double* %x, double* %dx, double* %out, double* %dout) {
17+
entry:
18+
%0 = load i32, i32* @enzyme_width, align 4
19+
%1 = load i32, i32* @enzyme_dupv, align 4
20+
%2 = load i32, i32* @enzyme_dupnoneedv, align 4
21+
call void (i8*, ...) @__enzyme_fwddiff(i8* bitcast (void (double*, double*)* @square to i8*), i32 %0, i32 3, i32 %1, i64 16, double* %x, double* %dx, i32 %1, i64 16, i32 %2, i32 16, double* %out, double* %dout)
22+
ret void
23+
}
24+
25+
declare void @__enzyme_fwddiff(i8*, ...)
26+
27+
; CHECK: define void @dsquare(double* %x, double* %dx, double* %out, double* %dout)
28+
; CHECK-NEXT: entry:
29+
; CHECK-NEXT: %0 = load i32, i32* @enzyme_width
30+
; CHECK-NEXT: %1 = load i32, i32* @enzyme_dupv
31+
; CHECK-NEXT: %2 = load i32, i32* @enzyme_dupnoneedv
32+
; CHECK-NEXT: %3 = bitcast double* %dx to i8*
33+
; CHECK-NEXT: %4 = getelementptr i8, i8* %3, i64 0
34+
; CHECK-NEXT: %5 = bitcast i8* %4 to double*
35+
; CHECK-NEXT: %6 = insertvalue [3 x double*] undef, double* %5, 0
36+
; CHECK-NEXT: %7 = bitcast double* %dx to i8*
37+
; CHECK-NEXT: %8 = getelementptr i8, i8* %7, i64 16
38+
; CHECK-NEXT: %9 = bitcast i8* %8 to double*
39+
; CHECK-NEXT: %10 = insertvalue [3 x double*] %6, double* %9, 1
40+
; CHECK-NEXT: %11 = bitcast double* %dx to i8*
41+
; CHECK-NEXT: %12 = getelementptr i8, i8* %11, i64 32
42+
; CHECK-NEXT: %13 = bitcast i8* %12 to double*
43+
; CHECK-NEXT: %14 = insertvalue [3 x double*] %10, double* %13, 2
44+
; CHECK-NEXT: %15 = bitcast double* %dout to i8*
45+
; CHECK-NEXT: %16 = getelementptr i8, i8* %15, i32 0
46+
; CHECK-NEXT: %17 = bitcast i8* %16 to double*
47+
; CHECK-NEXT: %18 = insertvalue [3 x double*] undef, double* %17, 0
48+
; CHECK-NEXT: %19 = bitcast double* %dout to i8*
49+
; CHECK-NEXT: %20 = getelementptr i8, i8* %19, i32 16
50+
; CHECK-NEXT: %21 = bitcast i8* %20 to double*
51+
; CHECK-NEXT: %22 = insertvalue [3 x double*] %18, double* %21, 1
52+
; CHECK-NEXT: %23 = bitcast double* %dout to i8*
53+
; CHECK-NEXT: %24 = getelementptr i8, i8* %23, i32 32
54+
; CHECK-NEXT: %25 = bitcast i8* %24 to double*
55+
; CHECK-NEXT: %26 = insertvalue [3 x double*] %22, double* %25, 2
56+
; CHECK-NEXT: call void @fwddiffe3square(double* %x, [3 x double*] %14, double* %out, [3 x double*] %26)
57+
; CHECK-NEXT: ret void
58+
; CHECK-NEXT: }
59+
60+
; CHECK: define internal void @fwddiffe3square(double* nocapture readonly %x, [3 x double*] %"x'", double* nocapture %out, [3 x double*] %"out'")
61+
; CHECK-NEXT: entry:
62+
; CHECK-NEXT: %0 = load double, double* %x
63+
; CHECK-NEXT: %1 = extractvalue [3 x double*] %"x'", 0
64+
; CHECK-NEXT: %2 = load double, double* %1
65+
; CHECK-NEXT: %3 = insertvalue [3 x double] undef, double %2, 0
66+
; CHECK-NEXT: %4 = extractvalue [3 x double*] %"x'", 1
67+
; CHECK-NEXT: %5 = load double, double* %4
68+
; CHECK-NEXT: %6 = insertvalue [3 x double] %3, double %5, 1
69+
; CHECK-NEXT: %7 = extractvalue [3 x double*] %"x'", 2
70+
; CHECK-NEXT: %8 = load double, double* %7
71+
; CHECK-NEXT: %9 = insertvalue [3 x double] %6, double %8, 2
72+
; CHECK-NEXT: %mul = fmul double %0, %0
73+
; CHECK-NEXT: %10 = extractvalue [3 x double] %9, 0
74+
; CHECK-NEXT: %11 = extractvalue [3 x double] %9, 0
75+
; CHECK-NEXT: %12 = fmul fast double %10, %0
76+
; CHECK-NEXT: %13 = fmul fast double %11, %0
77+
; CHECK-NEXT: %14 = fadd fast double %12, %13
78+
; CHECK-NEXT: %15 = insertvalue [3 x double] undef, double %14, 0
79+
; CHECK-NEXT: %16 = extractvalue [3 x double] %9, 1
80+
; CHECK-NEXT: %17 = extractvalue [3 x double] %9, 1
81+
; CHECK-NEXT: %18 = fmul fast double %16, %0
82+
; CHECK-NEXT: %19 = fmul fast double %17, %0
83+
; CHECK-NEXT: %20 = fadd fast double %18, %19
84+
; CHECK-NEXT: %21 = insertvalue [3 x double] %15, double %20, 1
85+
; CHECK-NEXT: %22 = extractvalue [3 x double] %9, 2
86+
; CHECK-NEXT: %23 = extractvalue [3 x double] %9, 2
87+
; CHECK-NEXT: %24 = fmul fast double %22, %0
88+
; CHECK-NEXT: %25 = fmul fast double %23, %0
89+
; CHECK-NEXT: %26 = fadd fast double %24, %25
90+
; CHECK-NEXT: %27 = insertvalue [3 x double] %21, double %26, 2
91+
; CHECK-NEXT: store double %mul, double* %out
92+
; CHECK-NEXT: %28 = extractvalue [3 x double*] %"out'", 0
93+
; CHECK-NEXT: %29 = extractvalue [3 x double] %27, 0
94+
; CHECK-NEXT: store double %29, double* %28
95+
; CHECK-NEXT: %30 = extractvalue [3 x double*] %"out'", 1
96+
; CHECK-NEXT: %31 = extractvalue [3 x double] %27, 1
97+
; CHECK-NEXT: store double %31, double* %30,
98+
; CHECK-NEXT: %32 = extractvalue [3 x double*] %"out'", 2
99+
; CHECK-NEXT: %33 = extractvalue [3 x double] %27, 2
100+
; CHECK-NEXT: store double %33, double* %32
101+
; CHECK-NEXT: ret void
102+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)