Skip to content

Commit 5bf2105

Browse files
timkalerwsmoses
authored andcommitted
Selective cachereads (#21)
* Enable cache reads by default, which is needed for correctness. * Selectively omit caching for reads whose value does not change after the load instruction. Loads that are modified after the load instruction are called "uncacheable" in the code. * Propagate the uncacheable status of pointer arguments to calls. * readwriteread C test illustrates behavior of code.
1 parent 189a8ff commit 5bf2105

21 files changed

+639
-169
lines changed

enzyme/Enzyme/Enzyme.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ void HandleAutoDiff(CallInst *CI, TargetLibraryInfo &TLI, AAResults &AA) {//, Lo
155155

156156
bool differentialReturn = cast<Function>(fn)->getReturnType()->isFPOrFPVectorTy();
157157

158-
auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr);//, LI, DT);
158+
std::set<unsigned> volatile_args;
159+
auto newFunc = CreatePrimalAndGradient(cast<Function>(fn), constants, TLI, AA, /*should return*/false, differentialReturn, /*topLevel*/true, /*addedType*/nullptr, volatile_args);//, LI, DT);
159160

160161
if (differentialReturn)
161162
args.push_back(ConstantFP::get(cast<Function>(fn)->getReturnType(), 1.0));

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 411 additions & 36 deletions
Large diffs are not rendered by default.

enzyme/Enzyme/EnzymeLogic.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,6 @@ extern llvm::cl::opt<bool> enzyme_print;
3636
//! return structtype if recursive function
3737
std::pair<llvm::Function*,llvm::StructType*> CreateAugmentedPrimal(llvm::Function* todiff, llvm::AAResults &AA, const std::set<unsigned>& constant_args, llvm::TargetLibraryInfo &TLI, bool differentialReturn);
3838

39-
llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set<unsigned>& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg);
39+
llvm::Function* CreatePrimalAndGradient(llvm::Function* todiff, const std::set<unsigned>& constant_args, llvm::TargetLibraryInfo &TLI, llvm::AAResults &AA, bool returnValue, bool differentialReturn, bool topLevel, llvm::Type* additionalArg, std::set<unsigned> volatile_args);
4040

4141
#endif

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,13 @@ PHINode* canonicalizeIVs(fake::SCEVExpander &e, Type *Ty, Loop *L, DominatorTree
164164

165165
Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI) {
166166
static std::map<Function*,Function*> cache;
167-
if (cache.find(F) != cache.end()) return cache[F];
168-
167+
static std::map<Function*, BasicAAResult*> cache_AA;
168+
llvm::errs() << "Before cache lookup for " << F->getName() << "\n";
169+
if (cache.find(F) != cache.end()) {
170+
AA.addAAResult(*(cache_AA[F]));
171+
return cache[F];
172+
}
173+
llvm::errs() << "Did not do cache lookup for " << F->getName() << "\n";
169174
Function *NewF = Function::Create(F->getFunctionType(), F->getLinkage(), "preprocess_" + F->getName(), F->getParent());
170175

171176
ValueToValueMapTy VMap;
@@ -439,7 +444,7 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
439444
FunctionAnalysisManager AM;
440445
AM.registerPass([] { return AAManager(); });
441446
AM.registerPass([] { return ScalarEvolutionAnalysis(); });
442-
AM.registerPass([] { return AssumptionAnalysis(); });
447+
//AM.registerPass([] { return AssumptionAnalysis(); });
443448
AM.registerPass([] { return TargetLibraryAnalysis(); });
444449
AM.registerPass([] { return TargetIRAnalysis(); });
445450
AM.registerPass([] { return LoopAnalysis(); });
@@ -458,13 +463,22 @@ Function* preprocessForClone(Function *F, AAResults &AA, TargetLibraryInfo &TLI)
458463
MAM.registerPass([&] { return FunctionAnalysisManagerModuleProxy(AM); });
459464

460465
//Alias analysis is necessary to ensure can query whether we can move a forward pass function
461-
BasicAA ba;
462-
auto baa = new BasicAAResult(ba.run(*NewF, AM));
466+
//BasicAA ba;
467+
//auto baa = new BasicAAResult(ba.run(*NewF, AM));
468+
AssumptionCache* AC = new AssumptionCache(*NewF);
469+
TargetLibraryInfo* TLI = new TargetLibraryInfo(AM.getResult<TargetLibraryAnalysis>(*NewF));
470+
auto baa = new BasicAAResult(NewF->getParent()->getDataLayout(),
471+
*NewF,
472+
*TLI,
473+
*AC,
474+
&AM.getResult<DominatorTreeAnalysis>(*NewF),
475+
AM.getCachedResult<LoopAnalysis>(*NewF),
476+
AM.getCachedResult<PhiValuesAnalysis>(*NewF));
477+
cache_AA[F] = baa;
463478
AA.addAAResult(*baa);
464-
465-
ScopedNoAliasAA sa;
466-
auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM));
467-
AA.addAAResult(*saa);
479+
//ScopedNoAliasAA sa;
480+
//auto saa = new ScopedNoAliasAAResult(sa.run(*NewF, AM));
481+
//AA.addAAResult(*saa);
468482

