Skip to content

Commit c5714ea

Browse files
Made the tests pass again
1 parent 022b09e commit c5714ea

File tree

5 files changed

+55
-29
lines changed

5 files changed

+55
-29
lines changed

examples/kaleidoscope/main.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1317,7 +1317,7 @@ pub fn main() {
13171317
if is_anonymous {
13181318
let ee = module.create_jit_execution_engine(OptimizationLevel::None).unwrap();
13191319

1320-
let maybe_fn = unsafe { ee.get_function::<extern "C" fn() -> f64>(name.as_str()) };
1320+
let maybe_fn = unsafe { ee.get_function::<unsafe extern "C" fn() -> f64>(name.as_str()) };
13211321
let compiled_fn = match maybe_fn {
13221322
Ok(f) => f,
13231323
Err(err) => {
@@ -1326,7 +1326,9 @@ pub fn main() {
13261326
}
13271327
};
13281328

1329-
println!("=> {}", compiled_fn());
1329+
unsafe {
1330+
println!("=> {}", compiled_fn());
1331+
}
13301332
}
13311333
}
13321334
}

src/execution_engine.rs

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::rc::Rc;
99
use std::ops::Deref;
1010
use std::ffi::{CStr, CString};
1111
use std::mem::{forget, uninitialized, zeroed, transmute_copy, size_of};
12+
use std::fmt::{self, Debug, Formatter};
1213

1314
#[derive(Debug, PartialEq, Eq)]
1415
pub enum FunctionLookupError {
@@ -164,6 +165,12 @@ impl ExecutionEngine {
164165
/// If a target hasn't already been initialized, spurious "function not
165166
/// found" errors may be encountered.
166167
///
168+
/// The [`UnsafeFunctionPointer`] trait is designed so only `unsafe extern
169+
/// "C"` functions can be retrieved via the `get_function()` method. If you
170+
/// get funny type errors then it's probably because you have specified the
171+
/// wrong calling convention or forgotten to specify the retrieved function
172+
/// is `unsafe`.
173+
///
167174
/// # Examples
168175
///
169176
///
@@ -200,7 +207,6 @@ impl ExecutionEngine {
200207
/// }
201208
/// ```
202209
///
203-
///
204210
/// # Safety
205211
///
206212
/// It is the caller's responsibility to ensure they call the function with
@@ -209,6 +215,8 @@ impl ExecutionEngine {
209215
/// The `Symbol` wrapper ensures a function won't accidentally outlive the
210216
/// execution engine it came from, but adding functions after calling this
211217
/// method *may* invalidate the function pointer.
218+
///
219+
/// [`UnsafeFunctionPointer`]: trait.UnsafeFunctionPointer.html
212220
pub unsafe fn get_function<F>(&self, fn_name: &str) -> Result<Symbol<F>, FunctionLookupError>
213221
where F: UnsafeFunctionPointer
214222
{
@@ -293,10 +301,24 @@ impl ExecutionEngine {
293301
}
294302
}
295303

304+
// Modules owned by the EE will be discarded by the EE so we don't
305+
// want owned modules to drop.
306+
impl Drop for ExecutionEngine {
307+
fn drop(&mut self) {
308+
forget(self.target_data.take().expect("TargetData should always exist until Drop"));
309+
310+
if Rc::strong_count(&self.execution_engine) == 1 {
311+
unsafe {
312+
LLVMDisposeExecutionEngine(*self.execution_engine);
313+
}
314+
}
315+
}
316+
}
317+
296318
/// A wrapper around a function pointer which ensures the symbol being pointed
297319
/// to doesn't accidentally outlive its execution engine.
298-
#[derive(Debug, Clone)]
299-
pub struct Symbol<F: UnsafeFunctionPointer> {
320+
#[derive(Clone)]
321+
pub struct Symbol<F> {
300322
pub(crate) execution_engine: Rc<LLVMExecutionEngineRef>,
301323
inner: F,
302324
}
@@ -309,8 +331,16 @@ impl<F: UnsafeFunctionPointer> Deref for Symbol<F> {
309331
}
310332
}
311333

334+
impl<F> Debug for Symbol<F> {
335+
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
336+
f.debug_tuple("Symbol")
337+
.field(&"<unnamed>")
338+
.finish()
339+
}
340+
}
341+
312342
/// Marker trait representing an unsafe function pointer (`unsafe extern "C" fn(A, B, ...) -> Output`).
313-
pub trait UnsafeFunctionPointer: private::Sealed {}
343+
pub trait UnsafeFunctionPointer: private::Sealed + Copy {}
314344

315345
mod private {
316346
/// A sealed trait which ensures nobody outside this crate can implement
@@ -337,17 +367,8 @@ impl_unsafe_fn!(A, B, C, D, E, F);
337367
impl_unsafe_fn!(A, B, C, D, E, F, G);
338368
impl_unsafe_fn!(A, B, C, D, E, F, G, H);
339369
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I);
370+
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J);
371+
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J, K);
372+
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J, K, L);
373+
impl_unsafe_fn!(A, B, C, D, E, F, G, H, I, J, K, L, M);
340374

