1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
//! Streams of work for the device to perform.
//!
//! In CUDA, most work is performed asynchronously. Even tasks such as memory copying can be
//! scheduled by the host and performed when ready. Scheduling this work is done using a Stream.
//!
//! A stream is required for all asynchronous tasks in CUDA, such as kernel launches and
//! asynchronous memory copying. Each task in a stream is performed in the order it was scheduled,
//! and tasks within a stream cannot overlap. Tasks scheduled in multiple streams may interleave or
//! execute concurrently. Sequencing between multiple streams can be achieved using events, which
//! are not currently supported by RustaCUDA. Finally, the host can wait for all work scheduled in
//! a stream to be completed.

use crate::error::{CudaResult, DropResult, ToResult};
use crate::event::Event;
use crate::function::{BlockSize, Function, GridSize};
use cuda_sys::cuda::{self, cudaError_t, CUstream};
use std::ffi::c_void;
use std::mem;
use std::panic;
use std::ptr;

bitflags! {
    /// Bit flags for configuring a CUDA Stream.
    pub struct StreamFlags: u32 {
        /// No flags set.
        const DEFAULT = 0x00;

        /// This stream does not synchronize with the NULL stream.
        ///
        /// Note that the name is chosen to correspond to CUDA documentation, but is nevertheless
        /// misleading. All work within a single stream is ordered and asynchronous regardless
        /// of whether this flag is set. All streams in RustaCUDA may execute work concurrently,
        /// regardless of the flag. However, for legacy reasons, CUDA has a notion of a NULL stream,
        /// which is used as the default when no other stream is provided. Work on other streams
        /// may not be executed concurrently with work on the NULL stream unless this flag is set.
        /// Since RustaCUDA does not provide access to the NULL stream, this flag has no effect in
        /// most circumstances. However, it is recommended to use it anyway, as some other crate
        /// in this binary may be using the NULL stream directly.
        const NON_BLOCKING = 0x01;
    }
}

bitflags! {
    /// Bit flags for configuring a CUDA Stream waiting on an CUDA Event.
    ///
    /// Current versions of CUDA support only the default flag.
    pub struct StreamWaitEventFlags: u32 {
        /// No flags set.
        const DEFAULT = 0x0;
    }
}

