Skip to content

Commit 79e17bc

Browse files
committed
add new tests for autodiff batching and update old ones
1 parent b7c63a9 commit 79e17bc

File tree

7 files changed

+251
-54
lines changed

7 files changed

+251
-54
lines changed

tests/codegen/autodiff.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fn square(x: &f64) -> f64 {
1111
x * x
1212
}
1313

14-
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture align 8 %"x'"
14+
// CHECK:define internal fastcc double @diffesquare(double %x.0.val, ptr nocapture nonnull align 8 %"x'"
1515
// CHECK-NEXT:invertstart:
1616
// CHECK-NEXT: %_0 = fmul double %x.0.val, %x.0.val
1717
// CHECK-NEXT: %0 = fadd fast double %x.0.val, %x.0.val
@@ -22,7 +22,7 @@ fn square(x: &f64) -> f64 {
2222
// CHECK-NEXT:}
2323

2424
fn main() {
25-
let x = 3.0;
25+
let x = std::hint::black_box(3.0);
2626
let output = square(&x);
2727
assert_eq!(9.0, output);
2828

tests/codegen/autodiffv.rs

+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
//
5+
// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
6+
// breakages. One benefit is that we match the IR generated by Enzyme only after running it
7+
// through LLVM's O3 pipeline, which will remove most of the noise.
8+
// However, our integration test could also be affected by changes in how rustc lowers MIR into
9+
// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
10+
// reduce this test to only match the first lines and the ret instructions.
11+
12+
#![feature(autodiff)]
13+
14+
use std::autodiff::autodiff;
15+
16+
#[autodiff(d_square3, Forward, Dual, DualOnly)]
17+
#[autodiff(d_square2, Forward, 4, Dual, DualOnly)]
18+
#[autodiff(d_square1, Forward, 4, Dual, Dual)]
19+
#[no_mangle]
20+
fn square(x: &f32) -> f32 {
21+
x * x
22+
}
23+
24+
// d_sqaure2
25+
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
26+
// CHECK-NEXT: start:
27+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
28+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
29+
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
30+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
31+
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
32+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
33+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
34+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
35+
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
36+
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
37+
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
38+
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
39+
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
40+
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
41+
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
42+
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
43+
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
44+
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
45+
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
46+
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
47+
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
48+
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
49+
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
50+
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
51+
// CHECK-NEXT: ret [4 x float] %19
52+
// CHECK-NEXT: }
53+
54+
// d_square3, the extra float is the original return value (x * x)
55+
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
56+
// CHECK-NEXT: start:
57+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
58+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
59+
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
60+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
61+
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
62+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
63+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
64+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
65+
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
66+
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
67+
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
68+
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
69+
// CHECK-NEXT: %7 = insertelement <4 x float> %6, float %"_2'ipl3", i64 3
70+
// CHECK-NEXT: %8 = fadd fast <4 x float> %7, %7
71+
// CHECK-NEXT: %9 = insertelement <4 x float> poison, float %x.0.val, i64 0
72+
// CHECK-NEXT: %10 = shufflevector <4 x float> %9, <4 x float> poison, <4 x i32> zeroinitializer
73+
// CHECK-NEXT: %11 = fmul fast <4 x float> %8, %10
74+
// CHECK-NEXT: %12 = extractelement <4 x float> %11, i64 0
75+
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
76+
// CHECK-NEXT: %14 = extractelement <4 x float> %11, i64 1
77+
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
78+
// CHECK-NEXT: %16 = extractelement <4 x float> %11, i64 2
79+
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
80+
// CHECK-NEXT: %18 = extractelement <4 x float> %11, i64 3
81+
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
82+
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
83+
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
84+
// CHECK-NEXT: ret { float, [4 x float] } %21
85+
// CHECK-NEXT: }
86+
87+
fn main() {
88+
let x = std::hint::black_box(3.0);
89+
let output = square(&x);
90+
dbg!(&output);
91+
assert_eq!(9.0, output);
92+
dbg!(square(&x));
93+
94+
let mut df_dx1 = 1.0;
95+
let mut df_dx2 = 2.0;
96+
let mut df_dx3 = 3.0;
97+
let mut df_dx4 = 0.0;
98+
let [o1, o2, o3, o4] = d_square2(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
99+
dbg!(o1, o2, o3, o4);
100+
let [output2, o1, o2, o3, o4] =
101+
d_square1(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4);
102+
dbg!(o1, o2, o3, o4);
103+
assert_eq!(output, output2);
104+
assert!((6.0 - o1).abs() < 1e-10);
105+
assert!((12.0 - o2).abs() < 1e-10);
106+
assert!((18.0 - o3).abs() < 1e-10);
107+
assert!((0.0 - o4).abs() < 1e-10);
108+
assert_eq!(1.0, df_dx1);
109+
assert_eq!(2.0, df_dx2);
110+
assert_eq!(3.0, df_dx3);
111+
assert_eq!(0.0, df_dx4);
112+
assert_eq!(d_square3(&x, &mut df_dx1), 2.0 * o1);
113+
assert_eq!(d_square3(&x, &mut df_dx2), 2.0 * o2);
114+
assert_eq!(d_square3(&x, &mut df_dx3), 2.0 * o3);
115+
assert_eq!(d_square3(&x, &mut df_dx4), 2.0 * o4);
116+
}

tests/pretty/autodiff_forward.pp

+79-21
Original file line numberDiff line numberDiff line change
@@ -25,48 +25,52 @@
2525

2626
// We want to be sure that the same function can be differentiated in different ways
2727

28+
29+
// Make sure, that we add the None for the default return.
30+
31+
2832
::core::panicking::panic("not implemented")
2933
}
30-
#[rustc_autodiff(Forward, Dual, Const, Dual,)]
34+
#[rustc_autodiff(Forward, 1, Dual, Const, Dual)]
3135
#[inline(never)]
32-
pub fn df1(x: &[f64], bx: &[f64], y: f64) -> (f64, f64) {
36+
pub fn df1(x: &[f64], bx_0: &[f64], y: f64) -> (f64, f64) {
3337
unsafe { asm!("NOP", options(pure, nomem)); };
3438
::core::hint::black_box(f1(x, y));
35-
::core::hint::black_box((bx,));
36-
::core::hint::black_box((f1(x, y), f64::default()))
39+
::core::hint::black_box((bx_0,));
40+
::core::hint::black_box(<(f64, f64)>::default())
3741
}
3842
#[rustc_autodiff]
3943
#[inline(never)]
4044
pub fn f2(x: &[f64], y: f64) -> f64 {
4145
::core::panicking::panic("not implemented")
4246
}
43-
#[rustc_autodiff(Forward, Dual, Const, Const,)]
47+
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
4448
#[inline(never)]
45-
pub fn df2(x: &[f64], bx: &[f64], y: f64) -> f64 {
49+
pub fn df2(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
4650
unsafe { asm!("NOP", options(pure, nomem)); };
4751
::core::hint::black_box(f2(x, y));
48-
::core::hint::black_box((bx,));
52+
::core::hint::black_box((bx_0,));
4953
::core::hint::black_box(f2(x, y))
5054
}
5155
#[rustc_autodiff]
5256
#[inline(never)]
5357
pub fn f3(x: &[f64], y: f64) -> f64 {
5458
::core::panicking::panic("not implemented")
5559
}
56-
#[rustc_autodiff(Forward, Dual, Const, Const,)]
60+
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
5761
#[inline(never)]
58-
pub fn df3(x: &[f64], bx: &[f64], y: f64) -> f64 {
62+
pub fn df3(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
5963
unsafe { asm!("NOP", options(pure, nomem)); };
6064
::core::hint::black_box(f3(x, y));
61-
::core::hint::black_box((bx,));
65+
::core::hint::black_box((bx_0,));
6266
::core::hint::black_box(f3(x, y))
6367
}
6468
#[rustc_autodiff]
6569
#[inline(never)]
6670
pub fn f4() {}
67-
#[rustc_autodiff(Forward, None)]
71+
#[rustc_autodiff(Forward, 1, None)]
6872
#[inline(never)]
69-
pub fn df4() {
73+
pub fn df4() -> () {
7074
unsafe { asm!("NOP", options(pure, nomem)); };
7175
::core::hint::black_box(f4());
7276
::core::hint::black_box(());
@@ -76,28 +80,82 @@
7680
pub fn f5(x: &[f64], y: f64) -> f64 {
7781
::core::panicking::panic("not implemented")
7882
}
79-
#[rustc_autodiff(Forward, Const, Dual, Const,)]
83+
#[rustc_autodiff(Forward, 1, Const, Dual, Const)]
8084
#[inline(never)]
81-
pub fn df5_y(x: &[f64], y: f64, by: f64) -> f64 {
85+
pub fn df5_y(x: &[f64], y: f64, by_0: f64) -> f64 {
8286
unsafe { asm!("NOP", options(pure, nomem)); };
8387
::core::hint::black_box(f5(x, y));
84-
::core::hint::black_box((by,));
88+
::core::hint::black_box((by_0,));
8589
::core::hint::black_box(f5(x, y))
8690
}
87-
#[rustc_autodiff(Forward, Dual, Const, Const,)]
91+
#[rustc_autodiff(Forward, 1, Dual, Const, Const)]
8892
#[inline(never)]
89-
pub fn df5_x(x: &[f64], bx: &[f64], y: f64) -> f64 {
93+
pub fn df5_x(x: &[f64], bx_0: &[f64], y: f64) -> f64 {
9094
unsafe { asm!("NOP", options(pure, nomem)); };
9195
::core::hint::black_box(f5(x, y));
92-
::core::hint::black_box((bx,));
96+
::core::hint::black_box((bx_0,));
9397
::core::hint::black_box(f5(x, y))
9498
}
95-
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
99+
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
96100
#[inline(never)]
97-
pub fn df5_rev(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
101+
pub fn df5_rev(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
98102
unsafe { asm!("NOP", options(pure, nomem)); };
99103
::core::hint::black_box(f5(x, y));
100-
::core::hint::black_box((dx, dret));
104+
::core::hint::black_box((dx_0, dret));
101105
::core::hint::black_box(f5(x, y))
102106
}
107+
struct DoesNotImplDefault;
108+
#[rustc_autodiff]
109+
#[inline(never)]
110+
pub fn f6() -> DoesNotImplDefault {
111+
::core::panicking::panic("not implemented")
112+
}
113+
#[rustc_autodiff(Forward, 1, Const)]
114+
#[inline(never)]
115+
pub fn df6() -> DoesNotImplDefault {
116+
unsafe { asm!("NOP", options(pure, nomem)); };
117+
::core::hint::black_box(f6());
118+
::core::hint::black_box(());
119+
::core::hint::black_box(f6())
120+
}
121+
#[rustc_autodiff]
122+
#[inline(never)]
123+
pub fn f7(x: f32) -> () {}
124+
#[rustc_autodiff(Forward, 1, Const, None)]
125+
#[inline(never)]
126+
pub fn df7(x: f32) -> () {
127+
unsafe { asm!("NOP", options(pure, nomem)); };
128+
::core::hint::black_box(f7(x));
129+
::core::hint::black_box(());
130+
}
131+
#[no_mangle]
132+
#[rustc_autodiff]
133+
#[inline(never)]
134+
fn f8(x: &f32) -> f32 { ::core::panicking::panic("not implemented") }
135+
#[rustc_autodiff(Forward, 4, Dual, Dual)]
136+
#[inline(never)]
137+
fn f8_3(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
138+
-> [f32; 5usize] {
139+
unsafe { asm!("NOP", options(pure, nomem)); };
140+
::core::hint::black_box(f8(x));
141+
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
142+
::core::hint::black_box(<[f32; 5usize]>::default())
143+
}
144+
#[rustc_autodiff(Forward, 4, Dual, DualOnly)]
145+
#[inline(never)]
146+
fn f8_2(x: &f32, bx_0: &f32, bx_1: &f32, bx_2: &f32, bx_3: &f32)
147+
-> [f32; 4usize] {
148+
unsafe { asm!("NOP", options(pure, nomem)); };
149+
::core::hint::black_box(f8(x));
150+
::core::hint::black_box((bx_0, bx_1, bx_2, bx_3));
151+
::core::hint::black_box(<[f32; 4usize]>::default())
152+
}
153+
#[rustc_autodiff(Forward, 1, Dual, DualOnly)]
154+
#[inline(never)]
155+
fn f8_1(x: &f32, bx_0: &f32) -> f32 {
156+
unsafe { asm!("NOP", options(pure, nomem)); };
157+
::core::hint::black_box(f8(x));
158+
::core::hint::black_box((bx_0,));
159+
::core::hint::black_box(<f32>::default())
160+
}
103161
fn main() {}

tests/pretty/autodiff_forward.rs

+18
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,22 @@ pub fn f5(x: &[f64], y: f64) -> f64 {
3636
unimplemented!()
3737
}
3838

39+
struct DoesNotImplDefault;
40+
#[autodiff(df6, Forward, Const)]
41+
pub fn f6() -> DoesNotImplDefault {
42+
unimplemented!()
43+
}
44+
45+
// Make sure, that we add the None for the default return.
46+
#[autodiff(df7, Forward, Const)]
47+
pub fn f7(x: f32) -> () {}
48+
49+
#[autodiff(f8_1, Forward, Dual, DualOnly)]
50+
#[autodiff(f8_2, Forward, 4, Dual, DualOnly)]
51+
#[autodiff(f8_3, Forward, 4, Dual, Dual)]
52+
#[no_mangle]
53+
fn f8(x: &f32) -> f32 {
54+
unimplemented!()
55+
}
56+
3957
fn main() {}

tests/pretty/autodiff_reverse.pp

+11-11
Original file line numberDiff line numberDiff line change
@@ -28,18 +28,18 @@
2828
2929
::core::panicking::panic("not implemented")
3030
}
31-
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
31+
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
3232
#[inline(never)]
33-
pub fn df1(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
33+
pub fn df1(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
3434
unsafe { asm!("NOP", options(pure, nomem)); };
3535
::core::hint::black_box(f1(x, y));
36-
::core::hint::black_box((dx, dret));
36+
::core::hint::black_box((dx_0, dret));
3737
::core::hint::black_box(f1(x, y))
3838
}
3939
#[rustc_autodiff]
4040
#[inline(never)]
4141
pub fn f2() {}
42-
#[rustc_autodiff(Reverse, None)]
42+
#[rustc_autodiff(Reverse, 1, None)]
4343
#[inline(never)]
4444
pub fn df2() {
4545
unsafe { asm!("NOP", options(pure, nomem)); };
@@ -51,20 +51,20 @@
5151
pub fn f3(x: &[f64], y: f64) -> f64 {
5252
::core::panicking::panic("not implemented")
5353
}
54-
#[rustc_autodiff(Reverse, Duplicated, Const, Active,)]
54+
#[rustc_autodiff(Reverse, 1, Duplicated, Const, Active)]
5555
#[inline(never)]
56-
pub fn df3(x: &[f64], dx: &mut [f64], y: f64, dret: f64) -> f64 {
56+
pub fn df3(x: &[f64], dx_0: &mut [f64], y: f64, dret: f64) -> f64 {
5757
unsafe { asm!("NOP", options(pure, nomem)); };
5858
::core::hint::black_box(f3(x, y));
59-
::core::hint::black_box((dx, dret));
59+
::core::hint::black_box((dx_0, dret));
6060
::core::hint::black_box(f3(x, y))
6161
}
6262
enum Foo { Reverse, }
6363
use Foo::Reverse;
6464
#[rustc_autodiff]
6565
#[inline(never)]
6666
pub fn f4(x: f32) { ::core::panicking::panic("not implemented") }
67-
#[rustc_autodiff(Reverse, Const, None)]
67+
#[rustc_autodiff(Reverse, 1, Const, None)]
6868
#[inline(never)]
6969
pub fn df4(x: f32) {
7070
unsafe { asm!("NOP", options(pure, nomem)); };
@@ -76,11 +76,11 @@
7676
pub fn f5(x: *const f32, y: &f32) {
7777
::core::panicking::panic("not implemented")
7878
}
79-
#[rustc_autodiff(Reverse, DuplicatedOnly, Duplicated, None)]
79+
#[rustc_autodiff(Reverse, 1, DuplicatedOnly, Duplicated, None)]
8080
#[inline(never)]
81-
pub unsafe fn df5(x: *const f32, dx: *mut f32, y: &f32, dy: &mut f32) {
81+
pub unsafe fn df5(x: *const f32, dx_0: *mut f32, y: &f32, dy_0: &mut f32) {
8282
unsafe { asm!("NOP", options(pure, nomem)); };
8383
::core::hint::black_box(f5(x, y));
84-
::core::hint::black_box((dx, dy));
84+
::core::hint::black_box((dx_0, dy_0));
8585
}
8686
fn main() {}

tests/ui/autodiff/autodiff_illegal.rs

+7
Original file line numberDiff line numberDiff line change
@@ -177,4 +177,11 @@ fn f21(x: f32) -> f32 {
177177
unimplemented!()
178178
}
179179

180+
struct DoesNotImplDefault;
181+
#[autodiff(df22, Forward, Dual)]
182+
pub fn f22() -> DoesNotImplDefault {
183+
//~^^ ERROR the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
184+
unimplemented!()
185+
}
186+
180187
fn main() {}

0 commit comments

Comments
 (0)