Skip to content

Commit bc68d2f

Browse files
committed
add half::16 support
1 parent 8ff477f commit bc68d2f

File tree

3 files changed

+58
-0
lines changed

3 files changed

+58
-0
lines changed

Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ matrixmultiply = { version = "0.3.2", default-features = false, features=["cgemm
4444
serde = { version = "1.0", optional = true, default-features = false, features = ["alloc"] }
4545
rawpointer = { version = "0.2" }
4646

47+
half = {version = "2.7.1", default-features = false, features = ["num-traits"], optional = true}
48+
4749
[dev-dependencies]
4850
defmac = "0.2"
4951
quickcheck = { workspace = true }
@@ -67,6 +69,8 @@ matrixmultiply-threading = ["matrixmultiply/threading"]
6769

6870
portable-atomic-critical-section = ["portable-atomic/critical-section"]
6971

72+
half = ["dep:half"]
73+
7074

7175
[target.'cfg(not(target_has_atomic = "ptr"))'.dependencies]
7276
portable-atomic = { version = "1.6.0" }

src/impl_ops.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,16 @@ impl ScalarOperand for i128 {}
4545
impl ScalarOperand for u128 {}
4646
impl ScalarOperand for isize {}
4747
impl ScalarOperand for usize {}
48+
#[cfg(feature = "half")]
49+
impl ScalarOperand for half::f16 {}
50+
#[cfg(feature = "half")]
51+
impl ScalarOperand for half::bf16 {}
4852
impl ScalarOperand for f32 {}
4953
impl ScalarOperand for f64 {}
54+
#[cfg(feature = "half")]
55+
impl ScalarOperand for Complex<half::f16> {}
56+
#[cfg(feature = "half")]
57+
impl ScalarOperand for Complex<half::bf16> {}
5058
impl ScalarOperand for Complex<f32> {}
5159
impl ScalarOperand for Complex<f64> {}
5260

@@ -468,6 +476,26 @@ mod arithmetic_ops
468476
impl_scalar_lhs_op!(bool, Commute, |, BitOr, bitor, "bit or");
469477
impl_scalar_lhs_op!(bool, Commute, ^, BitXor, bitxor, "bit xor");
470478

479+
#[cfg(feature = "half")]
480+
mod ops_f16 {
481+
use super::*;
482+
impl_scalar_lhs_op!(half::f16, Commute, +, Add, add, "addition");
483+
impl_scalar_lhs_op!(half::f16, Ordered, -, Sub, sub, "subtraction");
484+
impl_scalar_lhs_op!(half::f16, Commute, *, Mul, mul, "multiplication");
485+
impl_scalar_lhs_op!(half::f16, Ordered, /, Div, div, "division");
486+
impl_scalar_lhs_op!(half::f16, Ordered, %, Rem, rem, "remainder");
487+
}
488+
489+
#[cfg(feature = "half")]
490+
mod ops_bf16 {
491+
use super::*;
492+
impl_scalar_lhs_op!(half::bf16, Commute, +, Add, add, "addition");
493+
impl_scalar_lhs_op!(half::bf16, Ordered, -, Sub, sub, "subtraction");
494+
impl_scalar_lhs_op!(half::bf16, Commute, *, Mul, mul, "multiplication");
495+
impl_scalar_lhs_op!(half::bf16, Ordered, /, Div, div, "division");
496+
impl_scalar_lhs_op!(half::bf16, Ordered, %, Rem, rem, "remainder");
497+
}
498+
471499
impl_scalar_lhs_op!(f32, Commute, +, Add, add, "addition");
472500
impl_scalar_lhs_op!(f32, Ordered, -, Sub, sub, "subtraction");
473501
impl_scalar_lhs_op!(f32, Commute, *, Mul, mul, "multiplication");
@@ -480,6 +508,24 @@ mod arithmetic_ops
480508
impl_scalar_lhs_op!(f64, Ordered, /, Div, div, "division");
481509
impl_scalar_lhs_op!(f64, Ordered, %, Rem, rem, "remainder");
482510

511+
#[cfg(feature = "half")]
512+
mod ops_complex_f16 {
513+
use super::*;
514+
impl_scalar_lhs_op!(Complex<half::f16>, Commute, +, Add, add, "addition");
515+
impl_scalar_lhs_op!(Complex<half::f16>, Ordered, -, Sub, sub, "subtraction");
516+
impl_scalar_lhs_op!(Complex<half::f16>, Commute, *, Mul, mul, "multiplication");
517+
impl_scalar_lhs_op!(Complex<half::f16>, Ordered, /, Div, div, "division");
518+
}
519+
520+
#[cfg(feature = "half")]
521+
mod ops_complex_bf16 {
522+
use super::*;
523+
impl_scalar_lhs_op!(Complex<half::bf16>, Commute, +, Add, add, "addition");
524+
impl_scalar_lhs_op!(Complex<half::bf16>, Ordered, -, Sub, sub, "subtraction");
525+
impl_scalar_lhs_op!(Complex<half::bf16>, Commute, *, Mul, mul, "multiplication");
526+
impl_scalar_lhs_op!(Complex<half::bf16>, Ordered, /, Div, div, "division");
527+
}
528+
483529
impl_scalar_lhs_op!(Complex<f32>, Commute, +, Add, add, "addition");
484530
impl_scalar_lhs_op!(Complex<f32>, Ordered, -, Sub, sub, "subtraction");
485531
impl_scalar_lhs_op!(Complex<f32>, Commute, *, Mul, mul, "multiplication");

src/linalg_traits.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ pub trait NdFloat:
6161
{
6262
}
6363

64+
#[cfg(all(feature = "std", feature = "half"))]
65+
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
66+
impl NdFloat for half::f16 {}
67+
68+
#[cfg(all(feature = "std", feature = "half"))]
69+
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
70+
impl NdFloat for half::bf16 {}
71+
6472
#[cfg(feature = "std")]
6573
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
6674
impl NdFloat for f32 {}

0 commit comments

Comments
 (0)