Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
rawpointer = { version = "0.2" }

half = {version = "2.7.1", default-features = false, features = ["num-traits"], optional = true}

[dev-dependencies]
defmac = "0.2"
quickcheck = { workspace = true }
Expand All @@ -58,15 +60,17 @@ default = ["std"]
# See README for more instructions
blas = ["dep:cblas-sys", "dep:libc"]

serde = ["dep:serde"]
serde = ["dep:serde", "half?/serde"]

std = ["num-traits/std", "matrixmultiply/std"]
std = ["num-traits/std", "matrixmultiply/std", "half?/std"]
rayon = ["dep:rayon", "std"]

matrixmultiply-threading = ["matrixmultiply/threading"]

portable-atomic-critical-section = ["portable-atomic/critical-section"]

half = ["dep:half"]


[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
portable-atomic = { version = "1.6.0" }
Expand Down Expand Up @@ -115,6 +119,6 @@ tag-name = "{{version}}"

# Config specific to docs.rs
[package.metadata.docs.rs]
features = ["approx", "serde", "rayon"]
features = ["approx", "serde", "rayon", "half"]
# Define the configuration attribute `docsrs`
rustdoc-args = ["--cfg", "docsrs"]
6 changes: 5 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ your `Cargo.toml`.

- Whether ``portable-atomic`` should use ``critical-section``

- ``half``

- Enable support for the ``half::f16`` and ``half::bf16`` types.


How to use with cargo
---------------------

Expand Down Expand Up @@ -179,4 +184,3 @@ http://www.apache.org/licenses/LICENSE-2.0 or the MIT license
http://opensource.org/licenses/MIT, at your
option. This file may not be copied, modified, or distributed
except according to those terms.

32 changes: 31 additions & 1 deletion benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ fn bench_col_iter(bench: &mut test::Bencher)
}

macro_rules! mat_mul {
($modname:ident, $ty:ident, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => {
($modname:ident, $ty:ty, $(($name:ident, $m:expr, $n:expr, $k:expr))+) => {
mod $modname {
use test::{black_box, Bencher};
use ndarray::Array;
Expand All @@ -814,6 +814,36 @@ macro_rules! mat_mul {
};
}

#[cfg(feature = "half")]
mat_mul! {mat_mul_f16, half::f16,
(m004, 4, 4, 4)
(m007, 7, 7, 7)
(m008, 8, 8, 8)
(m012, 12, 12, 12)
(m016, 16, 16, 16)
(m032, 32, 32, 32)
(m064, 64, 64, 64)
(m127, 127, 127, 127) // ~128x slower than f32
(mix16x4, 32, 4, 32)
(mix32x2, 32, 2, 32)
// (mix10000, 128, 10000, 128) // too slow
Comment on lines +826 to +829
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mention that f16 is slower in the issue and several times here. Is it faster on some operations? If not, why do you (and others) want to use it? Only to save space?

Copy link
Author

@swfsql swfsql Dec 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, on my arch (x86_64) it is indeed quite slow. They mention on their docs that aarch64 has support for all operations (for half::f16), and it is possibly viable in performance for that arch -- but I haven't tested it. They also have specialized methods for storing/loading the half::{f16, bf16} data to/from other types (u16, f32, f64) which could improve performance also, but I didn't leverage those operations when including the types to ndarray (I don't really know if/how they could be leveraged).

Albeit (at least for me) it is a bit disappointing that it is slow, I find it useful for debugging fp16 models (for machine learning), given that some architectures behave poorly on fp16 and it is easier to debug them if everything is running on the cpu.
For wasm targets, some builds may not enable cpu features, and thus the performance of f32 and half::f16 should be closer -- so in that case I believe the memory savings could be meaningful, but I believe that is a niche target.

I don't know if proper simd support is possible for fp16, it appears that some work towards this is active (for the primitive f16).

With that being said, I still solicit for half::{f16, bf16} support on ndarray. That makes development and debugging smoother, even if fp16 doesn't have proper simd support -- given that the crux of the training happens on the gpus. In the future it is possible to both have simd improvements from the underlying half, or from a new addition or replacement into a primitive f16 type.
I also understand that ndarray takes performance in high regard, thus possibly opting for a delay or a non-inclusion of the fp16 types.

}

#[cfg(feature = "half")]
mat_mul! {mat_mul_bf16, half::bf16,
(m004, 4, 4, 4)
(m007, 7, 7, 7)
(m008, 8, 8, 8)
(m012, 12, 12, 12)
(m016, 16, 16, 16)
(m032, 32, 32, 32)
(m064, 64, 64, 64)
(m127, 127, 127, 127) // 84x slower than f32
(mix16x4, 32, 4, 32)
(mix32x2, 32, 2, 32)
// (mix10000, 128, 10000, 128) // too slow
}

mat_mul! {mat_mul_f32, f32,
(m004, 4, 4, 4)
(m007, 7, 7, 7)
Expand Down
4 changes: 4 additions & 0 deletions crates/blas-tests/tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,10 @@ fn mat_mut_zero_len()
}
}
});
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::f16>);
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::bf16>);
mat_mul_zero_len!(range_mat::<f32>);
mat_mul_zero_len!(range_mat::<f64>);
mat_mul_zero_len!(range_i32);
Expand Down
3 changes: 3 additions & 0 deletions crates/numeric-tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ rand_distr = { workspace = true }
blas-src = { optional = true, version = "0.10", default-features = false, features = ["openblas"] }
openblas-src = { optional = true, version = ">=0.10.11", default-features = false, features = ["cblas", "system"] }

