Skip to content

Commit e215049

Browse files
committed
Handle rust pointer type
1 parent e8f903e commit e215049

File tree

10 files changed

+276
-100
lines changed

10 files changed

+276
-100
lines changed

enzyme/Enzyme/ActivityAnalysis.cpp

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -104,36 +104,40 @@ static inline bool couldFunctionArgumentCapture(CallInst *CI, Value *val) {
104104
return false;
105105
}
106106

107-
const char *KnownInactiveFunctions[] = {
108-
"__assert_fail",
109-
"__cxa_guard_acquire",
110-
"__cxa_guard_release",
111-
"__cxa_guard_abort",
112-
"posix_memalign",
113-
"printf",
114-
"puts",
115-
"__enzyme_float",
116-
"__enzyme_double",
117-
"__enzyme_integer",
118-
"__enzyme_pointer",
119-
"__kmpc_for_static_init_4",
120-
"__kmpc_for_static_init_4u",
121-
"__kmpc_for_static_init_8",
122-
"__kmpc_for_static_init_8u",
123-
"__kmpc_for_static_fini",
124-
"__kmpc_dispatch_init_4",
125-
"__kmpc_dispatch_init_4u",
126-
"__kmpc_dispatch_init_8",
127-
"__kmpc_dispatch_init_8u",
128-
"__kmpc_dispatch_next_4",
129-
"__kmpc_dispatch_next_4u",
130-
"__kmpc_dispatch_next_8",
131-
"__kmpc_dispatch_next_8u",
132-
"__kmpc_dispatch_fini_4",
133-
"__kmpc_dispatch_fini_4u",
134-
"__kmpc_dispatch_fini_8",
135-
"__kmpc_dispatch_fini_8u",
136-
};
107+
const char *KnownInactiveFunctionsStartingWith[] = {"_ZN4core3fmt",
108+
"_ZN3std2io5stdio6_print"};
109+
110+
const char *KnownInactiveFunctions[] = {"__assert_fail",
111+
"__cxa_guard_acquire",
112+
"__cxa_guard_release",
113+
"__cxa_guard_abort",
114+
"posix_memalign",
115+
"printf",
116+
"puts",
117+
"__enzyme_float",
118+
"__enzyme_double",
119+
"__enzyme_integer",
120+
"__enzyme_pointer",
121+
"__kmpc_for_static_init_4",
122+
"__kmpc_for_static_init_4u",
123+
"__kmpc_for_static_init_8",
124+
"__kmpc_for_static_init_8u",
125+
"__kmpc_for_static_fini",
126+
"__kmpc_dispatch_init_4",
127+
"__kmpc_dispatch_init_4u",
128+
"__kmpc_dispatch_init_8",
129+
"__kmpc_dispatch_init_8u",
130+
"__kmpc_dispatch_next_4",
131+
"__kmpc_dispatch_next_4u",
132+
"__kmpc_dispatch_next_8",
133+
"__kmpc_dispatch_next_8u",
134+
"__kmpc_dispatch_fini_4",
135+
"__kmpc_dispatch_fini_4u",
136+
"__kmpc_dispatch_fini_8",
137+
"__kmpc_dispatch_fini_8u",
138+
"malloc_usable_size",
139+
"malloc_size",
140+
"_msize"};
137141

