Skip to content

Commit d2dffa3

Browse files
committed
Feat: add axpy to blastoff and improve docs
1 parent 4f8bc92 commit d2dffa3

File tree

6 files changed

+122
-23
lines changed

6 files changed

+122
-23
lines changed

.cargo/config.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[alias]
2-
xtask = "run -p xtask --bin xtask --"
2+
xtask = "run -p xtask --bin xtask --"

crates/blastoff/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,6 @@ bitflags = "1.3.2"
1010
cublas_sys = { version = "0.1", path = "../cublas_sys" }
1111
cust = { version = "0.2", path = "../cust", features = ["num-complex"] }
1212
num-complex = "0.4.0"
13+
14+
[package.metadata.docs.rs]
15+
rustdoc-args = ["--html-in-header", "katex-header.html"]

crates/blastoff/src/context.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ bitflags::bitflags! {
4545
/// cuBLAS contexts hold internal memory allocations required by the library, and will free those allocations on drop. They will
4646
/// also synchronize the entire device when dropping the context. Therefore, you should minimize both the amount of contexts, and the
4747
/// amount of context drops. You should generally allocate all the contexts at once, and drop them all at once.
48+
///
49+
/// # Methods
50+
///
51+
/// ## Level 1 Methods (Scalar/Vector-based operations)
52+
/// - [Index of smallest element by absolute value <span style="float:right;">`amin`</span>](CublasContext::amin)
53+
/// - [Index of largest element by absolute value <span style="float:right;">`amax`</span>](CublasContext::amax)
54+
/// - [$\alpha \boldsymbol{x} + \boldsymbol{y}$ <span style="float:right;">`axpy`</span>](CublasContext::axpy)
4855
#[derive(Debug)]
4956
pub struct CublasContext {
5057
pub(crate) raw: sys::v2::cublasHandle_t,

crates/blastoff/src/level1.rs

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,12 @@ impl CublasContext {
6161
///
6262
/// ```
6363
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
64-
/// # let _a = blastoff::__doctest_setup();
65-
/// use blastoff::context::CublasContext;
66-
/// use cust::prelude::*;
67-
/// use cust::memory::DeviceBox;
68-
/// use cust::util::SliceExt;
69-
/// let stream = Stream::new(StreamFlags::DEFAULT, None)?;
64+
/// # let _a = cust::quick_init()?;
65+
/// # use blastoff::context::CublasContext;
66+
/// # use cust::prelude::*;
67+
/// # use cust::memory::DeviceBox;
68+
/// # use cust::util::SliceExt;
69+
/// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
7070
/// let mut ctx = CublasContext::new()?;
7171
/// let data = [0.0f32, 1.0, 0.5, 5.0].as_dbuf()?;
7272
/// let mut result = DeviceBox::new(&0)?;
@@ -116,19 +116,19 @@ impl CublasContext {
116116
})
117117
}
118118

119-
/// Finds the index of the smallest element inside of the GPU buffer, writing the resulting
119+
/// Finds the index of the smallest element inside of the GPU buffer by absolute value, writing the resulting
120120
/// index into `result`. The index is 1-based, not 0-based.
121121
///
122122
/// # Example
123123
///
124124
/// ```
125125
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
126-
/// # let _a = blastoff::__doctest_setup();
127-
/// use blastoff::context::CublasContext;
128-
/// use cust::prelude::*;
129-
/// use cust::memory::DeviceBox;
130-
/// use cust::util::SliceExt;
131-
/// let stream = Stream::new(StreamFlags::DEFAULT, None)?;
126+
/// # let _a = cust::quick_init()?;
127+
/// # use blastoff::context::CublasContext;
128+
/// # use cust::prelude::*;
129+
/// # use cust::memory::DeviceBox;
130+
/// # use cust::util::SliceExt;
131+
/// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
132132
/// let mut ctx = CublasContext::new()?;
133133
/// let data = [0.0f32, 1.0, 0.5, 5.0].as_dbuf()?;
134134
/// let mut result = DeviceBox::new(&0)?;
@@ -149,4 +149,78 @@ impl CublasContext {
149149
) -> Result {
150150
self.amax_strided(stream, x, x.len(), None, result)
151151
}
152+
153+
/// Same as [`CublasContext::axpy`] but with an explicit stride.
154+
///
155+
/// # Panics
156+
///
157+
/// Panics if the buffers are not long enough for the stride and length requested.
158+
pub fn axpy_strided<T: Level1>(
159+
&mut self,
160+
stream: &Stream,
161+
alpha: &impl GpuBox<T>,
162+
n: usize,
163+
x: &impl GpuBuffer<T>,
164+
x_stride: Option<usize>,
165+
y: &mut impl GpuBuffer<T>,
166+
y_stride: Option<usize>,
167+
) -> Result {
168+
check_stride(x, n, x_stride);
169+
check_stride(y, n, y_stride);
170+
171+
self.with_stream(stream, |ctx| unsafe {
172+
Ok(T::axpy(
173+
ctx.raw,
174+
x.len() as i32,
175+
alpha.as_device_ptr().as_raw(),
176+
x.as_device_ptr().as_raw(),
177+
x_stride.unwrap_or(1) as i32,
178+
y.as_device_ptr().as_raw_mut(),
179+
y_stride.unwrap_or(1) as i32,
180+
)
181+
.to_result()?)
182+
})
183+
}
184+
185+
/// Multiplies `n` elements in `x` by `alpha`, then adds the result to `y`, overwriting
186+
/// `y` with the result.
187+
///
188+
/// # Panics
189+
///
190+
/// Panics if `x` or `y` are not long enough for the requested length `n`.
191+
///
192+
/// # Example
193+
///
194+
/// ```
195+
/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
196+
/// # let _a = cust::quick_init()?;
197+
/// # use blastoff::context::CublasContext;
198+
/// # use cust::prelude::*;
199+
/// # use cust::memory::DeviceBox;
200+
/// # use cust::util::SliceExt;
201+
/// # let stream = Stream::new(StreamFlags::DEFAULT, None)?;
202+
/// let mut ctx = CublasContext::new()?;
203+
/// let alpha = DeviceBox::new(&2.0)?;
204+
/// let x = [1.0, 2.0, 3.0, 4.0].as_dbuf()?;
205+
/// let mut y = [1.0; 4].as_dbuf()?;
206+
///
207+
/// ctx.axpy(&stream, &alpha, x.len(), &x, &mut y)?;
208+
///
209+
/// stream.synchronize()?;
210+
///
211+
/// let result = y.as_host_vec()?;
212+
/// assert_eq!(&result, &[3.0, 5.0, 7.0, 9.0]);
213+
/// # Ok(())
214+
/// # }
215+
/// ```
216+
pub fn axpy<T: Level1>(
217+
&mut self,
218+
stream: &Stream,
219+
alpha: &impl GpuBox<T>,
220+
n: usize,
221+
x: &impl GpuBuffer<T>,
222+
y: &mut impl GpuBuffer<T>,
223+
) -> Result {
224+
self.axpy_strided(stream, alpha, n, x, None, y, None)
225+
}
152226
}