341-
// Modules owned by the EE will be discarded by the EE so we don't
342-
// want owned modules to drop.
343-
impl Drop for ExecutionEngine {
344-
fn drop(&mut self) {
345-
forget(self.target_data.take().expect("TargetData should always exist until Drop"));
346-
347-
if Rc::strong_count(&self.execution_engine) == 1 {
348-
unsafe {
349-
LLVMDisposeExecutionEngine(*self.execution_engine);
350-
}
351-
}
352-
}
353-
}

tests/test_builder.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,14 +129,14 @@ fn test_null_checked_ptr_ops() {
129129
let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None).unwrap();
130130

131131
unsafe {
132-
let check_null_index1: Symbol<extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index1").unwrap();
132+
let check_null_index1: Symbol<unsafe extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index1").unwrap();
133133

134134
let array = &[100i8, 42i8];
135135

136136
assert_eq!(check_null_index1(null()), -1i8);
137137
assert_eq!(check_null_index1(array.as_ptr()), 42i8);
138138

139-
let check_null_index2: Symbol<extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index2").unwrap();
139+
let check_null_index2: Symbol<unsafe extern "C" fn(*const i8) -> i8> = execution_engine.get_function("check_null_index2").unwrap();
140140

141141
assert_eq!(check_null_index2(null()), -1i8);
142142
assert_eq!(check_null_index2(array.as_ptr()), 42i8);
@@ -279,7 +279,7 @@ fn test_switch() {
279279
builder.build_return(Some(&double));
280280

281281
unsafe {
282-
let switch: Symbol<extern "C" fn(u8) -> u8> = execution_engine.get_function("switch").unwrap();
282+
let switch: Symbol<unsafe extern "C" fn(u8) -> u8> = execution_engine.get_function("switch").unwrap();
283283

284284
assert_eq!(switch(0), 1);
285285
assert_eq!(switch(1), 2);
@@ -349,9 +349,9 @@ fn test_bit_shifts() {
349349
builder.build_return(Some(&shift));
350350

351351
unsafe {
352-
let left_shift: Symbol<extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("left_shift").unwrap();
353-
let right_shift: Symbol<extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("right_shift").unwrap();
354-
let right_shift_sign_extend: Symbol<extern "C" fn(i8, u8) -> i8> = execution_engine.get_function("right_shift_sign_extend").unwrap();
352+
let left_shift: Symbol<unsafe extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("left_shift").unwrap();
353+
let right_shift: Symbol<unsafe extern "C" fn(u8, u8) -> u8> = execution_engine.get_function("right_shift").unwrap();
354+
let right_shift_sign_extend: Symbol<unsafe extern "C" fn(i8, u8) -> i8> = execution_engine.get_function("right_shift_sign_extend").unwrap();
355355

356356
assert_eq!(left_shift(0, 0), 0);
357357
assert_eq!(left_shift(0, 4), 0);

tests/test_execution_engine.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ use self::inkwell::context::Context;
55
use self::inkwell::execution_engine::FunctionLookupError;
66
use self::inkwell::targets::{InitializationConfig, Target};
77

8+
type Thunk = unsafe extern "C" fn();
9+
810
#[test]
911
fn test_get_function_address() {
1012
let context = Context::create();
@@ -24,7 +26,7 @@ fn test_get_function_address() {
2426
let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None).unwrap();
2527

2628
unsafe {
27-
assert_eq!(execution_engine.get_function::<fn()>("errors").unwrap_err(),
29+
assert_eq!(execution_engine.get_function::<Thunk>("errors").unwrap_err(),
2830
FunctionLookupError::FunctionNotFound);
2931
}
3032

@@ -38,10 +40,10 @@ fn test_get_function_address() {
3840
let execution_engine = module.create_jit_execution_engine(OptimizationLevel::None).unwrap();
3941

4042
unsafe {
41-
assert_eq!(execution_engine.get_function::<fn()>("errors").unwrap_err(),
43+
assert_eq!(execution_engine.get_function::<Thunk>("errors").unwrap_err(),
4244
FunctionLookupError::FunctionNotFound);
4345

44-
assert!(execution_engine.get_function::<fn()>("func").is_ok());
46+
assert!(execution_engine.get_function::<Thunk>("func").is_ok());
4547
}
4648
}
4749

tests/test_tari_example.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ fn test_tari_example() {
3333
builder.build_return(Some(&sum));
3434

3535
unsafe {
36-
let sum: Symbol<extern "C" fn(u64, u64, u64) -> u64> = execution_engine.get_function("sum").unwrap();
36+
type Sum = unsafe extern "C" fn(u64, u64, u64) -> u64;
37+
let sum: Symbol<Sum> = execution_engine.get_function("sum").unwrap();
3738

3839
let x = 1u64;
3940
let y = 2u64;

0 commit comments

Comments
 (0)