Skip to content

Commit b5c19df

Browse files
authored
Fix notype heuristic (rust-lang#581)
1 parent 902f67e commit b5c19df

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2729,23 +2729,23 @@ class AdjointGenerator
27292729
if (auto CI = dyn_cast<CastInst>(val)) {
27302730
if (auto PT = dyn_cast<PointerType>(CI->getSrcTy())) {
27312731
auto ET = PT->getPointerElementType();
2732+
while (1) {
2733+
if (auto ST = dyn_cast<StructType>(ET)) {
2734+
if (ST->getNumElements()) {
2735+
ET = ST->getElementType(0);
2736+
continue;
2737+
}
2738+
}
2739+
if (auto AT = dyn_cast<ArrayType>(ET)) {
2740+
ET = AT->getElementType();
2741+
continue;
2742+
}
2743+
break;
2744+
}
27322745
if (ET->isFPOrFPVectorTy()) {
27332746
vd = TypeTree(ConcreteType(ET->getScalarType())).Only(0);
27342747
goto known;
27352748
}
2736-
if (ET->isIntOrIntVectorTy()) {
2737-
vd = TypeTree(BaseType::Integer).Only(0);
2738-
goto known;
2739-
}
2740-
if (ET->isPointerTy()) {
2741-
vd = TypeTree(BaseType::Pointer).Only(0);
2742-
goto known;
2743-
}
2744-
while (auto ST = dyn_cast<StructType>(ET)) {
2745-
if (!ST->getNumElements())
2746-
break;
2747-
ET = ST->getElementType(0);
2748-
}
27492749
if (ET->isPointerTy()) {
27502750
vd = TypeTree(BaseType::Pointer).Only(0);
27512751
goto known;

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ Value *GradientUtils::unwrapM(Value *const val, IRBuilder<> &BuilderM,
203203
this->mode == DerivativeMode::ReverseModeCombined)
204204
if (auto inst = dyn_cast<Instruction>(val)) {
205205
if (inst->getParent()->getParent() == newFunc) {
206-
if (unwrapMode == UnwrapMode::LegalFullUnwrap) {
206+
if (unwrapMode == UnwrapMode::LegalFullUnwrap &&
207+
this->mode != DerivativeMode::ReverseModeCombined) {
207208
// TODO this isOriginal is a bottleneck, the new mapping of
208209
// knownRecompute should be precomputed and maintained to lookup
209210
// instead
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -mem2reg -instsimplify -simplifycfg -S -enzyme-loose-types | FileCheck %s
2+
3+
4+
%struct.tensor2 = type { %struct.tensor1 }
5+
%struct.tensor1 = type { [3 x double] }
6+
7+
define void @_Z9transposePK7tensor2([3 x double]* %A, [3 x double]* %ref.tmp) {
8+
entry:
9+
%a0 = bitcast [3 x double]* %ref.tmp to i8*
10+
%a1 = bitcast [3 x double]* %A to i8*
11+
call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %a0, i8* align 8 %a1, i64 24, i1 false)
12+
ret void
13+
}
14+
define dso_local void @_Z4callPK7tensor2S1_(%struct.tensor2* %A, %struct.tensor2* %dA) {
15+
entry:
16+
call void (i8*, ...) @_Z17__enzyme_autodiffPvz(i8* bitcast (void ([3 x double]*, [3 x double]*)* @_Z9transposePK7tensor2 to i8*), %struct.tensor2* %A, %struct.tensor2* %dA, %struct.tensor2* %A, %struct.tensor2* %dA)
17+
ret void
18+
}
19+
20+
declare void @_Z17__enzyme_autodiffPvz(i8*, ...)
21+
22+
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture, i8* noalias nocapture readonly, i64, i1)
23+
24+
; CHECK: define internal void @diffe_Z9transposePK7tensor2([3 x double]* %A, [3 x double]* %"A'", [3 x double]* %ref.tmp, [3 x double]* %"ref.tmp'")
25+
; CHECK-NEXT: entry:
26+
; CHECK-NEXT: %"a0'ipc" = bitcast [3 x double]* %"ref.tmp'" to i8*
27+
; CHECK-NEXT: %a0 = bitcast [3 x double]* %ref.tmp to i8*
28+
; CHECK-NEXT: %"a1'ipc" = bitcast [3 x double]* %"A'" to i8*
29+
; CHECK-NEXT: %a1 = bitcast [3 x double]* %A to i8*
30+
; CHECK-NEXT: call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %a0, i8* align 8 %a1, i64 24, i1 false)
31+
; CHECK-NEXT: %0 = bitcast i8* %"a0'ipc" to double*
32+
; CHECK-NEXT: %1 = bitcast i8* %"a1'ipc" to double*
33+
; CHECK-NEXT: call void @__enzyme_memcpyadd_doubleda8sa8(double* %0, double* %1, i64 3)
34+
; CHECK-NEXT: ret void
35+
; CHECK-NEXT: }

0 commit comments

Comments
 (0)