Skip to content

Commit 949634d

Browse files
authored
Merge pull request rust-lang#57 from RReverser/llvm4-0
Provide safe function wrappers
2 parents 0e93d98 + 1a78203 commit 949634d

File tree

9 files changed

+119
-123
lines changed

9 files changed

+119
-123
lines changed

README.md

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,14 +53,14 @@ extern crate inkwell;
5353
use inkwell::OptimizationLevel;
5454
use inkwell::builder::Builder;
5555
use inkwell::context::Context;
56-
use inkwell::execution_engine::{ExecutionEngine, Symbol};
56+
use inkwell::execution_engine::{ExecutionEngine, JitFunction};
5757
use inkwell::module::Module;
5858
use inkwell::targets::{InitializationConfig, Target};
5959
use std::error::Error;
6060

6161
/// Convenience type alias for the `sum` function.
6262
///
63-
/// Calling `sum` is innately `unsafe` because there's no guarantee it doesn't
63+
/// Calling this is innately `unsafe` because there's no guarantee it doesn't
6464
/// do `unsafe` operations internally.
6565
type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64;
6666

@@ -83,22 +83,21 @@ fn run() -> Result<(), Box<Error>> {
8383
let z = 3u64;
8484

8585
unsafe {
86-
println!("{} + {} + {} = {}", x, y, z, sum(x, y, z));
87-
assert_eq!(sum(x, y, z), x + y + z);
86+
println!("{} + {} + {} = {}", x, y, z, sum.call(x, y, z));
87+
assert_eq!(sum.call(x, y, z), x + y + z);
8888
}
8989

9090
Ok(())
9191
}
9292

93-
fn jit_compile_sum(
93+
fn jit_compile_sum<'engine>(
9494
context: &Context,
9595
module: &Module,
9696
builder: &Builder,
97-
execution_engine: &ExecutionEngine,
98-
) -> Option<Symbol<SumFunc>> {
97+
execution_engine: &'engine ExecutionEngine,
98+
) -> Option<JitFunction<'engine, SumFunc>> {
9999
let i64_type = context.i64_type();
100-
let fn_type_params = [i64_type.into(), i64_type.into(), i64_type.into()];
101-
let fn_type = i64_type.fn_type(&fn_type_params, false);
100+
let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false);
102101

103102
let function = module.add_function("sum", fn_type, None);
104103
let basic_block = context.append_basic_block(&function, "entry");

examples/jit.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ extern crate inkwell;
33
use inkwell::OptimizationLevel;
44
use inkwell::builder::Builder;
55
use inkwell::context::Context;
6-
use inkwell::execution_engine::{ExecutionEngine, Symbol};
6+
use inkwell::execution_engine::{ExecutionEngine, JitFunction};
77
use inkwell::module::Module;
88
use inkwell::targets::{InitializationConfig, Target};
99
use std::error::Error;
@@ -14,12 +14,12 @@ use std::error::Error;
1414
/// do `unsafe` operations internally.
1515
type SumFunc = unsafe extern "C" fn(u64, u64, u64) -> u64;
1616

17-
fn jit_compile_sum(
17+
fn jit_compile_sum<'engine>(
1818
context: &Context,
1919
module: &Module,
2020
builder: &Builder,
21-
execution_engine: &ExecutionEngine,
22-
) -> Option<Symbol<SumFunc>> {
21+
execution_engine: &'engine ExecutionEngine,
22+
) -> Option<JitFunction<'engine, SumFunc>> {
2323
let i64_type = context.i64_type();
2424
let fn_type = i64_type.fn_type(&[i64_type.into(), i64_type.into(), i64_type.into()], false);
2525

@@ -54,8 +54,8 @@ fn run() -> Result<(), Box<Error>> {
5454
let z = 3u64;
5555

5656
unsafe {
57-
println!("{} + {} + {} = {}", x, y, z, sum(x, y, z));
58-
assert_eq!(sum(x, y, z), x + y + z);
57+
println!("{} + {} + {} = {}", x, y, z, sum.call(x, y, z));
58+
assert_eq!(sum.call(x, y, z), x + y + z);
5959
}
6060

6161
Ok(())

examples/kaleidoscope/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1330,7 +1330,7 @@ pub fn main() {
13301330
};
13311331

13321332
unsafe {
1333-
println!("=> {}", compiled_fn());
1333+
println!("=> {}", compiled_fn.call());
13341334
}
13351335
}
13361336
}

