Skip to content

Commit 53500a6

Browse files
bors[bot]vext01
andauthored
15: Move most of the control point logic into to Rust code. r=ltratt a=vext01 Co-authored-by: Edd Barrett <[email protected]>
2 parents b9b9cb8 + d730329 commit 53500a6

File tree

1 file changed

+52
-64
lines changed

1 file changed

+52
-64
lines changed

llvm/lib/Transforms/Yk/ControlPoint.cpp

Lines changed: 52 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,11 @@
8888
#define DEBUG_TYPE "yk-control-point"
8989
#define JIT_STATE_PREFIX "jit-state: "
9090

91+
// These constants mirror `ykrt::mt::JITACTION_*`.
92+
const uintptr_t JITActionNop = 1;
93+
const uintptr_t JITActionStartTracing = 2;
94+
const uintptr_t JITActionStopTracing = 3;
95+
9196
using namespace llvm;
9297

9398
/// Find the call to the dummy control point that we want to patch.
@@ -130,20 +135,34 @@ void createControlPoint(Module &Mod, Function *F, std::vector<Value *> LiveVars,
130135

131136
// Create control point blocks and setup the IRBuilder.
132137
BasicBlock *CtrlPointEntry = BasicBlock::Create(Context, "cpentry", F);
133-
BasicBlock *BBTracing = BasicBlock::Create(Context, "bbtracing", F);
134-
BasicBlock *BBNotTracing = BasicBlock::Create(Context, "bbnottracing", F);
135-
BasicBlock *BBHasTrace = BasicBlock::Create(Context, "bbhastrace", F);
136-
BasicBlock *BBExecuteTrace = BasicBlock::Create(Context, "bbhastrace", F);
137-
BasicBlock *BBHasNoTrace = BasicBlock::Create(Context, "bbhasnotrace", F);
138+
BasicBlock *BBExecuteTrace = BasicBlock::Create(Context, "bbhexectrace", F);
139+
BasicBlock *BBStartTracing = BasicBlock::Create(Context, "bbstarttracing", F);
138140
BasicBlock *BBReturn = BasicBlock::Create(Context, "bbreturn", F);
139141
BasicBlock *BBStopTracing = BasicBlock::Create(Context, "bbstoptracing", F);
140-
IRBuilder<> Builder(CtrlPointEntry);
142+
143+
// Get the type for a pointer-sized integer.
144+
DataLayout DL(&Mod);
145+
unsigned PtrBitSize = DL.getPointerSize() * 8;
146+
IntegerType *PtrSizedInteger = IntegerType::getIntNTy(Context, PtrBitSize);
141147

142148
// Some frequently used constants.
143-
ConstantInt *Int0 = ConstantInt::get(Context, APInt(8, 0));
144-
Constant *PtNull = Constant::getNullValue(Type::getInt8PtrTy(Context));
149+
ConstantInt *JActNop = ConstantInt::get(PtrSizedInteger, JITActionNop);
150+
ConstantInt *JActStartTracing =
151+
ConstantInt::get(PtrSizedInteger, JITActionStartTracing);
152+
ConstantInt *JActStopTracing =
153+
ConstantInt::get(PtrSizedInteger, JITActionStopTracing);
154+
155+
// Add definitions for __yk functions.
156+
Function *FuncTransLoc = llvm::Function::Create(
157+
FunctionType::get(PtrSizedInteger, {Type::getInt8PtrTy(Context)}, false),
158+
GlobalValue::ExternalLinkage, "__ykrt_transition_location", Mod);
159+
160+
Function *FuncSetCodePtr = llvm::Function::Create(
161+
FunctionType::get(
162+
Type::getVoidTy(Context),
163+
{Type::getInt8PtrTy(Context), Type::getInt8PtrTy(Context)}, false),
164+
GlobalValue::ExternalLinkage, "__ykrt_set_loc_code_ptr", Mod);
145165

146-
// Add definitions for __yktrace functions.
147166
Function *FuncStartTracing = llvm::Function::Create(
148167
FunctionType::get(Type::getVoidTy(Context), {Type::getInt64Ty(Context)},
149168
false),
@@ -158,82 +177,52 @@ void createControlPoint(Module &Mod, Function *F, std::vector<Value *> LiveVars,
158177
{Type::getInt8PtrTy(Context)}, false),
159178
GlobalValue::ExternalLinkage, "__yktrace_irtrace_compile", Mod);
160179

161-
// Generate global variables to hold the state of the JIT.
162-
GlobalVariable *GVTracing = new GlobalVariable(
163-
Mod, Type::getInt8Ty(Context), false, GlobalVariable::InternalLinkage,
164-
Int0, "tracing", (GlobalVariable *)nullptr);
165-
166-
GlobalVariable *GVCompiledTrace = new GlobalVariable(
167-
Mod, Type::getInt8PtrTy(Context), false, GlobalVariable::InternalLinkage,
168-
PtNull, "compiled_trace", (GlobalVariable *)nullptr);
169-
170-
GlobalVariable *GVStartLoc = new GlobalVariable(
171-
Mod, YkLocTy, false, GlobalVariable::InternalLinkage,
172-
Constant::getNullValue(YkLocTy), "start_loc", (GlobalVariable *)nullptr);
173-
174-
// Create control point entry block. Checks if we are currently tracing.
175-
Value *GVTracingVal = Builder.CreateLoad(Type::getInt8Ty(Context), GVTracing);
176-
Value *IsTracing =
177-
Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ, GVTracingVal, Int0);
178-
Builder.CreateCondBr(IsTracing, BBNotTracing, BBTracing);
179-
180-
// Create block for "not tracing" case. Checks if we already compiled a trace.
181-
Builder.SetInsertPoint(BBNotTracing);
182-
Value *GVCompiledTraceVal =
183-
Builder.CreateLoad(Type::getInt8PtrTy(Context), GVCompiledTrace);
184-
Value *HasTrace = Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ,
185-
GVCompiledTraceVal, PtNull);
186-
Builder.CreateCondBr(HasTrace, BBHasNoTrace, BBHasTrace);
187-
188-
// Create block that starts tracing.
189-
Builder.SetInsertPoint(BBHasNoTrace);
180+
// Populate the entry block. This calls `__ykrt_transition_location()` to
181+
// decide what to do next.
182+
IRBuilder<> Builder(CtrlPointEntry);
183+
Value *CastLoc =
184+
Builder.CreateBitCast(F->getArg(0), Type::getInt8PtrTy(Context));
185+
Value *JITAction = Builder.CreateCall(FuncTransLoc->getFunctionType(),
186+
FuncTransLoc, {CastLoc});
187+
SwitchInst *ActionSw = Builder.CreateSwitch(JITAction, BBExecuteTrace, 3);
188+
ActionSw->addCase(JActNop, BBReturn);
189+
ActionSw->addCase(JActStartTracing, BBStartTracing);
190+
ActionSw->addCase(JActStopTracing, BBStopTracing);
191+
192+
// Populate the block that starts tracing.
193+
Builder.SetInsertPoint(BBStartTracing);
190194
createJITStatePrint(Builder, &Mod, "start-tracing");
191195
Builder.CreateCall(FuncStartTracing->getFunctionType(), FuncStartTracing,
192196
{ConstantInt::get(Context, APInt(64, 1))});
193-
Builder.CreateStore(ConstantInt::get(Context, APInt(8, 1)), GVTracing);
194-
Builder.CreateStore(F->getArg(0), GVStartLoc);
195197
Builder.CreateBr(BBReturn);
196198

