From e26fd70dce5fc88dfa04e355ade91ab526f007b1 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Mon, 12 Aug 2024 05:54:21 -0700 Subject: [PATCH 01/12] Introduce Q4_0 and Q8_0 quantizations with BF16 delta values --- ggml/include/ggml.h | 6 + ggml/src/ggml-impl.h | 2 + ggml/src/ggml-quants.c | 550 +++++++++++++++++++++++++++++++++++ ggml/src/ggml-quants.h | 8 + ggml/src/ggml.c | 51 ++++ ggml/src/llamafile/sgemm.cpp | 412 ++++++++++++++++++++++++++ gguf-py/gguf/constants.py | 34 +++ src/llama.cpp | 25 +- 8 files changed, 1077 insertions(+), 11 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 15602a96df7ad..fc25aa7101245 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -346,6 +346,7 @@ extern "C" { // google brain half-precision bfloat16 typedef struct { uint16_t bits; } ggml_bf16_t; + GGML_API ggml_bf16_t ggml_make_bf16(uint16_t val); GGML_API ggml_bf16_t ggml_fp32_to_bf16(float); GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16 GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t); @@ -432,9 +433,14 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors +<<<<<<< HEAD GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors +======= + GGML_FTYPE_MOSTLY_Q4_0_B16 = 25, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_0_B16 = 26, // except 1d tensors +>>>>>>> ed837022 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) }; // available tensor operations: diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 190af081031da..3f7963f660281 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -20,11 +20,13 @@ #if defined(_MSC_VER) #define m512bh(p) p +#define m128bh(p) p #define m512i(p) p #else #define m512bh(p) (__m512bh)(p) +#define m128bh(p) (__m128bh)(p) #define m512i(p) (__m512i)(p) #endif diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index d5b91c2dbc0c1..041a5cfd69aa9 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -699,6 +699,48 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { quantize_row_q4_0_ref(x, y, k); } +// reference implementation for deterministic creation of model files +void quantize_row_q4_0_b16_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + float max = 0.0f; + + for (int j = 0; j < qk; j++) { + const float v = x[i*qk + j]; + if (amax < fabsf(v)) { + amax = fabsf(v); + max = v; + } + } + + const float d = max / -8; + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = (GGML_FP32_TO_BF16(d)).bits; + + for (int j = 0; j < qk/2; ++j) { + const float x0 = x[i*qk + 0 + j]*id; + const float x1 = x[i*qk + qk/2 + j]*id; + + const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f)); + const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f)); + + y[i].qs[j] = xi0; + y[i].qs[j] |= xi1 << 4; + } + } +} + +void quantize_row_q4_0_b16(const float * restrict x, void * restrict y, int64_t k) { + quantize_row_q4_0_b16_reference(x, y, k); +} + void quantize_row_q4_1_ref(const float * restrict x, block_q4_1 * restrict y, int64_t k) { const int qk = QK4_1; @@ -1532,6 +1574,27 @@ void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int6 } } +void dequantize_row_q4_0_b16(const block_q4_0 * restrict x, float * restrict y, int64_t k) { + static const int qk = QK4_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + const float d = GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)); + + for (int j = 0; j < qk/2; ++j) { + const int x0 = (x[i].qs[j] & 0x0F) - 8; + const int x1 = (x[i].qs[j] >> 4) - 8; + + y[i*qk + j + 0 ] = x0*d; + y[i*qk + j + qk/2] = x1*d; + } + } +} + + void dequantize_row_q4_1(const block_q4_1 * restrict x, float * restrict y, int64_t k) { static const int qk = QK4_1; @@ -1622,6 +1685,24 @@ void dequantize_row_q8_0(const block_q8_0 * restrict x, float * restrict y, int6 } } +void dequantize_row_q8_0_b16(const block_q8_0 * restrict x, float * restrict y, int64_t k) { + static const int qk = QK8_0; + + assert(k % qk == 0); + + const int nb = k / qk; + + for (int i = 0; i < nb; i++) { + + const float d = GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)); + + for (int j = 0; j < qk; ++j) { + y[i*qk + j] = x[i].qs[j]*d; + } + } +} + + // // 2-6 bit quantization in super-blocks // @@ -3132,6 +3213,34 @@ static void quantize_row_q4_0_impl(const float * restrict x, block_q4_0 * restri } } +static void quantize_row_q4_0_b16_impl(const float * restrict x, block_q4_0 * restrict y, int64_t n_per_row, const float * quant_weights) { + static_assert(QK4_0 == 32, "QK4_0 must be 32"); + + if (!quant_weights) { + quantize_row_q4_0_b16_reference(x, y, n_per_row); + return; + } + + float weight[QK4_0]; + int8_t L[QK4_0]; + + float sum_x2 = 0; + for (int j = 0; j < n_per_row; ++j) sum_x2 += x[j]*x[j]; + float sigma2 = sum_x2/n_per_row; + + const int64_t nb = n_per_row/QK4_0; + for (int ib = 0; ib < nb; ++ib) { + const float * xb = x + QK4_0 * ib; + const float * qw = quant_weights + QK4_0 * ib; + for (int j = 0; j < QK4_0; ++j) weight[j] = qw[j] * sqrtf(sigma2 + xb[j]*xb[j]); + float d = make_qx_quants(QK4_0, 8, xb, L, 1, weight); + y[ib].d = (GGML_FP32_TO_BF16(d)).bits; + for (int j = 0; j < 16; ++j) { + y[ib].qs[j] = L[j] | (L[j+16] << 4); + } + } +} + size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { quantize_row_q4_0_ref(src, dst, (int64_t)nrow*n_per_row); @@ -3147,6 +3256,21 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +size_t quantize_q4_0_b16(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + if (!quant_weights) { + quantize_row_q4_0_b16_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * ggml_row_size(GGML_TYPE_Q4_0_B16, n_per_row); + } + size_t row_size = ggml_row_size(GGML_TYPE_Q4_0_B16, n_per_row); + char * qrow = (char *)dst; + for (int64_t row = 0; row < nrow; ++row) { + quantize_row_q4_0_b16_impl(src, (block_q4_0*)qrow, n_per_row, quant_weights); + src += n_per_row; + qrow += row_size; + } + return nrow * row_size; +} + static void quantize_row_q4_1_impl(const float * restrict x, block_q4_1 * restrict y, int64_t n_per_row, const float * quant_weights) { static_assert(QK4_1 == 32, "QK4_1 must be 32"); @@ -3306,6 +3430,13 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr return nrow * row_size; } +size_t quantize_q8_0_b16(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { + (void)quant_weights; // not used + const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0_B16, n_per_row); + quantize_row_q8_0_b16_reference(src, dst, (int64_t)nrow*n_per_row); + return nrow * row_size; +} + // ====================== "True" 2-bit (de)-quantization void dequantize_row_iq2_xxs(const block_iq2_xxs * restrict x, float * restrict y, int64_t k) { @@ -4208,6 +4339,278 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r *s = sumf; } +void ggml_vec_dot_q4_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q4_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + + // Initialize accumulator with zeros +#if defined(__AVX512BF16__) + __m256 acc = _mm256_setzero_ps(); + __m128 zerovec = _mm_setzero_ps(); + const __m256i off = _mm256_set1_epi8( 8 ); + int nbmod = nb - (nb % 4); + for (int i = 0; i < nbmod; i+=4) { + // Compute combined scale for set of four blocks + uint64_t x_delta = ((uint64_t)x[i+3].d << 48) | ((uint64_t)x[i+2].d << 32) | ((uint64_t)x[i+1].d << 16) | (x[i].d); + uint64_t y_delta = ((uint64_t)y[i+3].d << 48) | ((uint64_t)y[i+2].d << 32) | ((uint64_t)y[i+1].d << 16) | (y[i].d); + + __m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta))); + __m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta))); + + __m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd)); + d = _mm256_permute2f128_ps(d ,d, 0); + + __m256i qx0 = bytes_from_nibbles_32(x[i].qs); + qx0 = _mm256_sub_epi8( qx0, off ); + __m256i qy0 = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q0 = mul_sum_i8_pairs_float(qx0, qy0); + + __m256i qx1 = bytes_from_nibbles_32(x[i + 1].qs); + qx1 = _mm256_sub_epi8( qx1, off ); + __m256i qy1 = _mm256_loadu_si256((const __m256i *)y[i + 1].qs); + + const __m256 q1 = mul_sum_i8_pairs_float(qx1, qy1); + + __m256i qx2 = bytes_from_nibbles_32(x[i + 2].qs); + qx2 = _mm256_sub_epi8( qx2, off ); + __m256i qy2 = _mm256_loadu_si256((const __m256i *)y[i + 2].qs); + + const __m256 q2 = mul_sum_i8_pairs_float(qx2, qy2); + + __m256i qx3 = bytes_from_nibbles_32(x[i + 3].qs); + qx3 = _mm256_sub_epi8( qx3, off ); + __m256i qy3 = _mm256_loadu_si256((const __m256i *)y[i + 3].qs); + + const __m256 q3 = mul_sum_i8_pairs_float(qx3, qy3); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 0), q0, acc ); + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 85), q1, acc ); + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 170), q2, acc ); + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 255), q3, acc ); + } + for(int i = nbmod; i < nb; i++) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d)) ); + __m256i qx = bytes_from_nibbles_32(x[i].qs); + qx = _mm256_sub_epi8( qx, off ); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( d, q, acc ); + } + *s = hsum_float_8(acc); +#elif defined(__AVX2__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + // Main loop + for (int i = 0; i < nb; ++i) { + /* Compute combined scale for the block */ + const __m256 d = _mm256_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d)) ); + + __m256i qx = bytes_from_nibbles_32(x[i].qs); + + // Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. + const __m256i off = _mm256_set1_epi8( 8 ); + qx = _mm256_sub_epi8( qx, off ); + + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + /* Multiply q with scale and accumulate */ + acc = _mm256_fmadd_ps( d, q, acc ); + } + *s = hsum_float_8(acc); +#elif defined(__AVX__) + // Initialize accumulator with zeros + __m256 acc = _mm256_setzero_ps(); + + // Main loop + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d))); + + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + const __m128i tmp = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + bx_0 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp, 4)); + by_0 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_0, by_0); + + // Convert int32_t to float + __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I(i32_0, i32_1)); + + // Apply the scale, and accumulate + acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc); + } + + *s = hsum_float_8(acc); +#elif defined(__SSSE3__) + // set constants + const __m128i lowMask = _mm_set1_epi8(0xF); + const __m128i off = _mm_set1_epi8(8); + + // Initialize accumulator with zeros + __m128 acc_0 = _mm_setzero_ps(); + __m128 acc_1 = _mm_setzero_ps(); + __m128 acc_2 = _mm_setzero_ps(); + __m128 acc_3 = _mm_setzero_ps(); + + // First round without accumulation + { + _mm_prefetch(&x[0] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d))); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[0].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[0].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[1] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[1] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[1].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[1].d)) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + acc_0 = _mm_mul_ps( d_0_1, p0 ); + acc_1 = _mm_mul_ps( d_0_1, p1 ); + acc_2 = _mm_mul_ps( d_2_3, p2 ); + acc_3 = _mm_mul_ps( d_2_3, p3 ); + } + + assert(nb % 2 == 0); // TODO: handle odd nb + + // Main loop + for (int i = 2; i < nb; i+=2) { + _mm_prefetch(&x[i] + sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 0 and 1 + const __m128 d_0_1 = _mm_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d)) ); + + const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[i].qs); + + __m128i bx_0 = _mm_and_si128(lowMask, tmp_0_1); + __m128i by_0 = _mm_loadu_si128((const __m128i *)y[i].qs); + bx_0 = _mm_sub_epi8(bx_0, off); + const __m128i i32_0 = mul_sum_i8_pairs(bx_0, by_0); + + __m128i bx_1 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_0_1, 4)); + __m128i by_1 = _mm_loadu_si128((const __m128i *)(y[i].qs + 16)); + bx_1 = _mm_sub_epi8(bx_1, off); + const __m128i i32_1 = mul_sum_i8_pairs(bx_1, by_1); + + _mm_prefetch(&x[i] + 2 * sizeof(block_q4_0), _MM_HINT_T0); + _mm_prefetch(&y[i] + 2 * sizeof(block_q8_0), _MM_HINT_T0); + + // Compute combined scale for the block 2 and 3 + const __m128 d_2_3 = _mm_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i + 1].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i + 1].d)) ); + + const __m128i tmp_2_3 = _mm_loadu_si128((const __m128i *)x[i + 1].qs); + + __m128i bx_2 = _mm_and_si128(lowMask, tmp_2_3); + __m128i by_2 = _mm_loadu_si128((const __m128i *)y[i + 1].qs); + bx_2 = _mm_sub_epi8(bx_2, off); + const __m128i i32_2 = mul_sum_i8_pairs(bx_2, by_2); + + __m128i bx_3 = _mm_and_si128(lowMask, _mm_srli_epi64(tmp_2_3, 4)); + __m128i by_3 = _mm_loadu_si128((const __m128i *)(y[i + 1].qs + 16)); + bx_3 = _mm_sub_epi8(bx_3, off); + const __m128i i32_3 = mul_sum_i8_pairs(bx_3, by_3); + + // Convert int32_t to float + __m128 p0 = _mm_cvtepi32_ps(i32_0); + __m128 p1 = _mm_cvtepi32_ps(i32_1); + __m128 p2 = _mm_cvtepi32_ps(i32_2); + __m128 p3 = _mm_cvtepi32_ps(i32_3); + + // Apply the scale + __m128 p0_d = _mm_mul_ps( d_0_1, p0 ); + __m128 p1_d = _mm_mul_ps( d_0_1, p1 ); + __m128 p2_d = _mm_mul_ps( d_2_3, p2 ); + __m128 p3_d = _mm_mul_ps( d_2_3, p3 ); + + // Acummulate + acc_0 = _mm_add_ps(p0_d, acc_0); + acc_1 = _mm_add_ps(p1_d, acc_1); + acc_2 = _mm_add_ps(p2_d, acc_2); + acc_3 = _mm_add_ps(p3_d, acc_3); + } + + *s = hsum_float_4x4(acc_0, acc_1, acc_2, acc_3); +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk/2; ++j) { + const int v0 = (x[i].qs[j] & 0x0F) - 8; + const int v1 = (x[i].qs[j] >> 4) - 8; + + sumi += (v0 * y[i].qs[j]) + (v1 * y[i].qs[j + qk/2]); + } + + sumf += sumi*GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d))*GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d)); + } + + *s = sumf; +#endif +} + void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { const int qk = QK8_1; const int nb = n / qk; @@ -5470,6 +5873,115 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * r *s = sumf; } +void ggml_vec_dot_q8_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { + const int qk = QK8_0; + const int nb = n / qk; + + assert(n % qk == 0); + assert(nrc == 1); + + UNUSED(nrc); + UNUSED(bx); + UNUSED(by); + UNUSED(bs); + + const block_q8_0 * restrict x = vx; + const block_q8_0 * restrict y = vy; + + +#if defined(__AVX512BF16__) + __m256 acc = _mm256_setzero_ps(); + __m128 zerovec = _mm_setzero_ps(); + int nbmod = nb - (nb % 4); + for (int i = 0; i < nbmod; i+=4) { + // Compute combined scale for set of four blocks + + uint64_t x_delta = ((uint64_t)x[i+3].d << 48) | ((uint64_t)x[i+2].d << 32) | ((uint64_t)x[i+1].d << 16) | (x[i].d); + uint64_t y_delta = ((uint64_t)y[i+3].d << 48) | ((uint64_t)y[i+2].d << 32) | ((uint64_t)y[i+1].d << 16) | (y[i].d); + + __m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta))); + __m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta))); + + __m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd)); + d = _mm256_permute2f128_ps(d ,d, 0); + + __m256i qx0 = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i qy0 = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q0 = mul_sum_i8_pairs_float(qx0, qy0); + + __m256i qx1 = _mm256_loadu_si256((const __m256i *)x[i + 1].qs); + __m256i qy1 = _mm256_loadu_si256((const __m256i *)y[i + 1].qs); + + const __m256 q1 = mul_sum_i8_pairs_float(qx1, qy1); + + __m256i qx2 = _mm256_loadu_si256((const __m256i *)x[i + 2].qs); + __m256i qy2 = _mm256_loadu_si256((const __m256i *)y[i + 2].qs); + + const __m256 q2 = mul_sum_i8_pairs_float(qx2, qy2); + + __m256i qx3 = _mm256_loadu_si256((const __m256i *)x[i + 3].qs); + __m256i qy3 = _mm256_loadu_si256((const __m256i *)y[i + 3].qs); + + const __m256 q3 = mul_sum_i8_pairs_float(qx3, qy3); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 0), q0, acc ); + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 85), q1, acc ); + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 170), q2, acc ); + acc = _mm256_fmadd_ps( _mm256_shuffle_ps(d, d, 255), q3, acc ); + } + for(int i = nbmod; i < nb; i++) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d)) ); + __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( d, q, acc ); + } + + *s = hsum_float_8(acc); + +#elif defined(__AVX2__) + // Main loop + __m256 acc = _mm256_setzero_ps(); + for (int i = 0; i < nb; ++i) { + // Compute combined scale for the block + const __m256 d = _mm256_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d)) ); + + __m256i qx = _mm256_loadu_si256((const __m256i *)x[i].qs); + __m256i qy = _mm256_loadu_si256((const __m256i *)y[i].qs); + + const __m256 q = mul_sum_i8_pairs_float(qx, qy); + + // Multiply q with scale and accumulate + acc = _mm256_fmadd_ps( d, q, acc ); + } + + *s = hsum_float_8(acc); +#else + // scalar + float sumf = 0.0; + + for (int i = 0; i < nb; i++) { + int sumi = 0; + + for (int j = 0; j < qk; j++) { + sumi += x[i].qs[j]*y[i].qs[j]; + } + + sumf += sumi*(GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d))*GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d))); + } + + *s = sumf; +#endif +} + + +#if QK_K == 256 void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); @@ -14602,6 +15114,28 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) { return true; } +static bool isinf_bf16(ggml_half f) { + return (f & 0x7fff) == 0x7f80; +} + +static bool isnan_bf16(ggml_half f) { + return (f & 0x7fff) > 0x7f80; +} + +static bool validate_bf16(ggml_half f, size_t i) { + if (isinf_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found inf value at block %zu\n", i); + return false; + } + + if (isnan_fp16(f)) { + fprintf(stderr, "ggml_validate_row_data: found nan value at block %zu\n", i); + return false; + } + + return true; +} + #define VALIDATE_ROW_DATA_D_F16_IMPL(type, data, nb) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ @@ -14618,6 +15152,7 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) { } \ } +<<<<<<< HEAD #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ @@ -14625,6 +15160,13 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) { if (!validate_fp16(q[i].d[j], i)) { \ return false; \ } \ +======= +#define VALIDATE_ROW_DATA_D_B16_IMPL(type, data, nb) \ + const type * q = (const type *) (data); \ + for (size_t i = 0; i < (nb); ++i) { \ + if (!validate_bf16((q[i].d), i)) { \ + return false; \ +>>>>>>> ed837022 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) } \ } @@ -14755,6 +15297,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_q4_0, data, nb); } break; + case GGML_TYPE_Q4_0_B16: + { + VALIDATE_ROW_DATA_D_B16_IMPL(block_q4_0, data, nb); + } break; case GGML_TYPE_Q4_1: { VALIDATE_ROW_DATA_DM_F16_IMPL(block_q4_1, data, nb, d, m); @@ -14771,6 +15317,10 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte { VALIDATE_ROW_DATA_D_F16_IMPL(block_q8_0, data, nb); } break; + case GGML_TYPE_Q8_0_B16: + { + VALIDATE_ROW_DATA_D_B16_IMPL(block_q8_0, data, nb); + } break; case GGML_TYPE_Q2_K: { VALIDATE_ROW_DATA_DM_F16_IMPL(block_q2_K, data, nb, d, dmin); diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 525d5ee30d8de..23d21c42d3833 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -33,10 +33,12 @@ void quantize_row_iq3_s_ref (const float * GGML_RESTRICT x, block_iq3_s * GGM void quantize_row_iq2_s_ref (const float * GGML_RESTRICT x, block_iq2_s * GGML_RESTRICT y, int64_t k); void quantize_row_q4_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_b16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q5_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_b16(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k); @@ -54,10 +56,12 @@ void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, // Dequantization void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q4_0_b16(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q5_1(const block_q5_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q8_0(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); +void dequantize_row_q8_0_b16(const block_q8_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); //void dequantize_row_q8_1(const block_q8_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); void dequantize_row_q2_K(const block_q2_K * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k); @@ -79,10 +83,12 @@ void dequantize_row_iq3_s (const block_iq3_s * GGML_RESTRICT x, float * GGML_ // Dot product void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q4_0_b16_q8_0_b16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); +void ggml_vec_dot_q8_0_b16_q8_0_b16(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q2_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, size_t bx, const void * GGML_RESTRICT vy, size_t by, int nrc); @@ -117,10 +123,12 @@ size_t quantize_q4_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, size_t quantize_q5_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q6_K(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q4_0_b16(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q4_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q5_1(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); size_t quantize_q8_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); +size_t quantize_q8_0_b16(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix); void iq2xs_init_impl(enum ggml_type type); void iq2xs_free_impl(enum ggml_type type); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 38990e3a05a3f..60c96c25878fb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -428,6 +428,12 @@ float ggml_bf16_to_fp32(ggml_bf16_t x) { return GGML_BF16_TO_FP32(x); // it just left shifts } +ggml_bf16_t ggml_make_bf16(uint16_t x) { + ggml_bf16_t bf16_value; + bf16_value.bits = x; + return bf16_value; +} + ggml_bf16_t ggml_fp32_to_bf16(float x) { #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml return GGML_FP32_TO_BF16(x); @@ -3304,6 +3310,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { enum ggml_type wtype = GGML_TYPE_COUNT; switch (ftype) { +<<<<<<< HEAD case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; @@ -3331,6 +3338,34 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; +======= + case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; + case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; + case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; + case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_0_B16: wtype = GGML_TYPE_Q4_0_B16; break; + case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; + case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; + case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; + case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q8_0_B16: wtype = GGML_TYPE_Q8_0_B16; break; + case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; + case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; + case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; + case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; + case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; + case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; + case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; + case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; + case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; + case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break; + case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; + case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; + case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; + case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; + case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; + case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; +>>>>>>> ed837022 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) } GGML_ASSERT(wtype != GGML_TYPE_COUNT); @@ -9561,10 +9596,12 @@ static void ggml_compute_forward_add( } } break; case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0_B16: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0_B16: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -9938,10 +9975,12 @@ static void ggml_compute_forward_add1( } } break; case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0_B16: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0_B16: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -10066,10 +10105,12 @@ static void ggml_compute_forward_acc( case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0_B16: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0_B16: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -12909,10 +12950,12 @@ static void ggml_compute_forward_out_prod( switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0_B16: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0_B16: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: case GGML_TYPE_Q4_K: @@ -13096,10 +13139,12 @@ static void ggml_compute_forward_set( case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0_B16: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0_B16: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -13358,10 +13403,12 @@ static void ggml_compute_forward_get_rows( switch (src0->type) { case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0_B16: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0_B16: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -13947,10 +13994,12 @@ static void ggml_compute_forward_clamp( case GGML_TYPE_F16: case GGML_TYPE_BF16: case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_0_B16: case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_1: case GGML_TYPE_Q8_0: + case GGML_TYPE_Q8_0_B16: case GGML_TYPE_Q8_1: case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: @@ -20645,10 +20694,12 @@ size_t ggml_quantize_chunk( switch (type) { case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q4_0_B16: result = quantize_q4_0_b16(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_0: result = quantize_q5_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q5_1: result = quantize_q5_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q8_0: result = quantize_q8_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; + case GGML_TYPE_Q8_0_B16: result = quantize_q8_0_b16(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q2_K: result = quantize_q2_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q3_K: result = quantize_q3_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; case GGML_TYPE_Q4_K: result = quantize_q4_K(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break; diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp index 6626ceb26213f..8cbfdcb15d6e7 100644 --- a/ggml/src/llamafile/sgemm.cpp +++ b/ggml/src/llamafile/sgemm.cpp @@ -72,6 +72,10 @@ inline float unhalf(ggml_fp16_t d) { return GGML_FP16_TO_FP32(d); } +inline float bf16_unhalf(ggml_bf16_t d) { + return GGML_BF16_TO_FP32(d); +} + //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED ARITHMETIC OPERATIONS @@ -813,6 +817,382 @@ class tinyBLAS_Q0_AVX { }; #endif // __AVX__ +#if defined(__AVX2__) || defined(__AVX512F__) +template +class tinyBLAS_Q0_B16_AVX { + public: + tinyBLAS_Q0_B16_AVX(int64_t k, + const TA *A, int64_t lda, + const TB *B, int64_t ldb, + TC *C, int64_t ldc, + int ith, int nth) + : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { + } + + void matmul(int64_t m, int64_t n, int task) { + if (task == GGML_TASK_TYPE_COMPUTE) + mnpack(0, m, 0, n); + } + + private: + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t mc, nc, mp, np; + switch ((MIN(m - m0, 4) << 4) | MIN(n - n0, 4)) { +#if VECTOR_REGISTERS == 32 + case 0x44: + mc = 4; + nc = 4; +#if defined(__AVX512BF16__) + gemm4xN<4>(m0, m, n0, n); +#else + gemm<4, 4>(m0, m, n0, n); +#endif + break; + case 0x43: + mc = 4; + nc = 3; +#if defined(__AVX512BF16__) + gemm4xN<3>(m0, m, n0, n); +#else + gemm<4, 3>(m0, m, n0, n); +#endif + break; + case 0x34: + mc = 3; + nc = 4; +#if defined(__AVX512BF16__) + gemmMx4<3>(m0, m, n0, n); +#else + gemm<3, 4>(m0, m, n0, n); +#endif + break; + case 0x33: + mc = 3; + nc = 3; + gemm<3, 3>(m0, m, n0, n); + break; + case 0x42: + mc = 4; + nc = 2; +#if defined(__AVX512BF16__) + gemm4xN<2>(m0, m, n0, n); +#else + gemm<4, 2>(m0, m, n0, n); +#endif + break; + case 0x24: + mc = 2; + nc = 4; +#if defined(__AVX512BF16__) + gemmMx4<2>(m0, m, n0, n); +#else + gemm<2, 4>(m0, m, n0, n); +#endif + break; +#else + case 0x44: + case 0x43: + case 0x42: + mc = 4; + nc = 2; +#if defined(__AVX512BF16__) + gemm4xN<2>(m0, m, n0, n); +#else + gemm<4, 2>(m0, m, n0, n); +#endif + break; + case 0x34: + case 0x24: + mc = 2; + nc = 4; +#if defined(__AVX512BF16__) + gemmMx4<2>(m0, m, n0, n); +#else + gemm<2, 4>(m0, m, n0, n); +#endif + break; + case 0x33: +#endif + case 0x32: + mc = 3; + nc = 2; + gemm<3, 2>(m0, m, n0, n); + break; + case 0x23: + mc = 2; + nc = 3; + gemm<2, 3>(m0, m, n0, n); + break; + case 0x41: + mc = 4; + nc = 1; +#if defined(__AVX512BF16__) + gemm4xN<1>(m0, m, n0, n); +#else + gemm<4, 1>(m0, m, n0, n); +#endif + break; + case 0x22: + mc = 2; + nc = 2; + gemm<2, 2>(m0, m, n0, n); + break; + case 0x14: + mc = 1; + nc = 4; +#if defined(__AVX512BF16__) + gemmMx4<1>(m0, m, n0, n); +#else + gemm<1, 4>(m0, m, n0, n); +#endif + break; + case 0x31: + mc = 3; + nc = 1; + gemm<3, 1>(m0, m, n0, n); + break; + case 0x13: + mc = 1; + nc = 3; + gemm<1, 3>(m0, m, n0, n); + break; + case 0x21: + mc = 2; + nc = 1; + gemm<2, 1>(m0, m, n0, n); + break; + case 0x12: + mc = 1; + nc = 2; + gemm<1, 2>(m0, m, n0, n); + break; + case 0x11: + mc = 1; + nc = 1; + gemm<1, 1>(m0, m, n0, n); + break; + default: + return; + } + mp = m0 + (m - m0) / mc * mc; + np = n0 + (n - n0) / nc * nc; + mnpack(mp, m, n0, np); + mnpack(m0, m, np, n); + } + +#if defined(__AVX512BF16__) + template + NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / 4; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * 4; + int64_t jj = n0 + job % xtiles * RN; + __m256 Cv[RN][4] = {}; + __m128 zerovec = _mm_setzero_ps(); + for (int64_t l = 0; l < k; ++l) { + uint64_t a_delta = ((uint64_t)A[lda * (ii + 3) + l].d << 48) | ((uint64_t)A[lda * (ii + 2) + l].d << 32) | ((uint64_t)A[lda * (ii + 1) + l].d << 16) | (A[lda * (ii + 0) + l].d); + __m128bh da = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, a_delta))); + __m256i avec0 = load(A + lda * (ii + 0) + l); + __m256i avec1 = load(A + lda * (ii + 1) + l); + __m256i avec2 = load(A + lda * (ii + 2) + l); + __m256i avec3 = load(A + lda * (ii + 3) + l); + for (int64_t j = 0; j < RN; ++j) { + __m128bh db = m128bh(_mm_set1_epi16(B[ldb * (jj + j) + l].d)); + __m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db)); + dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); + Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0), + updot(_mm256_sign_epi8(avec0, avec0), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec0)), + Cv[j][0]); + Cv[j][1] = madd(_mm256_shuffle_ps(dvec, dvec, 85), + updot(_mm256_sign_epi8(avec1, avec1), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec1)), + Cv[j][1]); + Cv[j][2] = madd(_mm256_shuffle_ps(dvec, dvec, 170), + updot(_mm256_sign_epi8(avec2, avec2), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec2)), + Cv[j][2]); + Cv[j][3] = madd(_mm256_shuffle_ps(dvec, dvec, 255), + updot(_mm256_sign_epi8(avec3, avec3), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), avec3)), + Cv[j][3]); + } + } + + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < 4; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + template + NOINLINE void gemmMx4(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / 4; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * 4; + __m256 Cv[4][RM] = {}; + __m128 zerovec = _mm_setzero_ps(); + for (int64_t l = 0; l < k; ++l) { + uint64_t b_delta = ((uint64_t)B[ldb * (jj + 3) + l].d << 48) | ((uint64_t)B[ldb * (jj + 2) + l].d << 32) | ((uint64_t)B[ldb * (jj + 1) + l].d << 16) | (B[ldb * (jj + 0) + l].d); + __m128bh db = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, b_delta))); + __m256i bvec0 = load(B + ldb * (jj + 0) + l); + __m256i bvec1 = load(B + ldb * (jj + 1) + l); + __m256i bvec2 = load(B + ldb * (jj + 2) + l); + __m256i bvec3 = load(B + ldb * (jj + 3) + l); + for (int64_t i = 0; i < RM; ++i) { + __m128bh da = m128bh(_mm_set1_epi16((A[lda * (ii + i) + l].d))); + __m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db)); + dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); + Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec0, load(A + lda * (ii + i) + l))), + Cv[0][i]); + Cv[1][i] = madd(_mm256_shuffle_ps(dvec, dvec, 85), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec1, load(A + lda * (ii + i) + l))), + Cv[1][i]); + Cv[2][i] = madd(_mm256_shuffle_ps(dvec, dvec, 170), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec2, load(A + lda * (ii + i) + l))), + Cv[2][i]); + Cv[3][i] = madd(_mm256_shuffle_ps(dvec, dvec, 255), + updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(bvec3, load(A + lda * (ii + i) + l))), + Cv[3][i]); + } + } + for (int64_t j = 0; j < 4; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } +#endif + + template + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + int64_t ytiles = (m - m0) / RM; + int64_t xtiles = (n - n0) / RN; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) + end = tiles; + for (int64_t job = start; job < end; ++job) { + int64_t ii = m0 + job / xtiles * RM; + int64_t jj = n0 + job % xtiles * RN; + __m256 Cv[RN][RM] = {}; + for (int64_t l = 0; l < k; ++l) + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) { +#if defined(__AVX2__) + __m256 udTmp = updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), + load(A + lda * (ii + i) + l)), + _mm256_sign_epi8(load(B + ldb * (jj + j) + l), + load(A + lda * (ii + i) + l))); +#else + __m128i ali0 = load0(A + lda * (ii + i) + l); + __m128i ali1 = load1(A + lda * (ii + i) + l); + __m128i blj0 = load0(B + ldb * (jj + j) + l); + __m128i blj1 = load1(B + ldb * (jj + j) + l); + + __m128i sepAA0 = _mm_sign_epi8(ali0, ali0); + __m128i sepAA1 = _mm_sign_epi8(ali1, ali1); + __m128i sepBA0 = _mm_sign_epi8(blj0, ali0); + __m128i sepBA1 = _mm_sign_epi8(blj1, ali1); + + // updot + const __m128i oneFill = _mm_set1_epi16(1); + __m128i mad0 = _mm_maddubs_epi16(sepAA0, sepBA0); + __m128i mad1 = _mm_maddubs_epi16(sepAA1, sepBA1); + __m256 udTmp = _mm256_cvtepi32_ps(MM256_SET_M128I(_mm_madd_epi16(oneFill, mad1), _mm_madd_epi16(oneFill, mad0))); +#endif + + Cv[j][i] = madd(_mm256_set1_ps(bf16_unhalf(ggml_make_bf16(A[lda * (ii + i) + l].d)) * + bf16_unhalf(ggml_make_bf16(B[ldb * (jj + j) + l].d))), + udTmp, + Cv[j][i]); + } + for (int64_t j = 0; j < RN; ++j) + for (int64_t i = 0; i < RM; ++i) + C[ldc * (jj + j) + (ii + i)] = hsum(Cv[j][i]); + } + } + + inline __m256i load(const block_q8_0 *b) { + return _mm256_loadu_si256((const __m256i *)b->qs); + } + + inline __m128i load0(const block_q8_0 *b) { + return _mm_loadu_si128((const __m128i *)b->qs); + } + + inline __m128i load1(const block_q8_0 *b) { + return _mm_loadu_si128(((const __m128i *)b->qs) + 1); + } + + inline __m256i load(const block_q4_0 *b) { + return _mm256_sub_epi8(denibble(b->qs), _mm256_set1_epi8(8)); + } + + inline __m128i load0(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), x), _mm_set1_epi8(8)); + } + + inline __m128i load1(const block_q4_0 *b) { + const __m128i x = _mm_loadu_si128((const __m128i *)(b->qs)); + return _mm_sub_epi8(_mm_and_si128(_mm_set1_epi8(15), _mm_srli_epi16(x, 4)), _mm_set1_epi8(8)); + } + + inline __m256 updot(__m256i u, __m256i s) { + __m256i res; +#if defined(__AVXVNNI__) || (defined(__AVX512VNNI__) && defined(__AVX512VL__)) + res = _mm256_dpbusd_epi32(_mm256_setzero_si256(), u, s); +#else + res = _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(u, s)); +#endif + return _mm256_cvtepi32_ps(res); + } + + static inline __m256i denibble(const uint8_t *p) { + __m128i x = _mm_loadu_si128((const __m128i *)p); + return _mm256_and_si256(_mm256_set1_epi8(15), + _mm256_insertf128_si256(_mm256_castsi128_si256(x), + _mm_srli_epi16(x, 4), 1)); + } + + const TA *const A; + const TB *const B; + TC *const C; + const int64_t k; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; +#endif // __AVX2__ + } // namespace /** @@ -1006,6 +1386,38 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda #endif } + case GGML_TYPE_Q8_0_B16: { + if (Btype != GGML_TYPE_Q8_0_B16) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + tinyBLAS_Q0_B16_AVX tb{ + k, (const block_q8_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#else + return false; +#endif + } + + case GGML_TYPE_Q4_0_B16: { + if (Btype != GGML_TYPE_Q8_0_B16) + return false; +#if defined(__AVX2__) || defined(__AVX512F__) + tinyBLAS_Q0_B16_AVX tb{ + k, (const block_q4_0 *)A, lda, + (const block_q8_0 *)B, ldb, + (float *)C, ldc, + ith, nth}; + tb.matmul(m, n, task); + return true; +#else + return false; +#endif + } + default: return false; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f63ec450a4e09..741a31e455667 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1254,6 +1254,7 @@ def get_type(val: Any) -> GGUFValueType: # Items here are (block size, type size) QK_K = 256 GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { +<<<<<<< HEAD GGMLQuantizationType.F32: (1, 4), GGMLQuantizationType.F16: (1, 2), GGMLQuantizationType.Q4_0: (32, 2 + 16), @@ -1286,6 +1287,39 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), +======= + GGMLQuantizationType.F32: (1, 4), + GGMLQuantizationType.F16: (1, 2), + GGMLQuantizationType.Q4_0: (32, 2 + 16), + GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), + GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), + GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), + GGMLQuantizationType.Q8_0: (32, 2 + 32), + GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), + GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), + GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), + GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), + GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), + GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), + GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), + GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), + GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), + GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), + GGMLQuantizationType.IQ4_NL: (32, 2 + 16), + GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), + GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), + GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), + GGMLQuantizationType.I8: (1, 1), + GGMLQuantizationType.I16: (1, 2), + GGMLQuantizationType.I32: (1, 4), + GGMLQuantizationType.I64: (1, 8), + GGMLQuantizationType.F64: (1, 8), + GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), + GGMLQuantizationType.BF16: (1, 2), + GGMLQuantizationType.Q4_0_B16: (32, 2 + 16), + GGMLQuantizationType.Q8_0_B16: (32, 2 + 32), +>>>>>>> 2f13a1e6 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) } diff --git a/src/llama.cpp b/src/llama.cpp index aaf8db496ecbd..10c96924d87c4 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7757,6 +7757,7 @@ static int llama_model_load(const std::string & fname, llama_model & model, llam model.ftype == LLAMA_FTYPE_MOSTLY_F16 || model.ftype == LLAMA_FTYPE_MOSTLY_BF16 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || + model.ftype == LLAMA_FTYPE_MOSTLY_Q4_0_B16 || model.ftype == LLAMA_FTYPE_MOSTLY_Q4_1 ) )) { @@ -15479,7 +15480,7 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n ftype == LLAMA_FTYPE_MOSTLY_IQ1_M) { new_type = GGML_TYPE_Q5_K; } - else if (new_type != GGML_TYPE_Q8_0) { + else if ((new_type != GGML_TYPE_Q8_0) && (new_type != GGML_TYPE_Q8_0_B16)) { new_type = GGML_TYPE_Q6_K; } } @@ -15620,12 +15621,12 @@ static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type n else if (ftype == LLAMA_FTYPE_MOSTLY_Q4_K_S && arch != LLM_ARCH_FALCON && i_layer < n_layer/8) { new_type = GGML_TYPE_Q5_K; } - else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0) + else if ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0 || ftype == LLAMA_FTYPE_MOSTLY_Q5_0 || ftype == LLAMA_FTYPE_MOSTLY_Q4_0_B16) && qs.has_imatrix && i_layer < n_layer/8) { // Guard against craziness in the first few ffn_down layers that can happen even with imatrix for Q4_0/Q5_0. // We only do it when an imatrix is provided because a) we want to make sure that one can always get the // same quantization as before imatrix stuff, and b) Q4_1/Q5_1 do go crazy on ffn_down without an imatrix. - new_type = ftype == LLAMA_FTYPE_MOSTLY_Q4_0 ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; + new_type = ((ftype == LLAMA_FTYPE_MOSTLY_Q4_0) || (ftype == LLAMA_FTYPE_MOSTLY_Q4_0_B16)) ? GGML_TYPE_Q4_1 : GGML_TYPE_Q5_1; } ++qs.i_ffn_down; } else if (name.find("attn_output.weight") != std::string::npos) { @@ -15781,14 +15782,16 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s llama_ftype ftype = params->ftype; switch (params->ftype) { - case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; - case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; - case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; - case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; - case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; - case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; - case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; - case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; + case LLAMA_FTYPE_MOSTLY_Q4_0: default_type = GGML_TYPE_Q4_0; break; + case LLAMA_FTYPE_MOSTLY_Q4_0_B16: default_type = GGML_TYPE_Q4_0_B16; break; + case LLAMA_FTYPE_MOSTLY_Q4_1: default_type = GGML_TYPE_Q4_1; break; + case LLAMA_FTYPE_MOSTLY_Q5_0: default_type = GGML_TYPE_Q5_0; break; + case LLAMA_FTYPE_MOSTLY_Q5_1: default_type = GGML_TYPE_Q5_1; break; + case LLAMA_FTYPE_MOSTLY_Q8_0: default_type = GGML_TYPE_Q8_0; break; + case LLAMA_FTYPE_MOSTLY_Q8_0_B16: default_type = GGML_TYPE_Q8_0_B16; break; + case LLAMA_FTYPE_MOSTLY_F16: default_type = GGML_TYPE_F16; break; + case LLAMA_FTYPE_MOSTLY_BF16: default_type = GGML_TYPE_BF16; break; + case LLAMA_FTYPE_ALL_F32: default_type = GGML_TYPE_F32; break; // K-quants case LLAMA_FTYPE_MOSTLY_Q2_K_S: From 983b03ab6a07de83c79173b70be5915ff7a90774 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Thu, 23 May 2024 08:49:41 -0700 Subject: [PATCH 02/12] Add additional comments --- ggml/src/ggml-quants.c | 2 ++ ggml/src/llamafile/sgemm.cpp | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 041a5cfd69aa9..883d7ba39045a 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -4368,6 +4368,7 @@ void ggml_vec_dot_q4_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const __m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta))); __m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta))); + // Computes product of delta values from four corresponding blocks __m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd)); d = _mm256_permute2f128_ps(d ,d, 0); @@ -5902,6 +5903,7 @@ void ggml_vec_dot_q8_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const __m128bh xd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, x_delta))); __m128bh yd = m128bh(_mm_cvtepu16_epi32(_mm_set_epi64x(0, y_delta))); + // Computes product of delta values from four corresponding blocks __m256 d = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, xd, yd)); d = _mm256_permute2f128_ps(d ,d, 0); diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp index 8cbfdcb15d6e7..f55944eaee58c 100644 --- a/ggml/src/llamafile/sgemm.cpp +++ b/ggml/src/llamafile/sgemm.cpp @@ -981,6 +981,7 @@ class tinyBLAS_Q0_B16_AVX { } #if defined(__AVX512BF16__) +// Templated functions for gemm of dimesnions 4xN template NOINLINE void gemm4xN(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / 4; @@ -1005,6 +1006,7 @@ class tinyBLAS_Q0_B16_AVX { __m256i avec3 = load(A + lda * (ii + 3) + l); for (int64_t j = 0; j < RN; ++j) { __m128bh db = m128bh(_mm_set1_epi16(B[ldb * (jj + j) + l].d)); + // Computation of product of delta values for four blocks __m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db)); dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); Cv[j][0] = madd(_mm256_shuffle_ps(dvec, dvec, 0), @@ -1056,7 +1058,8 @@ class tinyBLAS_Q0_B16_AVX { __m256i bvec3 = load(B + ldb * (jj + 3) + l); for (int64_t i = 0; i < RM; ++i) { __m128bh da = m128bh(_mm_set1_epi16((A[lda * (ii + i) + l].d))); - __m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db)); + // Computation of product of delta values for four blocks + __m256 dvec = _mm256_castps128_ps256(_mm_dpbf16_ps(zerovec, da, db)); dvec = _mm256_permute2f128_ps(dvec ,dvec, 0); Cv[0][i] = madd(_mm256_shuffle_ps(dvec, dvec, 0), updot(_mm256_sign_epi8(load(A + lda * (ii + i) + l), From 9e5174ce5d72aacf93946a5f16ea2916567e1911 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Fri, 24 May 2024 07:23:03 -0700 Subject: [PATCH 03/12] Remove additional ifdef conditions --- ggml/src/ggml-quants.c | 1 - 1 file changed, 1 deletion(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 883d7ba39045a..d4acf953e811a 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -5983,7 +5983,6 @@ void ggml_vec_dot_q8_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const } -#if QK_K == 256 void ggml_vec_dot_q2_K_q8_K(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { assert(nrc == 1); UNUSED(nrc); From c480818d9770b11c6be94304cf0cb9b6147e597f Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Fri, 24 May 2024 10:33:54 -0700 Subject: [PATCH 04/12] Fix issues with SSE3 version for vec_dot_q4_0_b16_q8_0_b16 --- ggml/src/ggml-quants.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index d4acf953e811a..674551da1a726 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -4486,7 +4486,7 @@ void ggml_vec_dot_q4_0_b16_q8_0_b16(int n, float * restrict s, size_t bs, const _mm_prefetch(&y[0] + sizeof(block_q8_0), _MM_HINT_T0); // Compute combined scale for the block 0 and 1 - const __m128 d_0_1 = _mm_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[i].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[i].d))); + const __m128 d_0_1 = _mm_set1_ps( GGML_BF16_TO_FP32(ggml_make_bf16(x[0].d)) * GGML_BF16_TO_FP32(ggml_make_bf16(y[0].d))); const __m128i tmp_0_1 = _mm_loadu_si128((const __m128i *)x[0].qs); From 5a6a235ac7918420f044948b74cae1af8257b26d Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Sun, 7 Jul 2024 19:29:04 -0700 Subject: [PATCH 05/12] Fix build issues in sgemm.cpp post rebase --- ggml/src/llamafile/sgemm.cpp | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/ggml/src/llamafile/sgemm.cpp b/ggml/src/llamafile/sgemm.cpp index f55944eaee58c..669c99c546104 100644 --- a/ggml/src/llamafile/sgemm.cpp +++ b/ggml/src/llamafile/sgemm.cpp @@ -829,9 +829,8 @@ class tinyBLAS_Q0_B16_AVX { : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { } - void matmul(int64_t m, int64_t n, int task) { - if (task == GGML_TASK_TYPE_COMPUTE) - mnpack(0, m, 0, n); + void matmul(int64_t m, int64_t n) { + mnpack(0, m, 0, n); } private: @@ -1398,7 +1397,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; - tb.matmul(m, n, task); + tb.matmul(m, n); return true; #else return false; @@ -1414,7 +1413,7 @@ bool llamafile_sgemm(int64_t m, int64_t n, int64_t k, const void *A, int64_t lda (const block_q8_0 *)B, ldb, (float *)C, ldc, ith, nth}; - tb.matmul(m, n, task); + tb.matmul(m, n); return true; #else return false; From db6657eeaf21db05f25cd03f3556174dec59f423 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Tue, 30 Jul 2024 09:46:22 -0700 Subject: [PATCH 06/12] Fic more conflicts in quantize.cpp --- ggml/include/ggml.h | 9 ++- ggml/src/ggml-quants.c | 136 +++++++++++++++++++++++++++++++++++++++-- ggml/src/ggml-quants.h | 2 + ggml/src/ggml.c | 26 +++++++- include/llama.h | 2 + src/llama.cpp | 2 + 6 files changed, 166 insertions(+), 11 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fc25aa7101245..b55fd1b8b6f4f 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -392,6 +392,8 @@ extern "C" { GGML_TYPE_Q4_0_4_4 = 31, GGML_TYPE_Q4_0_4_8 = 32, GGML_TYPE_Q4_0_8_8 = 33, + GGML_TYPE_Q4_0_B16 = 34, + GGML_TYPE_Q8_0_B16 = 35, GGML_TYPE_COUNT, }; @@ -433,14 +435,11 @@ extern "C" { GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors -<<<<<<< HEAD GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors -======= - GGML_FTYPE_MOSTLY_Q4_0_B16 = 25, // except 1d tensors - GGML_FTYPE_MOSTLY_Q8_0_B16 = 26, // except 1d tensors ->>>>>>> ed837022 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) + GGML_FTYPE_MOSTLY_Q4_0_B16 = 28, // except 1d tensors + GGML_FTYPE_MOSTLY_Q8_0_B16 = 29, // except 1d tensors }; // available tensor operations: diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 674551da1a726..66db1e46d067b 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -700,7 +700,7 @@ void quantize_row_q4_0(const float * restrict x, void * restrict y, int64_t k) { } // reference implementation for deterministic creation of model files -void quantize_row_q4_0_b16_reference(const float * restrict x, block_q4_0 * restrict y, int64_t k) { +void quantize_row_q4_0_b16_ref(const float * restrict x, block_q4_0 * restrict y, int64_t k) { static const int qk = QK4_0; assert(k % qk == 0); @@ -738,7 +738,7 @@ void quantize_row_q4_0_b16_reference(const float * restrict x, block_q4_0 * rest } void quantize_row_q4_0_b16(const float * restrict x, void * restrict y, int64_t k) { - quantize_row_q4_0_b16_reference(x, y, k); + quantize_row_q4_0_b16_ref(x, y, k); } @@ -1190,6 +1190,132 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k) #endif } +void quantize_row_q8_0_b16_ref(const float * restrict x, block_q8_0 * restrict y, int64_t k) { + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + for (int i = 0; i < nb; i++) { + float amax = 0.0f; // absolute max + + for (int j = 0; j < QK8_0; j++) { + const float v = x[i*QK8_0 + j]; + amax = MAX(amax, fabsf(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + y[i].d = (GGML_FP32_TO_BF16(d)).bits; + + for (int j = 0; j < QK8_0; ++j) { + const float x0 = x[i*QK8_0 + j]*id; + + y[i].qs[j] = roundf(x0); + } + } +} + +void quantize_row_q8_0_b16(const float * restrict x, void * restrict vy, int64_t k) { + assert(QK8_0 == 32); + assert(k % QK8_0 == 0); + const int nb = k / QK8_0; + + block_q8_0 * restrict y = vy; + +#if defined(__AVX2__) || defined(__AVX__) + for (int i = 0; i < nb; i++) { + // Load elements into 4 AVX vectors + __m256 v0 = _mm256_loadu_ps( x ); + __m256 v1 = _mm256_loadu_ps( x + 8 ); + __m256 v2 = _mm256_loadu_ps( x + 16 ); + __m256 v3 = _mm256_loadu_ps( x + 24 ); + x += 32; + + // Compute max(abs(e)) for the block + const __m256 signBit = _mm256_set1_ps( -0.0f ); + __m256 maxAbs = _mm256_andnot_ps( signBit, v0 ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v1 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v2 ) ); + maxAbs = _mm256_max_ps( maxAbs, _mm256_andnot_ps( signBit, v3 ) ); + + __m128 max4 = _mm_max_ps( _mm256_extractf128_ps( maxAbs, 1 ), _mm256_castps256_ps128( maxAbs ) ); + max4 = _mm_max_ps( max4, _mm_movehl_ps( max4, max4 ) ); + max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4 ) ); + const float maxScalar = _mm_cvtss_f32( max4 ); + + // Quantize these floats + const float d = maxScalar / 127.f; + + y[i].d = (GGML_FP32_TO_BF16(d)).bits; + + const float id = ( maxScalar != 0.0f ) ? 127.f / maxScalar : 0.0f; + const __m256 mul = _mm256_set1_ps( id ); + + // Apply the multiplier + v0 = _mm256_mul_ps( v0, mul ); + v1 = _mm256_mul_ps( v1, mul ); + v2 = _mm256_mul_ps( v2, mul ); + v3 = _mm256_mul_ps( v3, mul ); + + // Round to nearest integer + v0 = _mm256_round_ps( v0, _MM_ROUND_NEAREST ); + v1 = _mm256_round_ps( v1, _MM_ROUND_NEAREST ); + v2 = _mm256_round_ps( v2, _MM_ROUND_NEAREST ); + v3 = _mm256_round_ps( v3, _MM_ROUND_NEAREST ); + + // Convert floats to integers + __m256i i0 = _mm256_cvtps_epi32( v0 ); + __m256i i1 = _mm256_cvtps_epi32( v1 ); + __m256i i2 = _mm256_cvtps_epi32( v2 ); + __m256i i3 = _mm256_cvtps_epi32( v3 ); + +#if defined(__AVX2__) + // Convert int32 to int16 + i0 = _mm256_packs_epi32( i0, i1 ); // 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15 + i2 = _mm256_packs_epi32( i2, i3 ); // 16, 17, 18, 19, 24, 25, 26, 27, 20, 21, 22, 23, 28, 29, 30, 31 + // Convert int16 to int8 + i0 = _mm256_packs_epi16( i0, i2 ); // 0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27, 4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31 + + // We got our precious signed bytes, but the order is now wrong + // These AVX2 pack instructions process 16-byte pieces independently + // The following instruction is fixing the order + const __m256i perm = _mm256_setr_epi32( 0, 4, 1, 5, 2, 6, 3, 7 ); + i0 = _mm256_permutevar8x32_epi32( i0, perm ); + + _mm256_storeu_si256((__m256i *)y[i].qs, i0); +#else + // Since we don't have in AVX some necessary functions, + // we split the registers in half and call AVX2 analogs from SSE + __m128i ni0 = _mm256_castsi256_si128( i0 ); + __m128i ni1 = _mm256_extractf128_si256( i0, 1); + __m128i ni2 = _mm256_castsi256_si128( i1 ); + __m128i ni3 = _mm256_extractf128_si256( i1, 1); + __m128i ni4 = _mm256_castsi256_si128( i2 ); + __m128i ni5 = _mm256_extractf128_si256( i2, 1); + __m128i ni6 = _mm256_castsi256_si128( i3 ); + __m128i ni7 = _mm256_extractf128_si256( i3, 1); + + // Convert int32 to int16 + ni0 = _mm_packs_epi32( ni0, ni1 ); + ni2 = _mm_packs_epi32( ni2, ni3 ); + ni4 = _mm_packs_epi32( ni4, ni5 ); + ni6 = _mm_packs_epi32( ni6, ni7 ); + // Convert int16 to int8 + ni0 = _mm_packs_epi16( ni0, ni2 ); + ni4 = _mm_packs_epi16( ni4, ni6 ); + + _mm_storeu_si128((__m128i *)(y[i].qs + 0), ni0); + _mm_storeu_si128((__m128i *)(y[i].qs + 16), ni4); +#endif + } +#else + GGML_UNUSED(nb); + // scalar + quantize_row_q8_0_b16_ref(x, y, k); +#endif +} + + // reference implementation for deterministic creation of model files void quantize_row_q8_1_ref(const float * restrict x, block_q8_1 * restrict y, int64_t k) { assert(QK8_1 == 32); @@ -3217,7 +3343,7 @@ static void quantize_row_q4_0_b16_impl(const float * restrict x, block_q4_0 * re static_assert(QK4_0 == 32, "QK4_0 must be 32"); if (!quant_weights) { - quantize_row_q4_0_b16_reference(x, y, n_per_row); + quantize_row_q4_0_b16_ref(x, y, n_per_row); return; } @@ -3258,7 +3384,7 @@ size_t quantize_q4_0(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_q4_0_b16(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { if (!quant_weights) { - quantize_row_q4_0_b16_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q4_0_b16_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * ggml_row_size(GGML_TYPE_Q4_0_B16, n_per_row); } size_t row_size = ggml_row_size(GGML_TYPE_Q4_0_B16, n_per_row); @@ -3433,7 +3559,7 @@ size_t quantize_q8_0(const float * restrict src, void * restrict dst, int64_t nr size_t quantize_q8_0_b16(const float * restrict src, void * restrict dst, int64_t nrow, int64_t n_per_row, const float * quant_weights) { (void)quant_weights; // not used const size_t row_size = ggml_row_size(GGML_TYPE_Q8_0_B16, n_per_row); - quantize_row_q8_0_b16_reference(src, dst, (int64_t)nrow*n_per_row); + quantize_row_q8_0_b16_ref(src, dst, (int64_t)nrow*n_per_row); return nrow * row_size; } diff --git a/ggml/src/ggml-quants.h b/ggml/src/ggml-quants.h index 23d21c42d3833..69e11fa22565d 100644 --- a/ggml/src/ggml-quants.h +++ b/ggml/src/ggml-quants.h @@ -13,10 +13,12 @@ extern "C" { // Quantization void quantize_row_q4_0_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q4_0_b16_ref(const float * GGML_RESTRICT x, block_q4_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q4_1_ref(const float * GGML_RESTRICT x, block_q4_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_0_ref(const float * GGML_RESTRICT x, block_q5_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q5_1_ref(const float * GGML_RESTRICT x, block_q5_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_0_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); +void quantize_row_q8_0_b16_ref(const float * GGML_RESTRICT x, block_q8_0 * GGML_RESTRICT y, int64_t k); void quantize_row_q8_1_ref(const float * GGML_RESTRICT x, block_q8_1 * GGML_RESTRICT y, int64_t k); void quantize_row_q2_K_ref(const float * GGML_RESTRICT x, block_q2_K * GGML_RESTRICT y, int64_t k); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 60c96c25878fb..b8f5075bb1c4a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1033,7 +1033,31 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .ncols = 8, .gemv = ggml_gemv_q4_0_8x8_q8_0, .gemm = ggml_gemm_q4_0_8x8_q8_0, - } + }, + [GGML_TYPE_Q4_0_B16] = { + .type_name = "q4_0_b16", + .blck_size = QK4_0, + .type_size = sizeof(block_q4_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q4_0_b16, + .from_float = quantize_row_q4_0_b16, + .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_b16_ref, + .vec_dot = ggml_vec_dot_q4_0_b16_q8_0_b16, + .vec_dot_type = GGML_TYPE_Q8_0_B16, + .nrows = 1, + }, + [GGML_TYPE_Q8_0_B16] = { + .type_name = "q8_0_b16", + .blck_size = QK8_0, + .type_size = sizeof(block_q8_0), + .is_quantized = true, + .to_float = (ggml_to_float_t) dequantize_row_q8_0_b16, + .from_float = quantize_row_q8_0_b16, + .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_b16_ref, + .vec_dot = ggml_vec_dot_q8_0_b16_q8_0_b16, + .vec_dot_type = GGML_TYPE_Q8_0_B16, + .nrows = 1, + }, }; // For internal test use diff --git a/include/llama.h b/include/llama.h index ce07f4fac8f10..19d014c99399c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -166,6 +166,8 @@ extern "C" { LLAMA_FTYPE_MOSTLY_Q4_0_4_4 = 33, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_4_8 = 34, // except 1d tensors LLAMA_FTYPE_MOSTLY_Q4_0_8_8 = 35, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q4_0_B16 = 36, // except 1d tensors + LLAMA_FTYPE_MOSTLY_Q8_0_B16 = 37, // except 1d tensors LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file }; diff --git a/src/llama.cpp b/src/llama.cpp index 10c96924d87c4..5d1b877008b24 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3791,6 +3791,8 @@ struct llama_model_loader { case GGML_TYPE_Q4_0_4_4: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_4; break; case GGML_TYPE_Q4_0_4_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_4_8; break; case GGML_TYPE_Q4_0_8_8: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_8_8; break; + case GGML_TYPE_Q4_0_B16: ftype = LLAMA_FTYPE_MOSTLY_Q4_0_B16; break; + case GGML_TYPE_Q8_0_B16: ftype = LLAMA_FTYPE_MOSTLY_Q8_0_B16; break; default: { LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max)); From f7ce1322580f16a5442b0b50b8dddaca432c29ec Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Sun, 14 Jul 2024 23:19:28 -0700 Subject: [PATCH 07/12] Add changes to fix compiler issues --- ggml/src/ggml-quants.c | 6 +++--- ggml/src/ggml.c | 35 ++++------------------------------- 2 files changed, 7 insertions(+), 34 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index 66db1e46d067b..1c669d186eddc 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -15279,7 +15279,6 @@ static bool validate_bf16(ggml_half f, size_t i) { } \ } -<<<<<<< HEAD #define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ @@ -15287,13 +15286,14 @@ static bool validate_bf16(ggml_half f, size_t i) { if (!validate_fp16(q[i].d[j], i)) { \ return false; \ } \ -======= + } \ + } + #define VALIDATE_ROW_DATA_D_B16_IMPL(type, data, nb) \ const type * q = (const type *) (data); \ for (size_t i = 0; i < (nb); ++i) { \ if (!validate_bf16((q[i].d), i)) { \ return false; \ ->>>>>>> ed837022 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) } \ } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b8f5075bb1c4a..506fc4ed5bfbe 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1041,7 +1041,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q4_0_b16, .from_float = quantize_row_q4_0_b16, - .from_float_reference = (ggml_from_float_t) quantize_row_q4_0_b16_ref, + .from_float_ref = (ggml_from_float_t) quantize_row_q4_0_b16_ref, .vec_dot = ggml_vec_dot_q4_0_b16_q8_0_b16, .vec_dot_type = GGML_TYPE_Q8_0_B16, .nrows = 1, @@ -1053,7 +1053,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = { .is_quantized = true, .to_float = (ggml_to_float_t) dequantize_row_q8_0_b16, .from_float = quantize_row_q8_0_b16, - .from_float_reference = (ggml_from_float_t) quantize_row_q8_0_b16_ref, + .from_float_ref = (ggml_from_float_t) quantize_row_q8_0_b16_ref, .vec_dot = ggml_vec_dot_q8_0_b16_q8_0_b16, .vec_dot_type = GGML_TYPE_Q8_0_B16, .nrows = 1, @@ -3334,15 +3334,16 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { enum ggml_type wtype = GGML_TYPE_COUNT; switch (ftype) { -<<<<<<< HEAD case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; + case GGML_FTYPE_MOSTLY_Q4_0_B16: wtype = GGML_TYPE_Q4_0_B16; break; case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; + case GGML_FTYPE_MOSTLY_Q8_0_B16: wtype = GGML_TYPE_Q8_0_B16; break; case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; @@ -3362,34 +3363,6 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { case GGML_FTYPE_MOSTLY_Q4_0_8_8: wtype = GGML_TYPE_Q4_0_8_8; break; case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; -======= - case GGML_FTYPE_ALL_F32: wtype = GGML_TYPE_F32; break; - case GGML_FTYPE_MOSTLY_F16: wtype = GGML_TYPE_F16; break; - case GGML_FTYPE_MOSTLY_BF16: wtype = GGML_TYPE_BF16; break; - case GGML_FTYPE_MOSTLY_Q4_0: wtype = GGML_TYPE_Q4_0; break; - case GGML_FTYPE_MOSTLY_Q4_0_B16: wtype = GGML_TYPE_Q4_0_B16; break; - case GGML_FTYPE_MOSTLY_Q4_1: wtype = GGML_TYPE_Q4_1; break; - case GGML_FTYPE_MOSTLY_Q5_0: wtype = GGML_TYPE_Q5_0; break; - case GGML_FTYPE_MOSTLY_Q5_1: wtype = GGML_TYPE_Q5_1; break; - case GGML_FTYPE_MOSTLY_Q8_0: wtype = GGML_TYPE_Q8_0; break; - case GGML_FTYPE_MOSTLY_Q8_0_B16: wtype = GGML_TYPE_Q8_0_B16; break; - case GGML_FTYPE_MOSTLY_Q2_K: wtype = GGML_TYPE_Q2_K; break; - case GGML_FTYPE_MOSTLY_Q3_K: wtype = GGML_TYPE_Q3_K; break; - case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break; - case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break; - case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break; - case GGML_FTYPE_MOSTLY_IQ2_XXS: wtype = GGML_TYPE_IQ2_XXS; break; - case GGML_FTYPE_MOSTLY_IQ2_XS: wtype = GGML_TYPE_IQ2_XS; break; - case GGML_FTYPE_MOSTLY_IQ3_XXS: wtype = GGML_TYPE_IQ3_XXS; break; - case GGML_FTYPE_MOSTLY_IQ1_S: wtype = GGML_TYPE_IQ1_S; break; - case GGML_FTYPE_MOSTLY_IQ1_M: wtype = GGML_TYPE_IQ1_M; break; - case GGML_FTYPE_MOSTLY_IQ4_NL: wtype = GGML_TYPE_IQ4_NL; break; - case GGML_FTYPE_MOSTLY_IQ4_XS: wtype = GGML_TYPE_IQ4_XS; break; - case GGML_FTYPE_MOSTLY_IQ3_S: wtype = GGML_TYPE_IQ3_S; break; - case GGML_FTYPE_MOSTLY_IQ2_S: wtype = GGML_TYPE_IQ2_S; break; - case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break; - case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break; ->>>>>>> ed837022 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) } GGML_ASSERT(wtype != GGML_TYPE_COUNT); From ef693e99d73dcc4c6ddb1de72d9b80f7cec01a42 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Mon, 12 Aug 2024 05:58:11 -0700 Subject: [PATCH 08/12] Add the new data types across files --- examples/quantize/quantize.cpp | 2 ++ src/llama.cpp | 2 ++ 2 files changed, 4 insertions(+) diff --git a/examples/quantize/quantize.cpp b/examples/quantize/quantize.cpp index 7312309aeef98..ff477e4f4f18f 100644 --- a/examples/quantize/quantize.cpp +++ b/examples/quantize/quantize.cpp @@ -17,6 +17,7 @@ struct quant_option { static const std::vector QUANT_OPTIONS = { { "Q4_0", LLAMA_FTYPE_MOSTLY_Q4_0, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, + { "Q4_0_B16", LLAMA_FTYPE_MOSTLY_Q4_0_B16, " 3.56G, 5.9624 +/- 0.03348 ppl @ LLaMA-v2-7B", }, { "Q4_1", LLAMA_FTYPE_MOSTLY_Q4_1, " 4.78G, +0.4511 ppl @ Llama-3-8B", }, { "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 5.21G, +0.1316 ppl @ Llama-3-8B", }, { "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 5.65G, +0.1062 ppl @ Llama-3-8B", }, @@ -46,6 +47,7 @@ static const std::vector QUANT_OPTIONS = { { "Q5_K_M", LLAMA_FTYPE_MOSTLY_Q5_K_M, " 5.33G, +0.0569 ppl @ Llama-3-8B", }, { "Q6_K", LLAMA_FTYPE_MOSTLY_Q6_K, " 6.14G, +0.0217 ppl @ Llama-3-8B", }, { "Q8_0", LLAMA_FTYPE_MOSTLY_Q8_0, " 7.96G, +0.0026 ppl @ Llama-3-8B", }, + { "Q8_0_B16", LLAMA_FTYPE_MOSTLY_Q8_0_B16, " 6.70G, 5.8011 +/- 0.03239 ppl @ LLaMA-v1-7B", }, { "Q4_0_4_4", LLAMA_FTYPE_MOSTLY_Q4_0_4_4, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_4_8", LLAMA_FTYPE_MOSTLY_Q4_0_4_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, { "Q4_0_8_8", LLAMA_FTYPE_MOSTLY_Q4_0_8_8, " 4.34G, +0.4685 ppl @ Llama-3-8B", }, diff --git a/src/llama.cpp b/src/llama.cpp index 5d1b877008b24..07523ed97f281 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4460,10 +4460,12 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_F16: return "F16"; case LLAMA_FTYPE_MOSTLY_BF16: return "BF16"; case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0"; + case LLAMA_FTYPE_MOSTLY_Q4_0_B16: return "Q4_0_B16"; case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1"; case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; + case LLAMA_FTYPE_MOSTLY_Q8_0_B16: return "Q4_0_B16"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; From 4a2f703fbb22ed2d6ce1ba3b7c3850062e4c2fbd Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Thu, 1 Aug 2024 23:57:12 -0700 Subject: [PATCH 09/12] Add changes to use union data type for better conversion to strong type - Based on 5f2e011e2eed2f685521c707b3e74280fcb81dd3 from llamafile --- ggml/include/ggml.h | 1 - ggml/src/ggml-impl.h | 9 +++++++++ ggml/src/ggml.c | 6 ------ 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b55fd1b8b6f4f..1701b15388cad 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -346,7 +346,6 @@ extern "C" { // google brain half-precision bfloat16 typedef struct { uint16_t bits; } ggml_bf16_t; - GGML_API ggml_bf16_t ggml_make_bf16(uint16_t val); GGML_API ggml_bf16_t ggml_fp32_to_bf16(float); GGML_API float ggml_bf16_to_fp32(ggml_bf16_t); // consider just doing << 16 GGML_API void ggml_bf16_to_fp32_row(const ggml_bf16_t *, float *, int64_t); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 3f7963f660281..c2ead27168d85 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -102,6 +102,15 @@ static inline ggml_bf16_t ggml_compute_fp32_to_bf16(float s) { return h; } +static inline ggml_bf16_t ggml_make_bf16(uint16_t h) { + union { + ggml_bf16_t f; + uint16_t i; + } u; + u.i = h; + return u.f; +} + #define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x) #define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 506fc4ed5bfbe..642ca73338f59 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -428,12 +428,6 @@ float ggml_bf16_to_fp32(ggml_bf16_t x) { return GGML_BF16_TO_FP32(x); // it just left shifts } -ggml_bf16_t ggml_make_bf16(uint16_t x) { - ggml_bf16_t bf16_value; - bf16_value.bits = x; - return bf16_value; -} - ggml_bf16_t ggml_fp32_to_bf16(float x) { #define ggml_fp32_to_bf16 do_not_use__ggml_fp32_to_bf16__in_ggml return GGML_FP32_TO_BF16(x); From 7927655e42525a7bcce9dc18388b1ece3f56f587 Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Tue, 6 Aug 2024 01:36:54 -0700 Subject: [PATCH 10/12] Update the type name in llama.cpp --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index 07523ed97f281..8e03f38d893b0 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -4465,7 +4465,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) { case LLAMA_FTYPE_MOSTLY_Q5_0: return "Q5_0"; case LLAMA_FTYPE_MOSTLY_Q5_1: return "Q5_1"; case LLAMA_FTYPE_MOSTLY_Q8_0: return "Q8_0"; - case LLAMA_FTYPE_MOSTLY_Q8_0_B16: return "Q4_0_B16"; + case LLAMA_FTYPE_MOSTLY_Q8_0_B16: return "Q8_0_B16"; case LLAMA_FTYPE_MOSTLY_Q2_K: return "Q2_K - Medium"; case LLAMA_FTYPE_MOSTLY_Q2_K_S: return "Q2_K - Small"; case LLAMA_FTYPE_MOSTLY_Q3_K_S: return "Q3_K - Small"; From 43c7be57c10a04a18e6122312fcd77070239980e Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Mon, 12 Aug 2024 06:02:16 -0700 Subject: [PATCH 11/12] Add the BF16 delta data types in constants.py --- gguf-py/gguf/constants.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 741a31e455667..ba449f461edcc 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1166,6 +1166,8 @@ class GGMLQuantizationType(IntEnum): Q4_0_4_4 = 31 Q4_0_4_8 = 32 Q4_0_8_8 = 33 + Q4_0_B16 = 34 + Q8_0_B16 = 35 # TODO: add GGMLFileType from ggml_ftype in ggml.h From 806c5a4e5b586b1b086bd86bb301385017a1ca4b Mon Sep 17 00:00:00 2001 From: Srihari-mcw Date: Tue, 20 Aug 2024 04:02:17 -0700 Subject: [PATCH 12/12] Remove additional snippets for CI/CD issues with python constants.py script --- gguf-py/gguf/constants.py | 36 ++---------------------------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index ba449f461edcc..de8e4921148c0 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -1256,7 +1256,6 @@ def get_type(val: Any) -> GGUFValueType: # Items here are (block size, type size) QK_K = 256 GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = { -<<<<<<< HEAD GGMLQuantizationType.F32: (1, 4), GGMLQuantizationType.F16: (1, 2), GGMLQuantizationType.Q4_0: (32, 2 + 16), @@ -1289,39 +1288,8 @@ def get_type(val: Any) -> GGUFValueType: GGMLQuantizationType.Q4_0_4_4:(32, 2 + 16), GGMLQuantizationType.Q4_0_4_8:(32, 2 + 16), GGMLQuantizationType.Q4_0_8_8:(32, 2 + 16), -======= - GGMLQuantizationType.F32: (1, 4), - GGMLQuantizationType.F16: (1, 2), - GGMLQuantizationType.Q4_0: (32, 2 + 16), - GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16), - GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16), - GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16), - GGMLQuantizationType.Q8_0: (32, 2 + 32), - GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32), - GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4), - GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12), - GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12), - GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12), - GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8), - GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4), - GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32), - GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8), - GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16), - GGMLQuantizationType.IQ4_NL: (32, 2 + 16), - GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4), - GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16), - GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64), - GGMLQuantizationType.I8: (1, 1), - GGMLQuantizationType.I16: (1, 2), - GGMLQuantizationType.I32: (1, 4), - GGMLQuantizationType.I64: (1, 8), - GGMLQuantizationType.F64: (1, 8), - GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32), - GGMLQuantizationType.BF16: (1, 2), - GGMLQuantizationType.Q4_0_B16: (32, 2 + 16), - GGMLQuantizationType.Q8_0_B16: (32, 2 + 32), ->>>>>>> 2f13a1e6 (Introduce Q4_0 and Q8_0 quantizations with BF16 delta values) + GGMLQuantizationType.Q4_0_B16:(32, 2 + 16), + GGMLQuantizationType.Q8_0_B16:(32, 2 + 32), }