138142
/// Is the use of value val as an argument of call CI known to be inactive
139143
/// This tool can only be used when in DOWN mode
@@ -151,6 +155,12 @@ bool ActivityAnalyzer::isFunctionArgumentConstant(CallInst *CI, Value *val) {
151155
// of arguments
152156
if (isAllocationFunction(*F, TLI) || isDeallocationFunction(*F, TLI))
153157
return true;
158+
159+
for (auto FuncName : KnownInactiveFunctionsStartingWith) {
160+
if (Name.startswith(FuncName)) {
161+
return true;
162+
}
163+
}
154164
for (auto FuncName : KnownInactiveFunctions) {
155165
if (Name == FuncName)
156166
return true;
@@ -1122,7 +1132,11 @@ bool ActivityAnalyzer::isInstructionInactiveFromOrigin(TypeResults &TR,
11221132
called->getName() == "_ZdlPvm" || called->getName() == "munmap") {
11231133
return true;
11241134
}
1125-
1135+
for (auto FuncName : KnownInactiveFunctionsStartingWith) {
1136+
if (called->getName().startswith(FuncName)) {
1137+
return true;
1138+
}
1139+
}
11261140
for (auto FuncName : KnownInactiveFunctions) {
11271141
if (called->getName() == FuncName)
11281142
return true;
@@ -1256,19 +1270,20 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults &TR,
12561270
<< "\n";
12571271

12581272
bool seenuse = false;
1259-
1260-
std::deque<User *> todo;
1273+
// user, predecessor
1274+
std::deque<std::pair<User *, Value *>> todo;
12611275
for (const auto a : val->users()) {
1262-
todo.push_back(a);
1276+
todo.push_back(std::make_pair(a, val));
12631277
}
1264-
std::set<Value *> done = {val};
1278+
std::set<std::pair<User *, Value *>> done = {};
12651279

12661280
while (todo.size()) {
1267-
User *a = todo.front();
1281+
auto pair = todo.front();
12681282
todo.pop_front();
1269-
if (done.count(a))
1283+
if (done.count(pair))
12701284
continue;
1271-
done.insert(a);
1285+
done.insert(pair);
1286+
User *a = pair.first;
12721287

12731288
if (printconst)
12741289
llvm::errs() << " considering use of " << *val << " - " << *a
@@ -1315,7 +1330,7 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults &TR,
13151330
}
13161331

13171332
if (auto call = dyn_cast<CallInst>(a)) {
1318-
bool ConstantArg = isFunctionArgumentConstant(call, val);
1333+
bool ConstantArg = isFunctionArgumentConstant(call, pair.second);
13191334
if (ConstantArg) {
13201335
if (printconst) {
13211336
llvm::errs() << "Value found constant callinst use:" << *val
@@ -1334,7 +1349,7 @@ bool ActivityAnalyzer::isValueInactiveFromUsers(TypeResults &TR,
13341349
continue;
13351350
}
13361351
for (auto u : I->users()) {
1337-
todo.push_back(u);
1352+
todo.push_back(std::make_pair(u, (Value *)I));
13381353
}
13391354
continue;
13401355
}

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,6 @@ class AdjointGenerator
221221
assert(placeholder->getType() == type);
222222
gutils->invertedPointers.erase(&LI);
223223

224-
// TODO consider optimizing when you know it isnt a pointer and thus don't
225-
// need to store
226224
if (!constantval) {
227225
IRBuilder<> BuilderZ(placeholder);
228226
Value *newip = nullptr;
@@ -2108,7 +2106,9 @@ class AdjointGenerator
21082106
}
21092107

21102108
if (called &&
2111-
(called->getName() == "printf" || called->getName() == "puts")) {
2109+
(called->getName() == "printf" || called->getName() == "puts" ||
2110+
called->getName().startswith("_ZN3std2io5stdio6_print") ||
2111+
called->getName().startswith("_ZN4core3fmt"))) {
21122112
if (Mode == DerivativeMode::Reverse) {
21132113
eraseIfUnused(*orig, /*erase*/ true, /*check*/ false);
21142114
}
@@ -2441,7 +2441,8 @@ class AdjointGenerator
24412441
auto argType = argi->getType();
24422442

24432443
if (!argType->isFPOrFPVectorTy() &&
2444-
TR.query(orig->getArgOperand(i)).Inner0().isPossiblePointer()) {
2444+
(TR.query(orig->getArgOperand(i)).Inner0().isPossiblePointer() ||
2445+
foreignFunction)) {
24452446
DIFFE_TYPE ty = DIFFE_TYPE::DUP_ARG;
24462447
if (argType->isPointerTy()) {
24472448
#if LLVM_VERSION_MAJOR >= 12
@@ -2473,11 +2474,24 @@ class AdjointGenerator
24732474
assert(whatType(argType) == DIFFE_TYPE::DUP_ARG ||
24742475
whatType(argType) == DIFFE_TYPE::CONSTANT);
24752476
} else {
2477+
if (foreignFunction)
2478+
assert(!argType->isIntOrIntVectorTy());
24762479
argsInverted.push_back(DIFFE_TYPE::OUT_DIFF);
24772480
assert(whatType(argType) == DIFFE_TYPE::OUT_DIFF ||
24782481
whatType(argType) == DIFFE_TYPE::CONSTANT);
24792482
}
24802483
}
2484+
if (called) {
2485+
if (orig->getNumArgOperands() !=
2486+
cast<Function>(called)->getFunctionType()->getNumParams()) {
2487+
llvm::errs() << *gutils->oldFunc << "\n";
2488+
llvm::errs() << *orig << "\n";
2489+
}
2490+
assert(orig->getNumArgOperands() ==
2491+
cast<Function>(called)->getFunctionType()->getNumParams());
2492+
assert(argsInverted.size() ==
2493+
cast<Function>(called)->getFunctionType()->getNumParams());
2494+
}
24812495

24822496
DIFFE_TYPE subretType;
24832497
if (gutils->isConstantValue(orig)) {

enzyme/Enzyme/CacheUtility.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ void CacheUtility::erase(Instruction *I) {
105105
SE.eraseValueFromMap(I);
106106

107107
if (!I->use_empty()) {
108+
llvm::errs() << *newFunc->getParent() << "\n";
108109
llvm::errs() << *newFunc << "\n";
109110
llvm::errs() << *I << "\n";
110111
}

enzyme/Enzyme/FunctionUtils.cpp

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -368,11 +368,56 @@ OldAllocationSize(Value *Ptr, CallInst *Loc, Function *NewF, IntegerType *T,
368368
if (success)
369369
continue;
370370
}
371-
llvm::errs() << *NewF->getParent() << "\n";
372-
llvm::errs() << *NewF << "\n";
371+
372+
// llvm::errs() << *NewF->getParent() << "\n";
373+
// llvm::errs() << *NewF << "\n";
373374
EmitFailure("DynamicReallocSize", Loc->getDebugLoc(), Loc,
374375
"could not statically determine size of realloc ", *Loc,
375376
" - because of - ", *next.first);
377+
378+
std::string allocName;
379+
switch (llvm::Triple(NewF->getParent()->getTargetTriple()).getOS()) {
380+
case llvm::Triple::Linux:
381+
case llvm::Triple::FreeBSD:
382+
case llvm::Triple::NetBSD:
383+
case llvm::Triple::OpenBSD:
384+
case llvm::Triple::Fuchsia:
385+
allocName = "malloc_usable_size";
386+
break;
387+
388+
case llvm::Triple::Darwin:
389+
case llvm::Triple::IOS:
390+
case llvm::Triple::MacOSX:
391+
case llvm::Triple::WatchOS:
392+
case llvm::Triple::TvOS:
393+
allocName = "malloc_size";
394+
break;
395+
396+
case llvm::Triple::Win32:
397+
allocName = "_msize";
398+
break;
399+
400+
default:
401+
llvm_unreachable("unknown reallocation for OS");
402+
}
403+
404+
AttributeList list;
405+
list = list.addAttribute(NewF->getContext(), AttributeList::FunctionIndex,
406+
Attribute::ReadOnly);
407+
list = list.addParamAttribute(NewF->getContext(), 0, Attribute::ReadNone);
408+
list = list.addParamAttribute(NewF->getContext(), 0, Attribute::NoCapture);
409+
auto allocSize = NewF->getParent()->getOrInsertFunction(
410+
allocName,
411+
FunctionType::get(
412+
IntegerType::get(NewF->getContext(), 8 * sizeof(size_t)),
413+
{Type::getInt8PtrTy(NewF->getContext())}, /*isVarArg*/ false),
414+
list);
415+
416+
B.SetInsertPoint(Loc);
417+
Value *sz = B.CreateZExtOrTrunc(B.CreateCall(allocSize, {Ptr}), T);
418+
B.CreateStore(sz, AI);
419+
return AI;
420+
376421
llvm_unreachable("DynamicReallocSize");
377422
}
378423
return AI;
@@ -490,6 +535,11 @@ static void ForceRecursiveInlining(Function *NewF, size_t Limit) {
490535
continue;
491536
if (CI->getCalledFunction()->empty())
492537
continue;
538+
if (CI->getCalledFunction()->getName().startswith(
539+
"_ZN3std2io5stdio6_print"))
540+
continue;
541+
if (CI->getCalledFunction()->getName().startswith("_ZN4core3fmt"))
542+
continue;
493543
if (CI->getCalledFunction()->hasFnAttribute(
494544
Attribute::ReturnsTwice) ||
495545
CI->getCalledFunction()->hasFnAttribute(Attribute::NoInline))

0 commit comments

Comments
 (0)