197-
// Create block that checks if we've reached the same location again so we
198-
// can execute a compiled trace.
199-
Builder.SetInsertPoint(BBHasTrace);
200-
Value *ValStartLoc = Builder.CreateLoad(YkLocTy, GVStartLoc);
201-
Value *ExecTraceCond = Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ,
202-
ValStartLoc, F->getArg(0));
203-
Builder.CreateCondBr(ExecTraceCond, BBExecuteTrace, BBReturn);
204-
205-
// Create block that executes a compiled trace.
199+
// Populate the block that calls a compiled trace. If execution gets into
200+
// this block then `JITAction` is a pointer to a compiled trace.
206201
Builder.SetInsertPoint(BBExecuteTrace);
207202
std::vector<Type *> TypeParams;
208203
for (Value *LV : LiveVars) {
209204
TypeParams.push_back(LV->getType());
210205
}
211206
FunctionType *FType =
212207
FunctionType::get(YkCtrlPointStruct, {YkCtrlPointStruct}, false);
213-
Value *CastTrace =
214-
Builder.CreateBitCast(GVCompiledTraceVal, FType->getPointerTo());
208+
Value *JITActionPtr =
209+
Builder.CreateIntToPtr(JITAction, Type::getInt8PtrTy(Context));
210+
Value *CastTrace = Builder.CreateBitCast(JITActionPtr, FType->getPointerTo());
215211
createJITStatePrint(Builder, &Mod, "enter-jit-code");
216212
CallInst *CTResult = Builder.CreateCall(FType, CastTrace, F->getArg(1));
217213
createJITStatePrint(Builder, &Mod, "exit-jit-code");
218214
CTResult->setTailCall(true);
219215
Builder.CreateBr(BBExecuteTrace);
220216