469483
}
470484

enzyme/Enzyme/GradientUtils.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,8 @@ Value* GradientUtils::invertPointerM(Value* val, IRBuilder<>& BuilderM) {
351351
return invertedPointers[val] = cs;
352352
} else if (auto fn = dyn_cast<Function>(val)) {
353353
//! Todo allow tape propagation
354-
auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr);
354+
std::set<unsigned> uncacheable_args;
355+
auto newf = CreatePrimalAndGradient(fn, /*constant_args*/{}, TLI, AA, /*returnValue*/false, /*differentialReturn*/fn->getReturnType()->isFPOrFPVectorTy(), /*topLevel*/false, /*additionalArg*/nullptr, uncacheable_args);
355356
return BuilderM.CreatePointerCast(newf, fn->getType());
356357
} else if (auto arg = dyn_cast<CastInst>(val)) {
357358
auto result = BuilderM.CreateCast(arg->getOpcode(), invertPointerM(arg->getOperand(0), BuilderM), arg->getDestTy(), arg->getName()+"'ipc");
@@ -824,10 +825,12 @@ Value* GradientUtils::lookupM(Value* val, IRBuilder<>& BuilderM) {
824825
}
825826
}
826827

827-
if (!shouldRecompute(inst, available)) {
828-
auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true);
829-
assert(op);
830-
return op;
828+
if (!(*(this->can_modref_map))[inst]) {
829+
if (!shouldRecompute(inst, available)) {
830+
auto op = unwrapM(inst, BuilderM, available, /*lookupIfAble*/true);
831+
assert(op);
832+
return op;
833+
}
831834
}
832835
/*
833836
if (!inLoop) {

enzyme/Enzyme/GradientUtils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ class GradientUtils {
8989
ValueToValueMapTy scopeFrees;
9090
ValueToValueMapTy originalToNewFn;
9191

92+
std::map<Instruction*, bool>* can_modref_map;
93+
94+
9295
Value* getNewFromOriginal(Value* originst) {
9396
assert(originst);
9497
auto f = originalToNewFn.find(originst);
@@ -507,7 +510,7 @@ class GradientUtils {
507510
}
508511
assert(lastScopeAlloc.find(malloc) == lastScopeAlloc.end());
509512
cast<Instruction>(malloc)->replaceAllUsesWith(ret);
510-
auto n = malloc->getName();
513+
std::string n = malloc->getName().str();
511514
erase(cast<Instruction>(malloc));
512515
ret->setName(n);
513516
}

enzyme/functional_tests_c/Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ OBJ := $(wildcard *.c)
1818

1919
all: $(patsubst %.c,build/%-enzyme0,$(OBJ)) $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ))
2020

21-
POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true
21+
POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg
2222

2323
#all: $(patsubst %.c,build/%-enzyme1,$(OBJ)) $(patsubst %.c,build/%-enzyme2,$(OBJ)) $(patsubst %.c,build/%-enzyme3,$(OBJ))
2424
#clean:
@@ -31,7 +31,7 @@ POST_ENZYME_FLAGS := -mem2reg -sroa -adce -simplifycfg -enzyme_cachereads=true
3131

3232
#EXTRA_FLAGS = -indvars -loop-simplify -loop-rotate
3333

34-
# NOTE(TFK): Optimization level 0 is broken right now.
34+
# /efs/home/tfk/valgrind-3.12.0/vg-in-place
3535
build/%-enzyme0: %.c
3636
@./setup.sh $(CLANG_BIN_PATH)/clang -std=c11 -O1 $(patsubst %.c,%,$<).c -S -emit-llvm -o $@.ll
3737
@./setup.sh $(CLANG_BIN_PATH)/opt $@.ll $(EXTRA_FLAGS) -load=$(ENZYME_PLUGIN) -enzyme $(POST_ENZYME_FLAGS) -o $@.bc

enzyme/functional_tests_c/insertsort_sum.c

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,8 @@ float* unsorted_array_init(int N) {
1616
return arr;
1717
}
1818

19-
// sums the first half of a sorted array.
20-
void insertsort_sum (float* array, int N, float* ret) {
19+
void insertsort_sum (float*__restrict array, int N, float*__restrict ret) {
2120
float sum = 0;
22-
//qsort(array, N, sizeof(float), cmp);
2321

2422
for (int i = 1; i < N; i++) {
2523
int j = i;
@@ -31,30 +29,16 @@ void insertsort_sum (float* array, int N, float* ret) {
3129
}
3230
}
3331

34-
3532
for (int i = 0; i < N/2; i++) {
36-
printf("Val: %f\n", array[i]);
33+
//printf("Val: %f\n", array[i]);
3734
sum += array[i];
3835
}
36+
3937
*ret = sum;
4038
}
4139

4240

43-
44-
4541
int main(int argc, char** argv) {
46-
47-
48-
49-
float a = 2.0;
50-
float b = 3.0;
51-
52-
53-
54-
float da = 0;
55-
float db = 0;
56-
57-
5842
float ret = 0;
5943
float dret = 1.0;
6044

@@ -71,18 +55,15 @@ int main(int argc, char** argv) {
7155
printf("%d:%f\n", i, array[i]);
7256
}
7357

74-
//insertsort_sum(array, N, &ret);
58+
__builtin_autodiff(insertsort_sum, array, d_array, N, &ret, &dret);
59+
60+
printf("The total sum is %f\n", ret);
7561

7662
printf("Array after sorting:\n");
7763
for (int i = 0; i < N; i++) {
7864
printf("%d:%f\n", i, array[i]);
7965
}
8066

81-
82-
printf("The total sum is %f\n", ret);
83-
84-
__builtin_autodiff(insertsort_sum, array, d_array, N, &ret, &dret);
85-
8667
for (int i = 0; i < N; i++) {
8768
printf("Diffe for index %d is %f\n", i, d_array[i]);
8869
if (i%2 == 0) {
@@ -91,13 +72,5 @@ int main(int argc, char** argv) {
9172
assert(d_array[i] == 1.0);
9273
}
9374
}
94-
95-
//__builtin_autodiff(compute_loops, &a, &da, &b, &db, &ret, &dret);
96-
97-
98-
//assert(da == 100*1.0f);
99-
//assert(db == 100*1.0f);
100-
101-
//printf("hello! %f, res2 %f, da: %f, db: %f\n", ret, ret, da,db);
10275
return 0;
10376
}

enzyme/functional_tests_c/insertsort_sum_alt.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ void insertion_sort_inner(float* array, int i) {
3535
}
3636

3737
// sums the first half of a sorted array.
38-
void insertsort_sum (float* array, int N, float* ret) {
38+
void insertsort_sum (float*__restrict array, int N, float*__restrict ret) {
3939
float sum = 0;
4040
//qsort(array, N, sizeof(float), cmp);
4141

@@ -45,7 +45,7 @@ void insertsort_sum (float* array, int N, float* ret) {
4545

4646

4747
for (int i = 0; i < N/2; i++) {
48-
printf("Val: %f\n", array[i]);
48+
//printf("Val: %f\n", array[i]);
4949
sum += array[i];
5050
}
5151
*ret = sum;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <math.h>
4+
#include <assert.h>
5+
#define __builtin_autodiff __enzyme_autodiff
6+
double __enzyme_autodiff(void*, ...);
7+
8+
double f_read(double* x) {
9+
double product = (*x) * (*x);
10+
return product;
11+
}
12+
13+
void g_write(double* x, double product) {
14+
*x = (*x) * product;
15+
}
16+
17+
double h_read(double* x) {
18+
return *x;
19+
}
20+
21+
double readwriteread_helper(double* x) {
22+
double product = f_read(x);
23+
g_write(x, product);
24+
double ret = h_read(x);
25+
return ret;
26+
}
27+
28+
void readwriteread(double*__restrict x, double*__restrict ret) {
29+
*ret = readwriteread_helper(x);
30+
}
31+
32+
int main(int argc, char** argv) {
33+
double ret = 0;
34+
double dret = 1.0;
35+
double* x = (double*) malloc(sizeof(double));
36+
double* dx = (double*) malloc(sizeof(double));
37+
*x = 2.0;
38+
*dx = 0.0;
39+
40+
__builtin_autodiff(readwriteread, x, dx, &ret, &dret);
41+
42+
43+
printf("dx is %f ret is %f\n", *dx, ret);
44+
assert(*dx == 3*2.0*2.0);
45+
return 0;
46+
}

enzyme/functional_tests_c/setup.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
#!/bin/bash
22

33
# NOTE(TFK): Uncomment for local testing.
4-
export CLANG_BIN_PATH=./../../build-dbg/bin
5-
export ENZYME_PLUGIN=./../mkdebug/Enzyme/LLVMEnzyme-7.so
4+
export CLANG_BIN_PATH=./../../llvm/build/bin/
5+
export ENZYME_PLUGIN=./../build/Enzyme/LLVMEnzyme-7.so
66

77
mkdir -p build
88
$@
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
; RUN: cd %desired_wd
2+
; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme
3+
; RUN: make build/readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
4+
; RUN: build/readwriteread-enzyme0
5+
; RUN: make clean-readwriteread-enzyme0 ENZYME_PLUGIN=%loadEnzyme
6+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
; RUN: cd %desired_wd
2+
; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme
3+
; RUN: make build/readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
4+
; RUN: build/readwriteread-enzyme1
5+
; RUN: make clean-readwriteread-enzyme1 ENZYME_PLUGIN=%loadEnzyme
6+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
; RUN: cd %desired_wd
2+
; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme
3+
; RUN: make build/readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
4+
; RUN: build/readwriteread-enzyme2
5+
; RUN: make clean-readwriteread-enzyme2 ENZYME_PLUGIN=%loadEnzyme
6+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
; RUN: cd %desired_wd
2+
; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme
3+
; RUN: make build/readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme CLANG_BIN_PATH=%clangBinPath
4+
; RUN: build/readwriteread-enzyme3
5+
; RUN: make clean-readwriteread-enzyme3 ENZYME_PLUGIN=%loadEnzyme
6+

enzyme/test/Enzyme/badcall.ll

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,12 @@ attributes #1 = { noinline nounwind uwtable }
4242

4343
; CHECK: define internal {{(dso_local )?}}{} @diffef(double* nocapture %x, double* %"x'")
4444
; CHECK-NEXT: entry:
45-
; CHECK-NEXT: %0 = call { { {} } } @augmented_subf(double* %x, double* %"x'")
46-
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
47-
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
48-
; CHECK-NEXT: %1 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {} } undef)
49-
; CHECK-NEXT: ret {} undef
45+
; CHECK-NEXT: %0 = call { { {}, double } } @augmented_subf(double* %x, double* %"x'")
46+
; CHECK-NEXT: %1 = extractvalue { { {}, double } } %0, 0
47+
; CHECK-NEXT: store double 2.000000e+00, double* %x, align 8
48+
; CHECK-NEXT: store double 0.000000e+00, double* %"x'"
49+
; CHECK-NEXT: %2 = call {} @diffesubf(double* nonnull %x, double* %"x'", { {}, double } %1)
50+
; CHECK-NEXT: ret {} undef
5051
; CHECK-NEXT: }
5152

5253
; CHECK: define internal {{(dso_local )?}}{ {} } @augmented_metasubf(double* nocapture %x, double* %"x'")
@@ -56,16 +57,21 @@ attributes #1 = { noinline nounwind uwtable }
5657
; CHECK-NEXT: ret { {} } undef
5758
; CHECK-NEXT: }
5859

59-
; CHECK: define internal {{(dso_local )?}}{ { {} } } @augmented_subf(double* nocapture %x, double* %"x'")
60+
; CHECK: define internal {{(dso_local )?}}{ { {}, double } } @augmented_subf(double* nocapture %x, double* %"x'")
6061
; CHECK-NEXT: entry:
61-
; CHECK-NEXT: %0 = load double, double* %x, align 8
62-
; CHECK-NEXT: %mul = fmul fast double %0, 2.000000e+00
63-
; CHECK-NEXT: store double %mul, double* %x, align 8
64-
; CHECK-NEXT: %1 = call { {} } @augmented_metasubf(double* %x, double* %"x'")
65-
; CHECK-NEXT: ret { { {} } } undef
62+
; CHECK-NEXT: %0 = alloca { { {}, double } }
63+
; CHECK-NEXT: %1 = getelementptr { { {}, double } }, { { {}, double } }* %0, i32 0, i32 0
64+
; CHECK-NEXT: %2 = load double, double* %x, align 8
65+
; CHECK-NEXT: %3 = getelementptr { {}, double }, { {}, double }* %1, i32 0, i32 1
66+
; CHECK-NEXT: store double %2, double* %3
67+
; CHECK-NEXT: %mul = fmul fast double %2, 2.000000e+00
68+
; CHECK-NEXT: store double %mul, double* %x, align 8
69+
; CHECK-NEXT: %4 = call { {} } @augmented_metasubf(double* %x, double* %"x'")
70+
; CHECK-NEXT: %5 = load { { {}, double } }, { { {}, double } }* %0
71+
; CHECK-NEXT: ret { { {}, double } } %5
6672
; CHECK-NEXT: }
6773

68-
; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {} } %tapeArg)
74+
; CHECK: define internal {{(dso_local )?}}{} @diffesubf(double* nocapture %x, double* %"x'", { {}, double } %tapeArg)
6975
; CHECK-NEXT: entry:
7076
; CHECK-NEXT: %0 = call {} @diffemetasubf(double* %x, double* %"x'", {} undef)
7177
; CHECK-NEXT: %1 = load double, double* %"x'"

0 commit comments

Comments
 (0)