Skip to content

Commit f9c7b63

Browse files
authored
Fix typeanalysis recusion bug and noarg gep (rust-lang#784)
* Fix typeanalysis recusion bug and noarg gep * Update TypeAnalysis.cpp
1 parent ca49313 commit f9c7b63

File tree

4 files changed

+69
-2
lines changed

4 files changed

+69
-2
lines changed

enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,7 @@ void getConstantAnalysis(Constant *Val, TypeAnalyzer &TA,
476476
}
477477
if (!isa<StructType>(GV->getValueType()) ||
478478
!cast<StructType>(GV->getValueType())->isOpaque()) {
479-
auto globalSize = DL.getTypeSizeInBits(GV->getValueType()) / 8;
479+
auto globalSize = (DL.getTypeSizeInBits(GV->getValueType()) + 7) / 8;
480480
// Since halfs are 16bit (2 byte) and pointers are >=32bit (4 byte) any
481481
// Single byte object must be integral
482482
if (globalSize == 1) {
@@ -679,7 +679,7 @@ void TypeAnalyzer::updateAnalysis(Value *Val, TypeTree Data, Value *Origin) {
679679

680680
if (auto GV = dyn_cast<GlobalVariable>(Val)) {
681681
if (GV->getValueType()->isSized()) {
682-
auto Size = DL.getTypeSizeInBits(GV->getValueType()) / 8;
682+
auto Size = (DL.getTypeSizeInBits(GV->getValueType()) + 7) / 8;
683683
Data = analysis[Val].Lookup(Size, DL).Only(-1);
684684
Data.insert({-1}, BaseType::Pointer);
685685
analysis[Val] = Data;
@@ -1392,6 +1392,14 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
13921392
}
13931393
}
13941394

1395+
if (gep.indices().begin() == gep.indices().end()) {
1396+
if (direction & DOWN)
1397+
updateAnalysis(&gep, getAnalysis(gep.getPointerOperand()), &gep);
1398+
if (direction & UP)
1399+
updateAnalysis(gep.getPointerOperand(), getAnalysis(&gep), &gep);
1400+
return;
1401+
}
1402+
13951403
auto &DL = fntypeinfo.Function->getParent()->getDataLayout();
13961404

13971405
auto pointerAnalysis = getAnalysis(gep.getPointerOperand());

enzyme/test/TypeAnalysis/infglobal.ll

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=mainloop -o /dev/null | FileCheck %s
2+
3+
@timeron = internal unnamed_addr global i1 false, align 4
4+
5+
define void @mainloop() {
6+
entry:
7+
%a3 = load i1, i1* @timeron, align 4
8+
%c3 = load i1, i1* @timeron, align 4
9+
br i1 %a3, label %a4, label %a5
10+
11+
a4:
12+
br label %a5
13+
14+
a5:
15+
ret void
16+
}
17+
18+
; CHECK: mainloop - {} |
19+
; CHECK-NEXT: entry
20+
; CHECK-NEXT: %a3 = load i1, i1* @timeron, align 4: {[-1]:Integer}
21+
; CHECK-NEXT: %c3 = load i1, i1* @timeron, align 4: {[-1]:Integer}
22+
; CHECK-NEXT: br i1 %a3, label %a4, label %a5: {}
23+
; CHECK-NEXT: a4
24+
; CHECK-NEXT: br label %a5: {}
25+
; CHECK-NEXT: a5
26+
; CHECK-NEXT: ret void: {}
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=mainloop -o /dev/null | FileCheck %s
2+
3+
@timeron = internal unnamed_addr global float 0.000000e+00, align 4
4+
5+
define void @mainloop() {
6+
entry:
7+
%a3 = load float, float* @timeron, align 4
8+
%c3 = load float, float* @timeron, align 4
9+
%d = fadd float %a3, %c3
10+
%r = load float, float* @timeron, align 4
11+
ret void
12+
}
13+
14+
; CHECK: mainloop - {} |
15+
; CHECK-NEXT: entry
16+
; CHECK-NEXT: %a3 = load float, float* @timeron, align 4: {[-1]:Float@float}
17+
; CHECK-NEXT: %c3 = load float, float* @timeron, align 4: {[-1]:Float@float}
18+
; CHECK-NEXT: %d = fadd float %a3, %c3: {[-1]:Float@float}
19+
; CHECK-NEXT: %r = load float, float* @timeron, align 4: {[-1]:Float@float}
20+
; CHECK-NEXT: ret void: {}

enzyme/test/TypeAnalysis/noarggep.ll

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=square -o /dev/null | FileCheck %s
2+
3+
define internal void @square(double* %in) {
4+
entry:
5+
%out = getelementptr inbounds double, double* %in
6+
ret void
7+
}
8+
9+
; CHECK: square - {} |{[-1]:Pointer, [-1,-1]:Float@double}:{}
10+
; CHECK-NEXT: double* %in: {[-1]:Pointer, [-1,-1]:Float@double}
11+
; CHECK-NEXT: entry
12+
; CHECK-NEXT: %out = getelementptr inbounds double, double* %in: {[-1]:Pointer, [-1,-1]:Float@double}
13+
; CHECK-NEXT: ret void: {}

0 commit comments

Comments
 (0)