221-
// Create block that decides when to stop tracing.
222-
Builder.SetInsertPoint(BBTracing);
223-
Value *ValStartLoc2 = Builder.CreateLoad(YkLocTy, GVStartLoc);
224-
Value *StopTracingCond = Builder.CreateICmp(CmpInst::Predicate::ICMP_EQ,
225-
ValStartLoc2, F->getArg(0));
226-
Builder.CreateCondBr(StopTracingCond, BBStopTracing, BBReturn);
227-
228217
// Create block that stops tracing, compiles a trace, and stores it in a
229218
// global variable.
230219
Builder.SetInsertPoint(BBStopTracing);
231220
Value *TR =
232221
Builder.CreateCall(FuncStopTracing->getFunctionType(), FuncStopTracing);
233222
Value *CT = Builder.CreateCall(FuncCompileTrace->getFunctionType(),
234223
FuncCompileTrace, {TR});
235-
Builder.CreateStore(CT, GVCompiledTrace);
236-
Builder.CreateStore(ConstantInt::get(Context, APInt(8, 0)), GVTracing);
224+
Builder.CreateCall(FuncSetCodePtr->getFunctionType(), FuncSetCodePtr,
225+
{CastLoc, CT});
237226
createJITStatePrint(Builder, &Mod, "stop-tracing");
238227
Builder.CreateBr(BBReturn);
239228

@@ -242,10 +231,9 @@ void createControlPoint(Module &Mod, Function *F, std::vector<Value *> LiveVars,
242231
// which contains the changed interpreter state.
243232
Builder.SetInsertPoint(BBReturn);
244233
Value *YkCtrlPointVars = F->getArg(1);
245-
PHINode *Phi = Builder.CreatePHI(YkCtrlPointStruct, 3);
246-
Phi->addIncoming(YkCtrlPointVars, BBHasTrace);
247-
Phi->addIncoming(YkCtrlPointVars, BBTracing);
248-
Phi->addIncoming(YkCtrlPointVars, BBHasNoTrace);
234+
PHINode *Phi = Builder.CreatePHI(YkCtrlPointStruct, 2);
235+
Phi->addIncoming(YkCtrlPointVars, CtrlPointEntry);
236+
Phi->addIncoming(YkCtrlPointVars, BBStartTracing);
249237
Phi->addIncoming(YkCtrlPointVars, BBStopTracing);
250238
Builder.CreateRet(Phi);
251239
}

0 commit comments

Comments
 (0)