src/execution_engine.rs

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::rc::Rc;
1111
use std::ops::Deref;
1212
use std::ffi::CString;
1313
use std::fmt::{self, Debug, Display, Formatter};
14+
use std::marker::PhantomData;
1415
use std::mem::{forget, zeroed, transmute_copy, size_of};
1516

1617
#[derive(Debug, PartialEq, Eq)]
@@ -270,7 +271,7 @@ impl ExecutionEngine {
270271
/// // fetch our JIT'd function and execute it
271272
/// unsafe {
272273
/// let test_fn = ee.get_function::<unsafe extern "C" fn() -> f64>("test_fn").unwrap();
273-
/// let return_value = test_fn();
274+
/// let return_value = test_fn.call();
274275
/// assert_eq!(return_value, 64.0);
275276
/// }
276277
/// ```
@@ -280,12 +281,12 @@ impl ExecutionEngine {
280281
/// It is the caller's responsibility to ensure they call the function with
281282
/// the correct signature and calling convention.
282283
///
283-
/// The `Symbol` wrapper ensures a function won't accidentally outlive the
284+
/// The `JitFunction` wrapper ensures a function won't accidentally outlive the
284285
/// execution engine it came from, but adding functions after calling this
285286
/// method *may* invalidate the function pointer.
286287
///
287288
/// [`UnsafeFunctionPointer`]: trait.UnsafeFunctionPointer.html
288-
pub unsafe fn get_function<F>(&self, fn_name: &str) -> Result<Symbol<F>, FunctionLookupError>
289+
pub unsafe fn get_function<'engine, F>(&'engine self, fn_name: &str) -> Result<JitFunction<'engine, F>, FunctionLookupError>
289290
where
290291
F: UnsafeFunctionPointer,
291292
{
@@ -313,8 +314,8 @@ impl ExecutionEngine {
313314
assert_eq!(size_of::<F>(), size_of::<usize>(),
314315
"The type `F` must have the same size as a function pointer");
315316

316-
Ok(Symbol {
317-
_execution_engine: self.execution_engine.clone(),
317+
Ok(JitFunction {
318+
_execution_engine: PhantomData,
318319
inner: transmute_copy(&address),
319320
})
320321
}
@@ -432,59 +433,58 @@ impl Deref for ExecEngineInner {
432433
}
433434
}
434435

435-
/// A wrapper around a function pointer which ensures the symbol being pointed
436+
/// A wrapper around a function pointer which ensures the function being pointed
436437
/// to doesn't accidentally outlive its execution engine.
437438
#[derive(Clone)]
438-
pub struct Symbol<F> {
439-
_execution_engine: ExecEngineInner,
439+
pub struct JitFunction<'engine, F> {
440+
_execution_engine: PhantomData<&'engine ExecutionEngine>,
440441
inner: F,
441442
}
442443

443-
impl<F: UnsafeFunctionPointer> Deref for Symbol<F> {
444-
type Target = F;
445-
446-
fn deref(&self) -> &Self::Target {
447-
&self.inner
448-
}
449-
}
450-
451-
impl<F> Debug for Symbol<F> {
444+
impl<'engine, F> Debug for JitFunction<'engine, F> {
452445
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
453-
f.debug_tuple("Symbol")
446+
f.debug_tuple("JitFunction")
454447
.field(&"<unnamed>")
455448
.finish()
456449
}
457450
}
458451

459452
/// Marker trait representing an unsafe function pointer (`unsafe extern "C" fn(A, B, ...) -> Output`).
460-
pub trait UnsafeFunctionPointer: private::Sealed + Copy {}
453+
pub trait UnsafeFunctionPointer: private::SealedUnsafeFunctionPointer {}
461454

462455
mod private {
463456
/// A sealed trait which ensures nobody outside this crate can implement
464457
/// `UnsafeFunctionPointer`.
465458
///
466459
/// See https://rust-lang-nursery.github.io/api-guidelines/future-proofing.html
467-
pub trait Sealed {}
460+
pub trait SealedUnsafeFunctionPointer: Copy {}
468461
}
469462

463+
impl<F: private::SealedUnsafeFunctionPointer> UnsafeFunctionPointer for F {}
464+
470465
macro_rules! impl_unsafe_fn {
466+
(@recurse $first:ident $( , $rest:ident )*) => {
467+
impl_unsafe_fn!($( $rest ),*);
468+
};
469+
470+
(@recurse) => {};
471+
471472
($( $param:ident ),*) => {
472-
impl<Output, $( $param ),*> private::Sealed for unsafe extern "C" fn($( $param ),*) -> Output {}
473-
impl<Output, $( $param ),*> UnsafeFunctionPointer for unsafe extern "C" fn($( $param ),*) -> Output {}
473+
impl<'engine, Output, $( $param ),*> private::SealedUnsafeFunctionPointer for unsafe extern "C" fn($( $param ),*) -> Output {}
474+
475+
impl<'engine, Output, $( $param ),*> JitFunction<'engine, unsafe extern "C" fn($( $param ),*) -> Output> {
476+
/// This method allows you to call the underlying function while making
477+
/// sure that the backing storage is not dropped too early and
478+
/// preserves the `unsafe` marker for any calls.
479+
#[allow(non_snake_case)]
480+
#[inline(always)]
481+
pub unsafe fn call(&self, $( $param: $param ),*) -> Output {
482+
(self.inner)($( $param ),*)
483+
}
484+
}
485+
486+
impl_unsafe_fn!(@recurse $( $param ),*);
474487
};
475488
}
476489

477-
impl_unsafe_fn!();
478-
impl_unsafe_fn!(A);
479-
impl_unsafe_fn!(A, B);
480-
impl_unsafe_fn!(A, B, C);
481-
impl_unsafe_fn!(A, B, C, D);
482-
impl_unsafe_fn!(A, B, C, D, E);
483-
impl_unsafe_fn!(A, B, C, D, E, F);
484-
impl_unsafe_fn!(A, B, C, D, E, F, G);
485-
impl_unsafe_fn!(A, B, C, D, E, F, G, H);
486-
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I);
487-
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J);
488-
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J, K);
489-
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J, K, L);
490490
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M);

