diff --git a/crates/cust/src/memory/device/device_slice.rs b/crates/cust/src/memory/device/device_slice.rs index 4c3b93f2..ea9cbb81 100644 --- a/crates/cust/src/memory/device/device_slice.rs +++ b/crates/cust/src/memory/device/device_slice.rs @@ -13,6 +13,7 @@ use std::ops::{ Index, IndexMut, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive, }; use std::os::raw::c_void; +use std::slice; use std::ptr::{slice_from_raw_parts, slice_from_raw_parts_mut}; /// Fixed-size device-side slice. @@ -35,14 +36,6 @@ impl Debug for DeviceSlice { } } -impl DeviceSlice { - pub fn as_host_vec(&self) -> CudaResult> { - let mut vec = vec![T::default(); self.len()]; - self.copy_to(&mut vec)?; - Ok(vec) - } -} - // This works by faking a regular slice out of the device raw-pointer and the length and transmuting // I have no idea if this is safe or not. Probably not, though I can't imagine how the compiler // could possibly know that the pointer is not de-referenceable. I'm banking that we get proper @@ -94,6 +87,17 @@ impl DeviceSlice { DevicePointer::from_raw(self as *const _ as *const () as usize as u64) } + pub fn as_host_vec(&self) -> CudaResult> { + let mut vec = Vec::with_capacity(self.len()); + // SAFETY: The slice points to uninitialized memory, but we only write to it. Once it is + // written, all values are valid, so we can (and must) change the length of the vector. + unsafe { + self.copy_to(slice::from_raw_parts_mut(vec.as_mut_ptr(), self.len()))?; + vec.set_len(self.len()) + } + Ok(vec) + } + /* TODO (AL): keep these? /// Divides one DeviceSlice into two at a given index. ///