@@ -61,12 +61,12 @@ impl CublasContext {
61
61
///
62
62
/// ```
63
63
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
64
- /// # let _a = blastoff::__doctest_setup() ;
65
- /// use blastoff::context::CublasContext;
66
- /// use cust::prelude::*;
67
- /// use cust::memory::DeviceBox;
68
- /// use cust::util::SliceExt;
69
- /// let stream = Stream::new(StreamFlags::DEFAULT, None)?;
64
+ /// # let _a = cust::quick_init()? ;
65
+ /// # use blastoff::context::CublasContext;
66
+ /// # use cust::prelude::*;
67
+ /// # use cust::memory::DeviceBox;
68
+ /// # use cust::util::SliceExt;
69
+ /// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
70
70
/// let mut ctx = CublasContext::new()?;
71
71
/// let data = [0.0f32, 1.0, 0.5, 5.0].as_dbuf()?;
72
72
/// let mut result = DeviceBox::new(&0)?;
@@ -116,19 +116,19 @@ impl CublasContext {
116
116
} )
117
117
}
118
118
119
- /// Finds the index of the smallest element inside of the GPU buffer, writing the resulting
119
+ /// Finds the index of the smallest element inside of the GPU buffer by absolute value , writing the resulting
120
120
/// index into `result`. The index is 1-based, not 0-based.
121
121
///
122
122
/// # Example
123
123
///
124
124
/// ```
125
125
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
126
- /// # let _a = blastoff::__doctest_setup() ;
127
- /// use blastoff::context::CublasContext;
128
- /// use cust::prelude::*;
129
- /// use cust::memory::DeviceBox;
130
- /// use cust::util::SliceExt;
131
- /// let stream = Stream::new(StreamFlags::DEFAULT, None)?;
126
+ /// # let _a = cust::quick_init()? ;
127
+ /// # use blastoff::context::CublasContext;
128
+ /// # use cust::prelude::*;
129
+ /// # use cust::memory::DeviceBox;
130
+ /// # use cust::util::SliceExt;
131
+ /// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
132
132
/// let mut ctx = CublasContext::new()?;
133
133
/// let data = [0.0f32, 1.0, 0.5, 5.0].as_dbuf()?;
134
134
/// let mut result = DeviceBox::new(&0)?;
@@ -149,4 +149,78 @@ impl CublasContext {
149
149
) -> Result {
150
150
self . amax_strided ( stream, x, x. len ( ) , None , result)
151
151
}
152
+
153
+ /// Same as [`CublasContext::axpy`] but with an explicit stride.
154
+ ///
155
+ /// # Panics
156
+ ///
157
+ /// Panics if the buffers are not long enough for the stride and length requested.
158
+ pub fn axpy_strided < T : Level1 > (
159
+ & mut self ,
160
+ stream : & Stream ,
161
+ alpha : & impl GpuBox < T > ,
162
+ n : usize ,
163
+ x : & impl GpuBuffer < T > ,
164
+ x_stride : Option < usize > ,
165
+ y : & mut impl GpuBuffer < T > ,
166
+ y_stride : Option < usize > ,
167
+ ) -> Result {
168
+ check_stride ( x, n, x_stride) ;
169
+ check_stride ( y, n, y_stride) ;
170
+
171
+ self . with_stream ( stream, |ctx| unsafe {
172
+ Ok ( T :: axpy (
173
+ ctx. raw ,
174
+ x. len ( ) as i32 ,
175
+ alpha. as_device_ptr ( ) . as_raw ( ) ,
176
+ x. as_device_ptr ( ) . as_raw ( ) ,
177
+ x_stride. unwrap_or ( 1 ) as i32 ,
178
+ y. as_device_ptr ( ) . as_raw_mut ( ) ,
179
+ y_stride. unwrap_or ( 1 ) as i32 ,
180
+ )
181
+ . to_result ( ) ?)
182
+ } )
183
+ }
184
+
185
+ /// Multiplies `n` elements in `x` by `alpha`, then adds the result to `y`, overwriting
186
+ /// `y` with the result.
187
+ ///
188
+ /// # Panics
189
+ ///
190
+ /// Panics if `x` or `y` are not long enough for the requested length `n`.
191
+ ///
192
+ /// # Example
193
+ ///
194
+ /// ```
195
+ /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
196
+ /// # let _a = cust::quick_init()?;
197
+ /// # use blastoff::context::CublasContext;
198
+ /// # use cust::prelude::*;
199
+ /// # use cust::memory::DeviceBox;
200
+ /// # use cust::util::SliceExt;
201
+ /// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
202
+ /// let mut ctx = CublasContext::new()?;
203
+ /// let alpha = DeviceBox::new(&2.0)?;
204
+ /// let x = [1.0, 2.0, 3.0, 4.0].as_dbuf()?;
205
+ /// let mut y = [1.0; 4].as_dbuf()?;
206
+ ///
207
+ /// ctx.axpy(&stream, &alpha, x.len(), &x, &mut y)?;
208
+ ///
209
+ /// stream.synchronize()?;
210
+ ///
211
+ /// let result = y.as_host_vec()?;
212
+ /// assert_eq!(&result, &[3.0, 5.0, 7.0, 9.0]);
213
+ /// # Ok(())
214
+ /// # }
215
+ /// ```
216
+ pub fn axpy < T : Level1 > (
217
+ & mut self ,
218
+ stream : & Stream ,
219
+ alpha : & impl GpuBox < T > ,
220
+ n : usize ,
221
+ x : & impl GpuBuffer < T > ,
222
+ y : & mut impl GpuBuffer < T > ,
223
+ ) -> Result {
224
+ self . axpy_strided ( stream, alpha, n, x, None , y, None )
225
+ }
152
226
}
0 commit comments