src/targets.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#[llvm_versions(7.0 => latest)]
12
use either::Either;
23
use llvm_sys::target::{LLVMTargetDataRef, LLVMCopyStringRepOfTargetData, LLVMSizeOfTypeInBits, LLVMCreateTargetData, LLVMByteOrder, LLVMPointerSize, LLVMByteOrdering, LLVMStoreSizeOfType, LLVMABISizeOfType, LLVMABIAlignmentOfType, LLVMCallFrameAlignmentOfType, LLVMPreferredAlignmentOfType, LLVMPreferredAlignmentOfGlobal, LLVMElementAtOffset, LLVMOffsetOfElement, LLVMDisposeTargetData, LLVMPointerSizeForAS, LLVMIntPtrType, LLVMIntPtrTypeForAS, LLVMIntPtrTypeInContext, LLVMIntPtrTypeForASInContext};
34
use llvm_sys::target_machine::{LLVMGetFirstTarget, LLVMTargetRef, LLVMGetNextTarget, LLVMGetTargetFromName, LLVMGetTargetFromTriple, LLVMGetTargetName, LLVMGetTargetDescription, LLVMTargetHasJIT, LLVMTargetHasTargetMachine, LLVMTargetHasAsmBackend, LLVMTargetMachineRef, LLVMDisposeTargetMachine, LLVMGetTargetMachineTarget, LLVMGetTargetMachineTriple, LLVMSetTargetMachineAsmVerbosity, LLVMCreateTargetMachine, LLVMGetTargetMachineCPU, LLVMGetTargetMachineFeatureString, LLVMGetDefaultTargetTriple, LLVMAddAnalysisPasses, LLVMCodeGenOptLevel, LLVMCodeModel, LLVMRelocMode, LLVMCodeGenFileType, LLVMTargetMachineEmitToMemoryBuffer, LLVMTargetMachineEmitToFile};

tests/all/test_builder.rs

Lines changed: 58 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ use self::inkwell::{AddressSpace, OptimizationLevel};
44
use self::inkwell::context::Context;
55
use self::inkwell::builder::Builder;
66
use self::inkwell::targets::{InitializationConfig, Target};
7-
use self::inkwell::execution_engine::Symbol;
87