crates/blastoff/src/lib.rs

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,18 @@
55
//!
66
//! **blastoff uses 1-based indexing, reflecting cuBLAS' behavior. This means
77
//! you will likely need to do some math to any returned indices. For example,
8-
//! [`amin`](crate::context::CublasContext::amin) returns a 1-based index.
8+
//! [`amin`](crate::context::CublasContext::amin) returns a 1-based index.**
99
1010
pub use cublas_sys as sys;
1111
use num_complex::{Complex32, Complex64};
1212

13-
pub mod context;
13+
pub use context::*;
14+
15+
mod context;
1416
pub mod error;
1517
mod level1;
1618
pub mod raw;
1719

18-
#[doc(hidden)]
19-
pub fn __doctest_setup() -> (cust::context::Context, cust::stream::Stream) {
20-
let ctx = cust::quick_init().unwrap();
21-
let stream = cust::stream::Stream::new(cust::stream::StreamFlags::DEFAULT, None).unwrap();
22-
(ctx, stream)
23-
}
24-
2520
pub trait BlasDatatype: private::Sealed + cust::memory::DeviceCopy {
2621
/// The corresponding float type. For complex numbers this means their backing
2722
/// precision, and for floats it is just themselves.

katex-header.html

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.css"
2+
integrity="sha384-9eLZqc9ds8eNjO3TmqPeYcDj8n+Qfa4nuSiGYa6DjLNcv9BtN69ZIulL9+8CqC9Y" crossorigin="anonymous">
3+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/katex.min.js"
4+
integrity="sha384-K3vbOmF2BtaVai+Qk37uypf7VrgBubhQreNQe9aGsz9lB63dIFiQVlJbr92dw2Lx"
5+
crossorigin="anonymous"></script>
6+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/contrib/auto-render.min.js"
7+
integrity="sha384-kmZOZB5ObwgQnS/DuDg6TScgOiWWBiVt0plIRkZCmE6rDZGrEOQeHM5PcHi+nyqe"
8+
crossorigin="anonymous"></script>
9+
<script>
10+
document.addEventListener("DOMContentLoaded", function () {
11+
renderMathInElement(document.body, {
12+
delimiters: [
13+
{ left: "$$", right: "$$", display: true },
14+
{ left: "\\(", right: "\\)", display: false },
15+
{ left: "$", right: "$", display: false },
16+
{ left: "\\[", right: "\\]", display: true }
17+
]
18+
});
19+
});
20+
</script>

0 commit comments

Comments
 (0)