Skip to content

Commit b53b148

Browse files
authored
Fix memcpy addrspaces (rust-lang#571)
1 parent 5c89a86 commit b53b148

File tree

2 files changed

+45
-6
lines changed

2 files changed

+45
-6
lines changed

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6466,7 +6466,6 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode,
64666466
dsto, Type::getInt8PtrTy(dsto->getContext()));
64676467
unsigned dstaddr =
64686468
cast<PointerType>(dsto->getType())->getAddressSpace();
6469-
auto secretpt = PointerType::get(secretty, dstaddr);
64706469
if (offset != 0) {
64716470
#if LLVM_VERSION_MAJOR > 7
64726471
dsto = Builder2.CreateConstInBoundsGEP1_64(
@@ -6480,13 +6479,13 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode,
64806479
? shadow_src
64816480
: gutils->lookupM(shadow_src, Builder2);
64826481
if (mode != DerivativeMode::ForwardModeSplit)
6483-
dsto = Builder2.CreatePointerCast(dsto, secretpt);
6482+
dsto = Builder2.CreatePointerCast(
6483+
dsto, PointerType::get(secretty, dstaddr));
64846484
if (srco->getType()->isIntegerTy())
64856485
srco = Builder2.CreateIntToPtr(
64866486
srco, Type::getInt8PtrTy(srco->getContext()));
64876487
unsigned srcaddr =
64886488
cast<PointerType>(srco->getType())->getAddressSpace();
6489-
secretpt = PointerType::get(secretty, srcaddr);
64906489
if (offset != 0) {
64916490
#if LLVM_VERSION_MAJOR > 7
64926491
srco = Builder2.CreateConstInBoundsGEP1_64(
@@ -6496,7 +6495,8 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode,
64966495
#endif
64976496
}
64986497
if (mode != DerivativeMode::ForwardModeSplit)
6499-
srco = Builder2.CreatePointerCast(srco, secretpt);
6498+
srco = Builder2.CreatePointerCast(
6499+
srco, PointerType::get(secretty, srcaddr));
65006500

65016501
if (mode == DerivativeMode::ForwardModeSplit) {
65026502
#if LLVM_VERSION_MAJOR >= 11
@@ -6518,8 +6518,10 @@ void SubTransferHelper(GradientUtils *gutils, DerivativeMode mode,
65186518
}
65196519
} else {
65206520
Value *args[]{
6521-
Builder2.CreatePointerCast(dsto, secretpt),
6522-
Builder2.CreatePointerCast(srco, secretpt),
6521+
Builder2.CreatePointerCast(dsto,
6522+
PointerType::get(secretty, dstaddr)),
6523+
Builder2.CreatePointerCast(srco,
6524+
PointerType::get(secretty, srcaddr)),
65236525
Builder2.CreateUDiv(
65246526
gutils->lookupM(length, Builder2),
65256527
ConstantInt::get(length->getType(),
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -simplifycfg -dce -instcombine -S | FileCheck %s
2+
3+
; Function Attrs: nounwind uwtable
4+
define dso_local void @memcpy_float(double addrspace(13)* nocapture %dst, double addrspace(10)* nocapture readonly %src, i64 %num) #0 {
5+
entry:
6+
%0 = bitcast double addrspace(13)* %dst to i8 addrspace(13)*
7+
%1 = bitcast double addrspace(10)* %src to i8 addrspace(10)*
8+
tail call void @llvm.memcpy.p13i8.p10i8.i64(i8 addrspace(13)* align 1 %0, i8 addrspace(10)* align 1 %1, i64 %num, i1 false)
9+
ret void
10+
}
11+
12+
; Function Attrs: argmemonly nounwind
13+
declare void @llvm.memcpy.p13i8.p10i8.i64(i8 addrspace(13)* nocapture writeonly, i8 addrspace(10)* nocapture readonly, i64, i1) #1
14+
15+
; Function Attrs: nounwind uwtable
16+
define dso_local void @dmemcpy_float(double addrspace(13)* %dst, double addrspace(13)* %dstp, double addrspace(10)* %src, double addrspace(10)* %srcp, i64 %n) local_unnamed_addr #0 {
17+
entry:
18+
tail call void (...) @__enzyme_autodiff.f64(void (double addrspace(13)*, double addrspace(10)*, i64)* nonnull @memcpy_float, double addrspace(13)* %dst, double addrspace(13)* %dstp, double addrspace(10)* %src, double addrspace(10)* %srcp, i64 %n) #3
19+
ret void
20+
}
21+
22+
declare void @__enzyme_autodiff.f64(...)
23+
24+
attributes #0 = { nounwind uwtable }
25+
attributes #1 = { argmemonly nounwind }
26+
attributes #2 = { noinline nounwind uwtable }
27+
attributes #3 = { nounwind }
28+
29+
; CHECK: define internal void @diffememcpy_float(double addrspace(13)* nocapture %dst, double addrspace(13)* nocapture %"dst'", double addrspace(10)* nocapture readonly %src, double addrspace(10)* nocapture %"src'", i64 %num)
30+
; CHECK-NEXT: entry:
31+
; CHECK-NEXT: %0 = bitcast double addrspace(13)* %dst to i8 addrspace(13)*
32+
; CHECK-NEXT: %1 = bitcast double addrspace(10)* %src to i8 addrspace(10)*
33+
; CHECK-NEXT: tail call void @llvm.memcpy.p13i8.p10i8.i64(i8 addrspace(13)* align 1 %0, i8 addrspace(10)* align 1 %1, i64 %num, i1 false)
34+
; CHECK-NEXT: %2 = lshr i64 %num, 3
35+
; CHECK-NEXT: call void @__enzyme_memcpyadd_doubleda1sa1dadd13sadd10(double addrspace(13)* %"dst'", double addrspace(10)* %"src'", i64 %2)
36+
; CHECK-NEXT: ret void
37+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)