98
use std::ffi::CString;
109
use std::ptr::null;
@@ -137,17 +136,17 @@ fn test_null_checked_ptr_ops() {
137136
let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None).unwrap();
138137

139138
unsafe {
140-
let check_null_index1: Symbol<unsafe extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index1").unwrap();
139+
let check_null_index1 = execution_engine.get_function::<unsafe extern "C" fn(*const i8) -> i8>("check_null_index1").unwrap();
141140

142141
let array = &[100i8, 42i8];
143142

144-
assert_eq!(check_null_index1(null()), -1i8);
145-
assert_eq!(check_null_index1(array.as_ptr()), 42i8);
143+
assert_eq!(check_null_index1.call(null()), -1i8);
144+
assert_eq!(check_null_index1.call(array.as_ptr()), 42i8);
146145

147-
let check_null_index2: Symbol<unsafe extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index2").unwrap();
146+
let check_null_index2 = execution_engine.get_function::<unsafe extern "C" fn(*const i8) -> i8>("check_null_index2").unwrap();
148147

149-
assert_eq!(check_null_index2(null()), -1i8);
150-
assert_eq!(check_null_index2(array.as_ptr()), 42i8);
148+
assert_eq!(check_null_index2.call(null()), -1i8);
149+
assert_eq!(check_null_index2.call(array.as_ptr()), 42i8);
151150
}
152151
}
153152

@@ -216,24 +215,24 @@ fn test_binary_ops() {
216215
unsafe {
217216
type BoolFunc = unsafe extern "C" fn(bool, bool) -> bool;
218217

219-
let and: Symbol<BoolFunc> = execution_engine.get_function("and").unwrap();
220-
let or: Symbol<BoolFunc> = execution_engine.get_function("or").unwrap();
221-
let xor: Symbol<BoolFunc> = execution_engine.get_function("xor").unwrap();
218+
let and = execution_engine.get_function::<BoolFunc>("and").unwrap();
219+
let or = execution_engine.get_function::<BoolFunc>("or").unwrap();
220+
let xor = execution_engine.get_function::<BoolFunc>("xor").unwrap();
222221

223-
assert!(!and(false, false));
224-
assert!(!and(true, false));
225-
assert!(!and(false, true));
226-
assert!(and(true, true));
222+
assert!(!and.call(false, false));
223+
assert!(!and.call(true, false));
224+
assert!(!and.call(false, true));
225+
assert!(and.call(true, true));
227226

228-
assert!(!or(false, false));
229-
assert!(or(true, false));
230-
assert!(or(false, true));
231-
assert!(or(true, true));
227+
assert!(!or.call(false, false));
228+
assert!(or.call(true, false));
229+
assert!(or.call(false, true));
230+
assert!(or.call(true, true));
232231

233-
assert!(!xor(false, false));
234-
assert!(xor(true, false));
235-
assert!(xor(false, true));
236-
assert!(!xor(true, true));
232+
assert!(!xor.call(false, false));
233+
assert!(xor.call(true, false));
234+
assert!(xor.call(false, true));
235+
assert!(!xor.call(true, true));
237236
}
238237
}
239238

@@ -287,13 +286,13 @@ fn test_switch() {
287286
builder.build_return(Some(&double));
288287

289288
unsafe {
290-
let switch: Symbol<unsafe extern "C" fn(u8) -> u8> = execution_engine.get_function("switch").unwrap();
289+
let switch = execution_engine.get_function::<unsafe extern "C" fn(u8) -> u8>("switch").unwrap();
291290

292-
assert_eq!(switch(0), 1);
293-
assert_eq!(switch(1), 2);
294-
assert_eq!(switch(3), 6);
295-
assert_eq!(switch(10), 20);
296-
assert_eq!(switch(42), 255);
291+
assert_eq!(switch.call(0), 1);
292+
assert_eq!(switch.call(1), 2);
293+
assert_eq!(switch.call(3), 6);
294+
assert_eq!(switch.call(10), 20);
295+
assert_eq!(switch.call(42), 255);
297296
}
298297
}
299298

