use crate::error::{CudaResult, DropResult, ToResult};
use crate::memory::device::{AsyncCopyDestination, CopyDestination, DeviceSlice};
use crate::memory::malloc::{cuda_free, cuda_malloc};
use crate::memory::DeviceCopy;
use crate::memory::DevicePointer;
use crate::stream::Stream;
use cuda_sys::cuda;
use std::mem;
use std::ops::{Deref, DerefMut};
use std::ptr;
#[derive(Debug)]
pub struct DeviceBuffer<T> {
buf: DevicePointer<T>,
capacity: usize,
}
impl<T> DeviceBuffer<T> {
pub unsafe fn uninitialized(size: usize) -> CudaResult<Self> {
let ptr = if size > 0 && mem::size_of::<T>() > 0 {
cuda_malloc(size)?
} else {
DevicePointer::wrap(ptr::NonNull::dangling().as_ptr() as *mut T)
};
Ok(DeviceBuffer {
buf: ptr,
capacity: size,
})
}
pub unsafe fn zeroed(size: usize) -> CudaResult<Self> {
let ptr = if size > 0 && mem::size_of::<T>() > 0 {
let mut ptr = cuda_malloc(size)?;
cuda::cuMemsetD8_v2(ptr.as_raw_mut() as u64, 0, size * mem::size_of::<T>())
.to_result()?;
ptr
} else {
DevicePointer::wrap(ptr::NonNull::dangling().as_ptr() as *mut T)
};
Ok(DeviceBuffer {
buf: ptr,
capacity: size,
})
}
pub unsafe fn from_raw_parts(ptr: DevicePointer<T>, capacity: usize) -> DeviceBuffer<T> {
DeviceBuffer { buf: ptr, capacity }
}
pub fn drop(mut dev_buf: DeviceBuffer<T>) -> DropResult<DeviceBuffer<T>> {
if dev_buf.buf.is_null() {
return Ok(());
}
if dev_buf.capacity > 0 && mem::size_of::<T>() > 0 {
let capacity = dev_buf.capacity;
let ptr = mem::replace(&mut dev_buf.buf, DevicePointer::null());
unsafe {
match cuda_free(ptr) {
Ok(()) => {
mem::forget(dev_buf);
Ok(())
}
Err(e) => Err((e, DeviceBuffer::from_raw_parts(ptr, capacity))),
}
}
} else {
Ok(())
}
}
}
impl<T: DeviceCopy> DeviceBuffer<T> {
pub fn from_slice(slice: &[T]) -> CudaResult<Self> {
unsafe {
let mut uninit = DeviceBuffer::uninitialized(slice.len())?;
uninit.copy_from(slice)?;
Ok(uninit)
}
}
pub unsafe fn from_slice_async(slice: &[T], stream: &Stream) -> CudaResult<Self> {
let mut uninit = DeviceBuffer::uninitialized(slice.len())?;
uninit.async_copy_from(slice, stream)?;
Ok(uninit)
}
}
impl<T> Deref for DeviceBuffer<T> {
type Target = DeviceSlice<T>;
fn deref(&self) -> &DeviceSlice<T> {
unsafe {
DeviceSlice::from_slice(::std::slice::from_raw_parts(
self.buf.as_raw(),
self.capacity,
))
}
}
}
impl<T> DerefMut for DeviceBuffer<T> {
fn deref_mut(&mut self) -> &mut DeviceSlice<T> {
unsafe {
&mut *(::std::slice::from_raw_parts_mut(self.buf.as_raw_mut(), self.capacity)
as *mut [T] as *mut DeviceSlice<T>)
}
}
}
impl<T> Drop for DeviceBuffer<T> {
fn drop(&mut self) {
if self.buf.is_null() {
return;
}
if self.capacity > 0 && mem::size_of::<T>() > 0 {
let ptr = mem::replace(&mut self.buf, DevicePointer::null());
unsafe {
cuda_free(ptr).expect("Failed to deallocate CUDA Device memory.");
}
}
self.capacity = 0;
}
}
#[cfg(test)]
mod test_device_buffer {
use super::*;
use crate::memory::device::DeviceBox;
use crate::stream::{Stream, StreamFlags};
#[derive(Clone, Debug)]
struct ZeroSizedType;
unsafe impl DeviceCopy for ZeroSizedType {}
#[test]
fn test_from_slice_drop() {
let _context = crate::quick_init().unwrap();
let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
drop(buf);
}
#[test]
fn test_copy_to_from_device() {
let _context = crate::quick_init().unwrap();
let start = [0u64, 1, 2, 3, 4, 5];
let mut end = [0u64, 0, 0, 0, 0, 0];
let buf = DeviceBuffer::from_slice(&start).unwrap();
buf.copy_to(&mut end).unwrap();
assert_eq!(start, end);
}
#[test]
fn test_async_copy_to_from_device() {
let _context = crate::quick_init().unwrap();
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
let start = [0u64, 1, 2, 3, 4, 5];
let mut end = [0u64, 0, 0, 0, 0, 0];
unsafe {
let buf = DeviceBuffer::from_slice_async(&start, &stream).unwrap();
buf.async_copy_to(&mut end, &stream).unwrap();
}
stream.synchronize().unwrap();
assert_eq!(start, end);
}
#[test]
fn test_slice() {
let _context = crate::quick_init().unwrap();
let start = [0u64, 1, 2, 3, 4, 5];
let mut end = [0u64, 0];
let mut buf = DeviceBuffer::from_slice(&[0u64, 0, 0, 0]).unwrap();
buf.copy_from(&start[0..4]).unwrap();
buf[0..2].copy_to(&mut end).unwrap();
assert_eq!(start[0..2], end);
}
#[test]
fn test_async_slice() {
let _context = crate::quick_init().unwrap();
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
let start = [0u64, 1, 2, 3, 4, 5];
let mut end = [0u64, 0];
unsafe {
let mut buf = DeviceBuffer::from_slice_async(&[0u64, 0, 0, 0], &stream).unwrap();
buf.async_copy_from(&start[0..4], &stream).unwrap();
buf[0..2].async_copy_to(&mut end, &stream).unwrap();
stream.synchronize().unwrap();
assert_eq!(start[0..2], end);
}
}
#[test]
#[should_panic]
fn test_copy_to_d2h_wrong_size() {
let _context = crate::quick_init().unwrap();
let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
let mut end = [0u64, 1, 2, 3, 4];
let _ = buf.copy_to(&mut end);
}
#[test]
#[should_panic]
fn test_async_copy_to_d2h_wrong_size() {
let _context = crate::quick_init().unwrap();
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
unsafe {
let buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
let mut end = [0u64, 1, 2, 3, 4];
let _ = buf.async_copy_to(&mut end, &stream);
}
}
#[test]
#[should_panic]
fn test_copy_from_h2d_wrong_size() {
let _context = crate::quick_init().unwrap();
let start = [0u64, 1, 2, 3, 4];
let mut buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
let _ = buf.copy_from(&start);
}
#[test]
#[should_panic]
fn test_async_copy_from_h2d_wrong_size() {
let _context = crate::quick_init().unwrap();
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
let start = [0u64, 1, 2, 3, 4];
unsafe {
let mut buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
let _ = buf.async_copy_from(&start, &stream);
}
}
#[test]
fn test_copy_device_slice_to_device() {
let _context = crate::quick_init().unwrap();
let start = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
let mut mid = DeviceBuffer::from_slice(&[0u64, 0, 0, 0]).unwrap();
let mut end = DeviceBuffer::from_slice(&[0u64, 0]).unwrap();
let mut host_end = [0u64, 0];
start[1..5].copy_to(&mut mid).unwrap();
end.copy_from(&mid[1..3]).unwrap();
end.copy_to(&mut host_end).unwrap();
assert_eq!([2u64, 3], host_end);
}
#[test]
fn test_async_copy_device_slice_to_device() {
let _context = crate::quick_init().unwrap();
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
unsafe {
let start = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
let mut mid = DeviceBuffer::from_slice_async(&[0u64, 0, 0, 0], &stream).unwrap();
let mut end = DeviceBuffer::from_slice_async(&[0u64, 0], &stream).unwrap();
let mut host_end = [0u64, 0];
start[1..5].async_copy_to(&mut mid, &stream).unwrap();
end.async_copy_from(&mid[1..3], &stream).unwrap();
end.async_copy_to(&mut host_end, &stream).unwrap();
stream.synchronize().unwrap();
assert_eq!([2u64, 3], host_end);
}
}
#[test]
#[should_panic]
fn test_copy_to_d2d_wrong_size() {
let _context = crate::quick_init().unwrap();
let buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
let mut end = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4]).unwrap();
let _ = buf.copy_to(&mut end);
}
#[test]
#[should_panic]
fn test_async_copy_to_d2d_wrong_size() {
let _context = crate::quick_init().unwrap();
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
unsafe {
let buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
let mut end = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4], &stream).unwrap();
let _ = buf.async_copy_to(&mut end, &stream);
}
}
#[test]
#[should_panic]
fn test_copy_from_d2d_wrong_size() {
let _context = crate::quick_init().unwrap();
let mut buf = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4, 5]).unwrap();
let start = DeviceBuffer::from_slice(&[0u64, 1, 2, 3, 4]).unwrap();
let _ = buf.copy_from(&start);
}
#[test]
#[should_panic]
fn test_async_copy_from_d2d_wrong_size() {
let _context = crate::quick_init().unwrap();
let stream = Stream::new(StreamFlags::NON_BLOCKING, None).unwrap();
unsafe {
let mut buf = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4, 5], &stream).unwrap();
let start = DeviceBuffer::from_slice_async(&[0u64, 1, 2, 3, 4], &stream).unwrap();
let _ = buf.async_copy_from(&start, &stream);
}
}
#[test]
fn test_can_create_uninitialized_non_devicecopy_buffers() {
let _context = crate::quick_init().unwrap();
unsafe {
let _box: DeviceBox<Vec<u8>> = DeviceBox::uninitialized().unwrap();
let buffer: DeviceBuffer<Vec<u8>> = DeviceBuffer::uninitialized(10).unwrap();
let _slice = &buffer[0..5];
}
}
#[test]
fn test_allocate_correct_size() {
use crate::context::CurrentContext;
let _context = crate::quick_init().unwrap();
let total_memory = CurrentContext::get_device()
.unwrap()
.total_memory()
.unwrap();
let allocation_size = (total_memory * 3) / 4 / mem::size_of::<u64>();
unsafe {
let _buffer = DeviceBuffer::<u64>::uninitialized(allocation_size).unwrap();
};
}
}