Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 105 additions & 1 deletion ggml/src/ggml-cpu/arch/arm/quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,111 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
int ib = 0;
float sumf = 0;

#if defined __ARM_NEON
#if defined(__ARM_FEATURE_SVE)
svfloat32_t sumv0 = svdup_n_f32(0.0f);
svfloat32_t sumv1 = svdup_n_f32(0.0f);

const int vector_length = ggml_cpu_get_sve_cnt() * 8;
const svbool_t ph16 = svptrue_pat_b8(SV_VL16);
// load LUT
svint8_t lut = svld1_s8(ph16, kvalues_mxfp4);

switch (vector_length) {
case 128:
{
const svbool_t ph4 = svptrue_pat_b32(SV_VL4);

for (; ib + 1 < nb; ib += 2) {
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];

// load x
const svuint8_t qx0r = svld1rq_u8(ph16, x0->qs);
const svuint8_t qx1r = svld1rq_u8(ph16, x1->qs);

// extract nibble
const svuint8_t idx0l = svand_n_u8_m(ph16, qx0r, 0x0F);
const svuint8_t idx0h = svlsr_n_u8_m(ph16, qx0r, 0x04);
const svuint8_t idx1l = svand_n_u8_m(ph16, qx1r, 0x0F);
const svuint8_t idx1h = svlsr_n_u8_m(ph16, qx1r, 0x04);

// 4-bit -> 8-bit
const svint8_t qx0l = svtbl_s8(lut, idx0l);
const svint8_t qx0h = svtbl_s8(lut, idx0h);
const svint8_t qx1l = svtbl_s8(lut, idx1l);
const svint8_t qx1h = svtbl_s8(lut, idx1h);

// load y
const svint8_t qy0h = svld1_s8(ph16, (const int8_t *) (y0->qs));
const svint8_t qy0l = svld1_s8(ph16, (const int8_t *) (y0->qs + 16));
const svint8_t qy1h = svld1_s8(ph16, (const int8_t *) (y1->qs));
const svint8_t qy1l = svld1_s8(ph16, (const int8_t *) (y1->qs + 16));

// dot product
const svint32_t dot0 = svdot_s32(
svdot_s32(svdup_n_s32(0), qx0l, qy0h),
qx0h, qy0l
);
const svint32_t dot1 = svdot_s32(
svdot_s32(svdup_n_s32(0), qx1l, qy1h),
qx1h, qy1l
);

sumv0 = svmla_n_f32_x(ph4, sumv0,
svcvt_f32_s32_x(ph4, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
sumv1 = svmla_n_f32_x(ph4, sumv1,
svcvt_f32_s32_x(ph4, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
}
sumf = svaddv_f32(svptrue_b32(), svadd_f32_x(svptrue_b32(), sumv0, sumv1));
} break;
case 256:
case 512:
{
const svbool_t ph32 = svptrue_pat_b8(SV_VL32);
const svbool_t pl16 = svnot_b_z(ph32, ph16);

for (; ib + 1 < nb; ib += 2) {
const block_mxfp4 * GGML_RESTRICT x0 = &x[ib + 0];
const block_mxfp4 * GGML_RESTRICT x1 = &x[ib + 1];
const block_q8_0 * GGML_RESTRICT y0 = &y[ib + 0];
const block_q8_0 * GGML_RESTRICT y1 = &y[ib + 1];

// load x
const svuint8_t qx0r = svld1rq_u8(ph32, x0->qs);
const svuint8_t qx1r = svld1rq_u8(ph32, x1->qs);

// extract nibble
const svuint8_t idx0 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx0r, 0x0F), 0x04);
const svuint8_t idx1 = svlsr_n_u8_m(pl16, svand_n_u8_m(ph16, qx1r, 0x0F), 0x04);

// 4-bit -> 8-bit
const svint8_t qx0 = svtbl_s8(lut, idx0);
const svint8_t qx1 = svtbl_s8(lut, idx1);

// load y
const svint8_t qy0 = svld1_s8(ph32, y0->qs);
const svint8_t qy1 = svld1_s8(ph32, y1->qs);

// dot product
const svint32_t dot0 = svdot_s32(svdup_n_s32(0), qx0, qy0);
const svint32_t dot1 = svdot_s32(svdup_n_s32(0), qx1, qy1);

sumv0 = svmla_n_f32_x(ph32, sumv0,
svcvt_f32_s32_x(ph32, dot0), GGML_CPU_FP16_TO_FP32(y0->d) * GGML_E8M0_TO_FP32_HALF(x0->e));
sumv1 = svmla_n_f32_x(ph32, sumv1,
svcvt_f32_s32_x(ph32, dot1), GGML_CPU_FP16_TO_FP32(y1->d) * GGML_E8M0_TO_FP32_HALF(x1->e));
}
sumf = svaddv_f32(ph32, svadd_f32_x(ph32, sumv0, sumv1));
} break;

default:
assert(false && "Unsupported vector length");
break;
}

#elif defined (__ARM_NEON)
const int8x16_t values = vld1q_s8(kvalues_mxfp4);
const uint8x16_t m4b = vdupq_n_u8(0x0f);
uint8x16x2_t q4bits;
Expand Down
Loading