/// A stream of work for the device to perform.
///
/// See the module-level documentation for more information.
#[derive(Debug)]
pub struct Stream {
    inner: CUstream,
}
impl Stream {
    /// Create a new stream with the given flags and optional priority.
    ///
    /// By convention, `priority` follows a convention where lower numbers represent greater
    /// priorities. That is, work in a stream with a lower priority number may pre-empt work in
    /// a stream with a higher priority number. `Context::get_stream_priority_range` can be used
    /// to get the range of valid priority values; if priority is set outside that range, it will
    /// be automatically clamped to the lowest or highest number in the range.
    ///
    /// # Examples
    ///
    /// ```
    /// # use rustacuda::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use rustacuda::stream::{Stream, StreamFlags};
    ///
    /// // With default priority
    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
    ///
    /// // With specific priority
    /// let priority = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn new(flags: StreamFlags, priority: Option<i32>) -> CudaResult<Self> {
        unsafe {
            let mut stream = Stream {
                inner: ptr::null_mut(),
            };
            cuda::cuStreamCreateWithPriority(
                &mut stream.inner as *mut CUstream,
                flags.bits(),
                priority.unwrap_or(0),
            )
            .to_result()?;
            Ok(stream)
        }
    }

    /// Return the flags which were used to create this stream.
    ///
    /// # Examples
    ///
    /// ```
    /// # use rustacuda::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use rustacuda::stream::{Stream, StreamFlags};
    ///
    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, None)?;
    /// assert_eq!(StreamFlags::NON_BLOCKING, stream.get_flags().unwrap());
    /// # Ok(())
    /// # }
    /// ```
    pub fn get_flags(&self) -> CudaResult<StreamFlags> {
        unsafe {
            let mut bits = 0u32;
            cuda::cuStreamGetFlags(self.inner, &mut bits as *mut u32).to_result()?;
            Ok(StreamFlags::from_bits_truncate(bits))
        }
    }

    /// Return the priority of this stream.
    ///
    /// If this stream was created without a priority, returns the default priority.
    /// If the stream was created with a priority outside the valid range, returns the clamped
    /// priority.
    ///
    /// # Examples
    ///
    /// ```
    /// # use rustacuda::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use rustacuda::stream::{Stream, StreamFlags};
    ///
    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
    /// println!("{}", stream.get_priority()?);
    /// # Ok(())
    /// # }
    /// ```
    pub fn get_priority(&self) -> CudaResult<i32> {
        unsafe {
            let mut priority = 0i32;
            cuda::cuStreamGetPriority(self.inner, &mut priority as *mut i32).to_result()?;
            Ok(priority)
        }
    }

    /// Add a callback to a stream.
    ///
    /// The callback will be executed after all previously queued
    /// items in the stream have been completed. Subsequently queued
    /// items will not execute until the callback is finished.
    ///
    /// Callbacks must not make any CUDA API calls.
    ///
    /// The callback will be passed a `CudaResult<()>` indicating the
    /// current state of the device with `Ok(())` denoting normal operation.
    ///
    /// # Examples
    ///
    /// ```
    /// # use rustacuda::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use rustacuda::stream::{Stream, StreamFlags};
    ///
    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
    ///
    /// // ... queue up some work on the stream
    ///
    /// stream.add_callback(Box::new(|status| {
    ///     println!("Device status is {:?}", status);
    /// }));
    ///
    /// // ... queue up some more work on the stream
    /// # Ok(())
    /// # }
    pub fn add_callback<T>(&self, callback: Box<T>) -> CudaResult<()>
    where
        T: FnOnce(CudaResult<()>) + Send,
    {
        unsafe {
            cuda::cuStreamAddCallback(
                self.inner,
                Some(callback_wrapper::<T>),
                Box::into_raw(callback) as *mut c_void,
                0,
            )
            .to_result()
        }
    }

    /// Wait until a stream's tasks are completed.
    ///
    /// Waits until the device has completed all operations scheduled for this stream.
    ///
    /// # Examples
    ///
    /// ```
    /// # use rustacuda::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use rustacuda::stream::{Stream, StreamFlags};
    ///
    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
    ///
    /// // ... queue up some work on the stream
    ///
    /// // Wait for the work to be completed.
    /// stream.synchronize()?;
    /// # Ok(())
    /// # }
    /// ```
    pub fn synchronize(&self) -> CudaResult<()> {
        unsafe { cuda::cuStreamSynchronize(self.inner).to_result() }
    }

    /// Make the stream wait on an event.
    ///
    /// All future work submitted to the stream will wait for the event to
    /// complete. Synchronization is performed on the device, if possible. The
    /// event may originate from different context or device than the stream.
    ///
    /// # Example
    ///
    /// ```
    /// # use rustacuda::quick_init;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _context = quick_init()?;
    /// use rustacuda::stream::{Stream, StreamFlags, StreamWaitEventFlags};
    /// use rustacuda::event::{Event, EventFlags};
    ///
    /// let stream_0 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
    /// let stream_1 = Stream::new(StreamFlags::NON_BLOCKING, None)?;
    /// let event = Event::new(EventFlags::DEFAULT)?;
    ///
    /// // do some work on stream_0 ...
    ///
    /// // record an event
    /// event.record(&stream_0)?;
    ///
    /// // wait until the work on stream_0 is finished before continuing stream_1
    /// stream_1.wait_event(event, StreamWaitEventFlags::DEFAULT)?;
    /// # Ok(())
    /// }
    /// ```
    pub fn wait_event(&self, event: Event, flags: StreamWaitEventFlags) -> CudaResult<()> {
        unsafe { cuda::cuStreamWaitEvent(self.inner, event.as_inner(), flags.bits()).to_result() }
    }

    // Hidden implementation detail function. Highly unsafe. Use the `launch!` macro instead.
    #[doc(hidden)]
    pub unsafe fn launch<G, B>(
        &self,
        func: &Function,
        grid_size: G,
        block_size: B,
        shared_mem_bytes: u32,
        args: &[*mut c_void],
    ) -> CudaResult<()>
    where
        G: Into<GridSize>,
        B: Into<BlockSize>,
    {
        let grid_size: GridSize = grid_size.into();
        let block_size: BlockSize = block_size.into();

        cuda::cuLaunchKernel(
            func.to_inner(),
            grid_size.x,
            grid_size.y,
            grid_size.z,
            block_size.x,
            block_size.y,
            block_size.z,
            shared_mem_bytes,
            self.inner,
            args.as_ptr() as *mut _,
            ptr::null_mut(),
        )
        .to_result()
    }

    // Get the inner `CUstream` from the `Stream`.
    //
    // Necessary for certain CUDA functions outside of this
    // module that expect a bare `CUstream`.
    pub(crate) fn as_inner(&self) -> CUstream {
        self.inner
    }

    /// Destroy a `Stream`, returning an error.
    ///
    /// Destroying a stream can return errors from previous asynchronous work. This function
    /// destroys the given stream and returns the error and the un-destroyed stream on failure.
    ///
    /// # Example
    ///
    /// ```
    /// # use rustacuda::*;
    /// # use std::error::Error;
    /// # fn main() -> Result<(), Box<dyn Error>> {
    /// # let _ctx = quick_init()?;
    /// use rustacuda::stream::{Stream, StreamFlags};
    ///
    /// let stream = Stream::new(StreamFlags::NON_BLOCKING, 1i32.into())?;
    /// match Stream::drop(stream) {
    ///     Ok(()) => println!("Successfully destroyed"),
    ///     Err((e, stream)) => {
    ///         println!("Failed to destroy stream: {:?}", e);
    ///         // Do something with stream
    ///     },
    /// }
    /// # Ok(())
    /// # }
    /// ```
    pub fn drop(mut stream: Stream) -> DropResult<Stream> {
        if stream.inner.is_null() {
            return Ok(());
        }

        unsafe {
            let inner = mem::replace(&mut stream.inner, ptr::null_mut());
            match cuda::cuStreamDestroy_v2(inner).to_result() {
                Ok(()) => {
                    mem::forget(stream);
                    Ok(())
                }
                Err(e) => Err((e, Stream { inner })),
            }
        }
    }
}
impl Drop for Stream {
    fn drop(&mut self) {
        if self.inner.is_null() {
            return;
        }

        unsafe {
            let inner = mem::replace(&mut self.inner, ptr::null_mut());
            // No choice but to panic here.
            cuda::cuStreamDestroy_v2(inner)
                .to_result()
                .expect("Failed to destroy CUDA stream.");
        }
    }
}
unsafe extern "C" fn callback_wrapper<T>(
    _stream: CUstream,
    status: cudaError_t,
    callback: *mut c_void,
) where
    T: FnOnce(CudaResult<()>) + Send,
{
    // Stop panics from unwinding across the FFI
    let _ = panic::catch_unwind(|| {
        let callback: Box<T> = Box::from_raw(callback as *mut T);
        callback(status.to_result());
    });
}