@@ -357,37 +356,37 @@ fn test_bit_shifts() {
357356
builder.build_return(Some(&shift));
358357

359358
unsafe {
360-
let left_shift: Symbol<unsafe extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("left_shift").unwrap();
361-
let right_shift: Symbol<unsafe extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("right_shift").unwrap();
362-
let right_shift_sign_extend: Symbol<unsafe extern "C" fn(i8, u8) -> i8> = execution_engine.get_function("right_shift_sign_extend").unwrap();
363-
364-
assert_eq!(left_shift(0, 0), 0);
365-
assert_eq!(left_shift(0, 4), 0);
366-
assert_eq!(left_shift(1, 0), 1);
367-
assert_eq!(left_shift(1, 1), 2);
368-
assert_eq!(left_shift(1, 2), 4);
369-
assert_eq!(left_shift(1, 3), 8);
370-
assert_eq!(left_shift(64, 1), 128);
371-
372-
assert_eq!(right_shift(128, 1), 64);
373-
assert_eq!(right_shift(8, 3), 1);
374-
assert_eq!(right_shift(4, 2), 1);
375-
assert_eq!(right_shift(2, 1), 1);
376-
assert_eq!(right_shift(1, 0), 1);
377-
assert_eq!(right_shift(0, 4), 0);
378-
assert_eq!(right_shift(0, 0), 0);
379-
380-
assert_eq!(right_shift_sign_extend(8, 3), 1);
381-
assert_eq!(right_shift_sign_extend(4, 2), 1);
382-
assert_eq!(right_shift_sign_extend(2, 1), 1);
383-
assert_eq!(right_shift_sign_extend(1, 0), 1);
384-
assert_eq!(right_shift_sign_extend(0, 4), 0);
385-
assert_eq!(right_shift_sign_extend(0, 0), 0);
386-
assert_eq!(right_shift_sign_extend(-127, 1), -64);
387-
assert_eq!(right_shift_sign_extend(-127, 8), -1);
388-
assert_eq!(right_shift_sign_extend(-65, 3), -9);
389-
assert_eq!(right_shift_sign_extend(-64, 3), -8);
390-
assert_eq!(right_shift_sign_extend(-63, 3), -8);
359+
let left_shift = execution_engine.get_function::<unsafe extern "C" fn(u8, u8) -> u8>("left_shift").unwrap();
360+
let right_shift = execution_engine.get_function::<unsafe extern "C" fn(u8, u8) -> u8>("right_shift").unwrap();
361+
let right_shift_sign_extend = execution_engine.get_function::<unsafe extern "C" fn(i8, u8) -> i8>("right_shift_sign_extend").unwrap();
362+
363+
assert_eq!(left_shift.call(0, 0), 0);
364+
assert_eq!(left_shift.call(0, 4), 0);
365+
assert_eq!(left_shift.call(1, 0), 1);
366+
assert_eq!(left_shift.call(1, 1), 2);
367+
assert_eq!(left_shift.call(1, 2), 4);
368+
assert_eq!(left_shift.call(1, 3), 8);
369+
assert_eq!(left_shift.call(64, 1), 128);
370+
371+
assert_eq!(right_shift.call(128, 1), 64);
372+
assert_eq!(right_shift.call(8, 3), 1);
373+
assert_eq!(right_shift.call(4, 2), 1);
374+
assert_eq!(right_shift.call(2, 1), 1);
375+
assert_eq!(right_shift.call(1, 0), 1);
376+
assert_eq!(right_shift.call(0, 4), 0);
377+
assert_eq!(right_shift.call(0, 0), 0);
378+
379+
assert_eq!(right_shift_sign_extend.call(8, 3), 1);
380+
assert_eq!(right_shift_sign_extend.call(4, 2), 1);
381+
assert_eq!(right_shift_sign_extend.call(2, 1), 1);
382+
assert_eq!(right_shift_sign_extend.call(1, 0), 1);
383+
assert_eq!(right_shift_sign_extend.call(0, 4), 0);
384+
assert_eq!(right_shift_sign_extend.call(0, 0), 0);
385+
assert_eq!(right_shift_sign_extend.call(-127, 1), -64);
386+
assert_eq!(right_shift_sign_extend.call(-127, 8), -1);
387+
assert_eq!(right_shift_sign_extend.call(-65, 3), -9);
388+
assert_eq!(right_shift_sign_extend.call(-64, 3), -8);
389+
assert_eq!(right_shift_sign_extend.call(-63, 3), -8);
391390
}
392391
}
393392

0 commit comments

Comments
 (0)