Skip to content

Commit a6a92b4

Browse files
authored
Add custom anonymous type for tape (rust-lang#859)
* Add custom anonymous type for tape * Fix cast
1 parent 5cfe075 commit a6a92b4

File tree

5 files changed

+26
-22
lines changed

5 files changed

+26
-22
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12063,7 +12063,9 @@ class AdjointGenerator
1206312063
shouldFree()) {
1206412064
assert(tape);
1206512065
auto tapep = BuilderZ.CreatePointerCast(
12066-
tape, PointerType::getUnqual(fnandtapetype->tapeType));
12066+
tape, PointerType::get(
12067+
fnandtapetype->tapeType,
12068+
cast<PointerType>(tape->getType())->getAddressSpace()));
1206712069
#if LLVM_VERSION_MAJOR > 7
1206812070
auto truetape =
1206912071
BuilderZ.CreateLoad(fnandtapetype->tapeType, tapep, "tapeld");

enzyme/Enzyme/EnzymeLogic.cpp

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2355,17 +2355,8 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
23552355
gutils->newFunc->getParent()->getDataLayout().getTypeAllocSizeInBits(
23562356
tapeType);
23572357
if (size != 0) {
2358-
auto i64 = Type::getInt64Ty(gutils->newFunc->getContext());
2359-
BasicBlock *BB = BasicBlock::Create(gutils->newFunc->getContext(),
2360-
"entry", gutils->newFunc);
2361-
IRBuilder<> B(BB);
2362-
2363-
CallInst *malloccall;
2364-
CreateAllocation(B, tapeType, ConstantInt::get(i64, 1), "tapemem",
2365-
&malloccall, nullptr);
23662358
RetTypes[returnMapping.find(AugmentedStruct::Tape)->second] =
2367-
malloccall->getType();
2368-
BB->eraseFromParent();
2359+
getDefaultAnonymousTapeType(gutils->newFunc->getContext());
23692360
}
23702361
}
23712362
}
@@ -2485,13 +2476,13 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
24852476
if (size != 0) {
24862477
CallInst *malloccall = nullptr;
24872478
Instruction *zero = nullptr;
2488-
tapeMemory =
2489-
CreateAllocation(ib, tapeType, ConstantInt::get(i64, 1), "tapemem",
2490-
&malloccall, EnzymeZeroCache ? &zero : nullptr);
2479+
tapeMemory = CreateAllocation(
2480+
ib, tapeType, ConstantInt::get(i64, 1), "tapemem", &malloccall,
2481+
EnzymeZeroCache ? &zero : nullptr, /*isDefault*/ true);
24912482
memory = malloccall;
24922483
} else {
2493-
memory =
2494-
ConstantPointerNull::get(Type::getInt8PtrTy(NewF->getContext()));
2484+
memory = ConstantPointerNull::get(
2485+
getDefaultAnonymousTapeType(NewF->getContext()));
24952486
}
24962487
Value *Idxs[] = {
24972488
ib.getInt32(0),

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1942,7 +1942,7 @@ FunctionType *getFunctionTypeForClone(
19421942
returnValue == ReturnType::TapeAndReturn ||
19431943
returnValue == ReturnType::Tape) {
19441944
RetTypes.clear();
1945-
RetTypes.push_back(Type::getInt8PtrTy(FTy->getContext()));
1945+
RetTypes.push_back(getDefaultAnonymousTapeType(FTy->getContext()));
19461946
if (returnValue == ReturnType::TapeAndTwoReturns) {
19471947
RetTypes.push_back(FTy->getReturnType());
19481948
RetTypes.push_back(

enzyme/Enzyme/Utils.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ void (*CustomErrorHandler)(const char *, LLVMValueRef, ErrorType,
4848
void *) = nullptr;
4949
LLVMValueRef (*CustomAllocator)(LLVMBuilderRef, LLVMTypeRef,
5050
/*Count*/ LLVMValueRef,
51-
/*Align*/ LLVMValueRef) = nullptr;
51+
/*Align*/ LLVMValueRef, uint8_t) = nullptr;
5252
LLVMValueRef (*CustomDeallocator)(LLVMBuilderRef, LLVMValueRef) = nullptr;
5353
void (*CustomRuntimeInactiveError)(LLVMBuilderRef, LLVMValueRef,
5454
LLVMValueRef) = nullptr;
5555
LLVMValueRef (*EnzymePostCacheStore)(LLVMValueRef, LLVMBuilderRef,
5656
LLVMValueRef *) = nullptr;
57+
LLVMTypeRef (*EnzymeDefaultTapeType)(LLVMContextRef) = nullptr;
5758
}
5859

5960
llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI,
@@ -70,6 +71,12 @@ llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI,
7071
return res;
7172
}
7273

74+
llvm::PointerType *getDefaultAnonymousTapeType(llvm::LLVMContext &C) {
75+
if (EnzymeDefaultTapeType)
76+
return cast<PointerType>(unwrap(EnzymeDefaultTapeType(wrap(&C))));
77+
return Type::getInt8PtrTy(C);
78+
}
79+
7380
Function *getOrInsertExponentialAllocator(Module &M, Function *newFunc,
7481
bool ZeroInit, llvm::Type *RT) {
7582
bool custom = true;
@@ -239,15 +246,16 @@ llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev,
239246
}
240247

241248
Value *CreateAllocation(IRBuilder<> &Builder, llvm::Type *T, Value *Count,
242-
Twine Name, CallInst **caller, Instruction **ZeroMem) {
249+
Twine Name, CallInst **caller, Instruction **ZeroMem,
250+
bool isDefault) {
243251
Value *res;
244252
auto &M = *Builder.GetInsertBlock()->getParent()->getParent();
245253
auto AlignI = M.getDataLayout().getTypeAllocSizeInBits(T) / 8;
246254
auto Align = ConstantInt::get(Count->getType(), AlignI);
247255
CallInst *malloccall = nullptr;
248256
if (CustomAllocator) {
249-
res = unwrap(
250-
CustomAllocator(wrap(&Builder), wrap(T), wrap(Count), wrap(Align)));
257+
res = unwrap(CustomAllocator(wrap(&Builder), wrap(T), wrap(Count),
258+
wrap(Align), isDefault));
251259
if (auto I = dyn_cast<Instruction>(res))
252260
I->setName(Name);
253261

enzyme/Enzyme/Utils.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,8 @@ llvm::SmallVector<llvm::Instruction *, 2> PostCacheStore(llvm::StoreInst *SI,
8686
llvm::Value *CreateAllocation(llvm::IRBuilder<> &B, llvm::Type *T,
8787
llvm::Value *Count, llvm::Twine Name = "",
8888
llvm::CallInst **caller = nullptr,
89-
llvm::Instruction **ZeroMem = nullptr);
89+
llvm::Instruction **ZeroMem = nullptr,
90+
bool isDefault = false);
9091
llvm::CallInst *CreateDealloc(llvm::IRBuilder<> &B, llvm::Value *ToFree);
9192

9293
llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev,
@@ -95,6 +96,8 @@ llvm::Value *CreateReAllocation(llvm::IRBuilder<> &B, llvm::Value *prev,
9596
llvm::CallInst **caller = nullptr,
9697
bool ZeroMem = false);
9798

99+
llvm::PointerType *getDefaultAnonymousTapeType(llvm::LLVMContext &C);
100+
98101
extern std::map<std::string, std::function<llvm::Value *(
99102
llvm::IRBuilder<> &, llvm::CallInst *,
100103
llvm::ArrayRef<llvm::Value *>)>>

0 commit comments

Comments
 (0)