half = { optional = true, version = "2.7.1", default-features = false, features = ["num-traits", "rand_distr"] }

[dev-dependencies]
num-traits = { workspace = true }
num-complex = { workspace = true }

[features]
test_blas = ["ndarray/blas", "blas-src", "openblas-src"]
half = ["dep:half", "ndarray/half"]

# Config for cargo-release
[package.metadata.release]
Expand Down
28 changes: 28 additions & 0 deletions crates/numeric-tests/tests/accuracy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,20 @@ fn accurate_eye_f64()
}
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_f16_dot()
{
accurate_mul_float_general::<half::f16>(1e-2, false);
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_bf16_dot()
{
accurate_mul_float_general::<half::bf16>(1e-1, false);
}

#[test]
fn accurate_mul_f32_dot()
{
Expand Down Expand Up @@ -222,6 +236,20 @@ where
}
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_complex16()
{
accurate_mul_complex_general::<half::f16>(1e-2);
}

#[test]
#[cfg(feature = "half")]
fn accurate_mul_complexb16()
{
accurate_mul_complex_general::<half::bf16>(1e-1);
}

#[test]
fn accurate_mul_complex32()
{
Expand Down
6 changes: 5 additions & 1 deletion ndarray-rand/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ rand = { workspace = true }
rand_distr = { workspace = true }
quickcheck = { workspace = true, optional = true }

half = { optional = true, version = "2.7.1", default-features = false, features = ["num-traits"] }

[dev-dependencies]
rand_isaac = "0.4.0"
quickcheck = { workspace = true }

[features]
half = ["dep:half", "ndarray/half"]

[package.metadata.release]
tag-name = "ndarray-rand-{{version}}"

16 changes: 16 additions & 0 deletions ndarray-rand/benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ fn uniform_f32(b: &mut Bencher)
b.iter(|| Array::random((m, m), Uniform::new(-1f32, 1.).unwrap()));
}

#[bench]
#[cfg(feature = "half")]
fn norm_f16(b: &mut Bencher)
{
let m = 100;
b.iter(|| Array::random((m, m), Normal::new(half::f16::ZERO, half::f16::ONE).unwrap()));
}

#[bench]
#[cfg(feature = "half")]
fn norm_bf16(b: &mut Bencher)
{
let m = 100;
b.iter(|| Array::random((m, m), Normal::new(half::bf16::ZERO, half::bf16::ONE).unwrap()));
}

#[bench]
fn norm_f32(b: &mut Bencher)
{
Expand Down
46 changes: 46 additions & 0 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,16 @@ impl ScalarOperand for i128 {}
impl ScalarOperand for u128 {}
impl ScalarOperand for isize {}
impl ScalarOperand for usize {}
#[cfg(feature = "half")]
impl ScalarOperand for half::f16 {}
#[cfg(feature = "half")]
impl ScalarOperand for half::bf16 {}
impl ScalarOperand for f32 {}
impl ScalarOperand for f64 {}
#[cfg(feature = "half")]
impl ScalarOperand for Complex<half::f16> {}
#[cfg(feature = "half")]
impl ScalarOperand for Complex<half::bf16> {}
impl ScalarOperand for Complex<f32> {}
impl ScalarOperand for Complex<f64> {}

Expand Down Expand Up @@ -468,6 +476,26 @@ mod arithmetic_ops
impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");

#[cfg(feature = "half")]
mod ops_f16 {
use super::*;
impl_scalar_lhs_op!(half::f16, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(half::f16, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(half::f16, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(half::f16, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(half::f16, Ordered, %, Rem, rem, "remainder");
}

#[cfg(feature = "half")]
mod ops_bf16 {
use super::*;
impl_scalar_lhs_op!(half::bf16, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(half::bf16, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(half::bf16, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(half::bf16, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(half::bf16, Ordered, %, Rem, rem, "remainder");
}

impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
Expand All @@ -480,6 +508,24 @@ mod arithmetic_ops
impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");

#[cfg(feature = "half")]
mod ops_complex_f16 {
use super::*;
impl_scalar_lhs_op!(Complex<half::f16>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<half::f16>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<half::f16>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<half::f16>, Ordered, /, Div, div, "division");
}

#[cfg(feature = "half")]
mod ops_complex_bf16 {
use super::*;
impl_scalar_lhs_op!(Complex<half::bf16>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<half::bf16>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<half::bf16>, Commute, *, Mul, mul, "multiplication");
impl_scalar_lhs_op!(Complex<half::bf16>, Ordered, /, Div, div, "division");
}

impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");
Expand Down
4 changes: 4 additions & 0 deletions tests/oper.rs
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,10 @@ fn mat_mut_zero_len()
}
}
});
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::f16>);
#[cfg(feature = "half")]
mat_mul_zero_len!(range_mat::<half::bf16>);
mat_mul_zero_len!(range_mat::<f32>);
mat_mul_zero_len!(range_mat::<f64>);
mat_mul_zero_len!(range_i32);
Expand Down
Loading