From 1f0fea70fb761d10e2264cbdcf4852ed32706c89 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 Aug 2024 10:43:42 -0400 Subject: [PATCH 01/92] llama : initial Mamba-2 support --- convert_hf_to_gguf.py | 67 ++++++++ ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 193 ++++++++++++++-------- gguf-py/gguf/constants.py | 19 +++ gguf-py/gguf/gguf_writer.py | 3 + gguf-py/gguf/tensor_mapping.py | 6 +- src/llama.cpp | 291 +++++++++++++++++++++++++++++++-- 7 files changed, 495 insertions(+), 87 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 108c822cff5d2..0ac64574a3043 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2788,6 +2788,73 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@Model.register("Mamba2ForCausalLM") +class Mamba2Model(Model): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + elif (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b8a21a2ccc3f0..59e0022dd4286 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1787,7 +1787,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * D); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index d63c917a5705a..6668209081b6c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7270,32 +7270,48 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * D) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; - - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; + + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); + GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(D->ne[0] == n_head); + GGML_ASSERT(ggml_is_vector(D)); + + if (ggml_is_vector(A)) { + // Mamba-2 + GGML_ASSERT(A->ne[0] == n_head); + } else { + // Mamba-1 + GGML_ASSERT(A->ne[0] == d_state); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + } } bool is_node = false; @@ -7316,6 +7332,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = D; return result; } @@ -15840,20 +15857,25 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s - const struct ggml_tensor * src1 = dst->src[1]; // x - const struct ggml_tensor * src2 = dst->src[2]; // dt - const struct ggml_tensor * src3 = dst->src[3]; // A - const struct ggml_tensor * src4 = dst->src[4]; // B - const struct ggml_tensor * src5 = dst->src[5]; // C + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src0->ne[3]; // number of sequences in the batch + + const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); @@ -15862,51 +15884,86 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); - - // rows per thread - const int dr = (nr + nth - 1)/nth; - - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; - - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = (float *) ((char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = (float *) ((char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations + GGML_ASSERT(src6->nb[0] == sizeof(float)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); + + // heads per thread + const int dh = (nh + nth - 1)/nth; + + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); + + for (int i3 = 0; i3 < ns; ++i3) { + for (int i2 = 0; i2 < nt; ++i2) { + const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + const float * D = (const float *) ((const char *) src6->data); // {nh} + float * y = (float *) ((char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} + float * s = (float *) ((char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + + // use the output as the source when it's not the first token-wise iteration if (i2 > 0) { s0 = s; } - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + if (ggml_is_vector(src3)) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); + + // TODO: SIMD implementation + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } + } + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int i = i1 + h*nr; + const float x_dt = x[i] * dt_soft_plus; + float sumf = 0.0f; + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int ii = i0 + i*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[ii] = state; + } + y[i] = sumf + x[i] * D[h]; + } } - y[i1] = sumf; } } } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index b55effa9907b1..32a2fb20f84b9 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -130,6 +130,7 @@ class SSM: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class Tokenizer: @@ -208,6 +209,7 @@ class MODEL_ARCH(IntEnum): GEMMA2 = auto() STARCODER2 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() DBRX = auto() @@ -269,6 +271,7 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() ATTN_Q_A = auto() ATTN_Q_B = auto() @@ -338,6 +341,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.GEMMA2: "gemma2", MODEL_ARCH.STARCODER2: "starcoder2", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.DBRX: "dbrx", @@ -399,6 +403,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.ATTN_Q_A: "blk.{bid}.attn_q_a", MODEL_TENSOR.ATTN_Q_B: "blk.{bid}.attn_q_b", @@ -869,6 +874,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1373,6 +1391,7 @@ def get_type(val: Any) -> GGUFValueType: KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index af3b98c679b0b..ea788918dbf2c 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -730,6 +730,9 @@ def add_ssm_state_size(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index a4f185c0658a3..8593a80a5ab8f 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -396,7 +396,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.SSM_IN: ( @@ -429,6 +429,10 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.D", ), + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 + ), + MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", diff --git a/src/llama.cpp b/src/llama.cpp index bd7f1508b2644..5be0ef7a2ac7a 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -198,6 +198,7 @@ enum llm_arch { LLM_ARCH_GEMMA2, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_DBRX, @@ -245,6 +246,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA2, "gemma2" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_DBRX, "dbrx" }, @@ -328,6 +330,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_TOKENIZER_MODEL, @@ -427,7 +430,8 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -517,6 +521,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_ATTN_Q_A, LLM_TENSOR_ATTN_Q_B, @@ -1068,6 +1073,22 @@ static const std::map> LLM_TENSOR_NA { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -2239,6 +2260,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -2289,6 +2311,7 @@ struct llama_hparams { if (this->ssm_d_inner != other.ssm_d_inner) return true; if (this->ssm_d_state != other.ssm_d_state) return true; if (this->ssm_dt_rank != other.ssm_dt_rank) return true; + if (this->ssm_n_group != other.ssm_n_group) return true; if (this->ssm_dt_b_c_rms != other.ssm_dt_b_c_rms) return true; if (this->dec_start_token_id != other.dec_start_token_id) return true; @@ -2357,7 +2380,7 @@ struct llama_hparams { // corresponds to Mamba's conv_states size // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t n_embd_v_s() const { // dimension of the recurrent state embeddings @@ -2419,6 +2442,7 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm; struct ggml_tensor * attn_norm_cross; struct ggml_tensor * attn_norm_enc; + struct ggml_tensor * ssm_norm; // attention struct ggml_tensor * wq; @@ -5573,6 +5597,38 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: model.type = e_model::MODEL_SMALL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: model.type = e_model::MODEL_MEDIUM; break; + case 1536: model.type = e_model::MODEL_LARGE; break; + case 2048: model.type = e_model::MODEL_XL; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -6404,6 +6460,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } @@ -7639,7 +7696,7 @@ static bool llm_load_tensors( layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2*d_inner}); - layer.ssm_conv1d = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner}); layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner}); layer.ssm_x = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_X, "weight", i), {d_inner, dt_rank + 2*d_state}); @@ -7648,9 +7705,61 @@ static bool llm_load_tensors( layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {d_inner}); // no "weight" suffix for these - layer.ssm_a = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}); layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {d_inner}); + // out_proj + layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = n_embd / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + model.tok_embd = ml.create_tensor(ctx_input, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}); + + // output + { + model.output_norm = ml.create_tensor(ctx_output, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}); + + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (model.output == NULL) { + model.output = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + ggml_context * ctx_layer = ctx_for_layer(i); + ggml_context * ctx_split = ctx_for_layer_split(i); + + auto & layer = model.layers[i]; + + // norm + layer.attn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}); + + layer.ssm_in = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}); + + layer.ssm_conv1d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}); + layer.ssm_conv1d_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}); + + layer.ssm_dt_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}); + + // no "weight" suffix for these + layer.ssm_a = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_A, i), {n_head}); + layer.ssm_d = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_D, i), {n_head}); + + layer.ssm_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}); + // out_proj layer.ssm_out = ml.create_tensor(ctx_split, tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}); } @@ -9041,6 +9150,8 @@ static struct ggml_tensor * llm_build_mamba( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = batch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -9064,7 +9175,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, graph, ssm_states_all, state_copy, state_mask, hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_3d(ctx, ssm, d_state, d_inner, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9113,8 +9224,8 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * x_db = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_x, x); // split struct ggml_tensor * dt = ggml_view_3d(ctx, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - struct ggml_tensor * B = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - struct ggml_tensor * C = ggml_view_3d(ctx, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + struct ggml_tensor * B = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + struct ggml_tensor * C = ggml_view_4d(ctx, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -9127,23 +9238,23 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); // store last states ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} - y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -9157,6 +9268,136 @@ static struct ggml_tensor * llm_build_mamba( return cur; } +static struct ggml_tensor * llm_build_mamba2( + struct ggml_context * ctx, + struct llama_context & lctx, + const llama_ubatch & batch, + struct ggml_cgraph * graph, + struct ggml_tensor * cur, + struct ggml_tensor * state_copy, + struct ggml_tensor * state_mask, + int32_t kv_head, + int32_t n_kv, + const llm_build_cb & cb, + int il) { + const llama_model & model = lctx.model; + const llama_hparams & hparams = model.hparams; + const llama_kv_cache & kv = lctx.kv_self; + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = batch.n_seqs; + + const int64_t n_seq_tokens = batch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(batch.equal_seqs); + GGML_ASSERT(batch.n_tokens == n_seq_tokens * n_seqs); + + struct ggml_tensor * conv_states_all = kv.k_l[il]; + struct ggml_tensor * ssm_states_all = kv.v_l[il]; + + // (ab)using the KV cache to store the states + struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, + graph, conv_states_all, state_copy, state_mask, + hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); + conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, + graph, ssm_states_all, state_copy, state_mask, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); + + // split the above in three + struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + struct ggml_tensor * conv_x = ggml_concat(ctx, conv, ggml_transpose(ctx, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + struct ggml_tensor * last_conv = ggml_view_3d(ctx, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(graph, + ggml_cpy(ctx, last_conv, + ggml_view_1d(ctx, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + struct ggml_tensor * x = ggml_view_4d(ctx, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + struct ggml_tensor * B = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + struct ggml_tensor * C = ggml_view_4d(ctx, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + + // store last states + ggml_build_forward_expand(graph, + ggml_cpy(ctx, + ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); + + // grouped RMS norm + y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = llm_build_norm(ctx, y, hparams, + model.layers[il].ssm_norm, NULL, + LLM_NORM_RMS, cb, il); + y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; +} + struct llm_build_context { const llama_model & model; llama_context & lctx; @@ -12788,7 +13029,7 @@ struct llm_build_context { return gf; } - struct ggml_cgraph * build_mamba() { + struct ggml_cgraph * build_mamba(int32_t version = 1) { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, llama_model_max_nodes(model), false); struct ggml_tensor * cur; @@ -12807,9 +13048,19 @@ struct llm_build_context { LLM_NORM_RMS, cb, il); cb(cur, "attn_norm", il); - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + switch (version) { + case 2: + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + case 1: + default: + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, + state_copy, state_mask, + kv_head, n_kv, cb, il); + break; + } if (il == n_layer - 1) { // skip computing output for unused tokens @@ -14858,7 +15109,11 @@ static struct ggml_cgraph * llama_build_graph( } break; case LLM_ARCH_MAMBA: { - result = llm.build_mamba(); + result = llm.build_mamba(/* version */ 1); + } break; + case LLM_ARCH_MAMBA2: + { + result = llm.build_mamba(/* version */ 2); } break; case LLM_ARCH_XVERSE: { @@ -17954,6 +18209,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -18125,6 +18381,7 @@ llama_token llama_model_decoder_start_token(const struct llama_model * model) { bool llama_model_is_recurrent(const struct llama_model * model) { switch (model->arch) { + case LLM_ARCH_MAMBA2: case LLM_ARCH_MAMBA: return true; default: return false; } From dceff23faec99945d3161d24ea209a0c433546db Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 21:49:39 -0400 Subject: [PATCH 02/92] ggml : SIMD ggml_ssm_scan for Mamba-2 * ggml : improve ggml_mul speed when masking recurrent states --- ggml/src/ggml.c | 95 +++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 6668209081b6c..f8e708088b357 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10207,7 +10207,37 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (nb10 == sizeof(float)) { + if (ne00 > 1 && ne10 == 1) { + // fast broadcast path + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + const float scale = src1_ptr[0]; + + if (scale == 0.0f) { + // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, + // but it is useful when resetting the state of recurrent models. + memset((char *)dst->data + ir*nb1, 0, nb1); + } else { + if (dst->data != src0->data) { + // src0 is same shape as dst => same indices + memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + } + if (scale != 1.0f) { + ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); + } + } + } + } else if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); @@ -15919,23 +15949,56 @@ static void ggml_compute_forward_ssm_scan_f32( const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; const float dA = expf(dt_soft_plus * A[h]); - // TODO: SIMD implementation // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; +#if defined(GGML_SIMD) + const int np = (nc & ~(GGML_F32_STEP - 1)); + + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; + + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); +#else + const int np = 0; +#endif // d_state - for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * dA) + (B[ig] * x_dt); + const float state = (s0[i] * dA) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } else { @@ -15948,20 +16011,22 @@ static void ggml_compute_forward_ssm_scan_f32( // dim for (int i1 = 0; i1 < nr; ++i1) { - const int i = i1 + h*nr; - const float x_dt = x[i] * dt_soft_plus; + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. // d_state for (int i0 = 0; i0 < nc; ++i0) { - const int ii = i0 + i*nc; + const int i = i0 + ii*nc; const int ig = i0 + (h & (ng - 1))*nc; // state = prev_state * dA + dB * x - const float state = (s0[ii] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); // y = rowwise_dotprod(state, C) sumf += state * C[ig]; - s[ii] = state; + s[i] = state; } - y[i] = sumf + x[i] * D[h]; + y[ii] = sumf + x[ii] * D[h]; } } } From 2bfe9de6d3a3598d4b778f9b144bb8ac33c2797b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sun, 18 Aug 2024 22:43:39 -0400 Subject: [PATCH 03/92] llama : support running Mamba-Codestral-7B-v0.1 --- convert_hf_to_gguf.py | 4 ++++ src/llama.cpp | 2 +- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 0ac64574a3043..a5bdd5def2029 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2843,6 +2843,10 @@ def set_gguf_parameters(self): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: del bid # unused + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + if name.endswith(".dt_bias"): name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" diff --git a/src/llama.cpp b/src/llama.cpp index 5be0ef7a2ac7a..fd80361bd7605 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9383,7 +9383,7 @@ static struct ggml_tensor * llm_build_mamba2( // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); y = llm_build_norm(ctx, y, hparams, - model.layers[il].ssm_norm, NULL, + ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); From aff96920f972d8e042dfdef6dc08644cd8df0234 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 16:28:07 -0400 Subject: [PATCH 04/92] llama : fix Mamba-2 conv state saving * ggml : make the ggml_mul fast broadcast path more consistently formatted --- ggml/src/ggml.c | 4 ++-- src/llama.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f8e708088b357..415fa6901304a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10226,11 +10226,11 @@ static void ggml_compute_forward_mul_f32( if (scale == 0.0f) { // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, // but it is useful when resetting the state of recurrent models. - memset((char *)dst->data + ir*nb1, 0, nb1); + memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); } else { if (dst->data != src0->data) { // src0 is same shape as dst => same indices - memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float)); + memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); } if (scale != 1.0f) { ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); diff --git a/src/llama.cpp b/src/llama.cpp index fd80361bd7605..03f93164a89e8 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -9335,7 +9335,7 @@ static struct ggml_tensor * llm_build_mamba2( ggml_cpy(ctx, last_conv, ggml_view_1d(ctx, conv_states_all, (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), - kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all)))); + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); // 1D convolution // The equivalent is to make a self-overlapping view of conv_x From e04910dc48966f1cbc7309d12b8e1b55bdd33df2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 21 Aug 2024 23:06:22 -0400 Subject: [PATCH 05/92] llama : remove unused variable --- src/llama.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index 03f93164a89e8..dda3d51b017d6 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7718,7 +7718,6 @@ static bool llm_load_tensors( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = n_embd / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; @@ -9287,7 +9286,7 @@ static struct ggml_tensor * llm_build_mamba2( const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t n_head = hparams.ssm_dt_rank; - const int64_t head_dim = d_inner / n_head; // FIXME + const int64_t head_dim = d_inner / n_head; const int64_t n_group = hparams.ssm_n_group; const int64_t n_seqs = batch.n_seqs; From fa358e707132ace9012cb90880abe86fd32464a6 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 01:13:43 -0400 Subject: [PATCH 06/92] llama : add missing break --- src/llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama.cpp b/src/llama.cpp index dda3d51b017d6..5b6b6707a1c95 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -5628,7 +5628,7 @@ static void llm_load_hparams( } break; default: model.type = e_model::MODEL_UNKNOWN; } - } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); From 38913dc8ddd1e119df0e0cfcacfb260b9b1f5c02 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 22 Aug 2024 14:31:12 -0400 Subject: [PATCH 07/92] convert_hf : prefer SentencePiece tokenizer for Mamba-2 when present The tokenzier.json of Mamba-Codestral-7B-v0.1 otherwise requires workarounds to work correctly. --- convert_hf_to_gguf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index a5bdd5def2029..4851926b7b98f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -2801,13 +2801,13 @@ def set_vocab(self): vocab_size = -(vocab_size // -pad_vocab) * pad_vocab self.hparams["vocab_size"] = vocab_size - if (self.dir_model / "tokenizer.json").is_file(): - self._set_vocab_gpt2() - elif (self.dir_model / "tokenizer.model").is_file(): + if (self.dir_model / "tokenizer.model").is_file(): self._set_vocab_sentencepiece() elif (self.dir_model / "tokenizer.model.v3").is_file(): # mamba-codestral raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() else: # Use the GPT-NeoX tokenizer when no tokenizer files are present self._set_vocab_builtin("gpt-neox", vocab_size) From 273e7a495ad8c93bb9ba8123c1a3de3c68f93cf9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 30 Sep 2024 15:52:42 -0400 Subject: [PATCH 08/92] llama : avoid redundant state copy for Mamba 1 and 2 --- ggml/include/ggml.h | 3 +- ggml/src/ggml.c | 50 ++++++------ src/llama.cpp | 154 +++++++++++++++++-------------------- tests/test-backend-ops.cpp | 54 ++++++++++--- 4 files changed, 142 insertions(+), 119 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fec6798ff6d06..1fc53bebebf30 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1833,7 +1833,8 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D); + struct ggml_tensor * D, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 12e4f26942f86..1c4c393e55d06 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7598,7 +7598,8 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D) { + struct ggml_tensor * D, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); @@ -7609,6 +7610,7 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; @@ -7623,21 +7625,19 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(ggml_is_3d(dt)); GGML_ASSERT(s->ne[1] == head_dim); GGML_ASSERT(s->ne[2] == n_head); - GGML_ASSERT(s->ne[3] == n_seqs); GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); GGML_ASSERT(D->ne[0] == n_head); GGML_ASSERT(ggml_is_vector(D)); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); - if (ggml_is_vector(A)) { - // Mamba-2 - GGML_ASSERT(A->ne[0] == n_head); - } else { - // Mamba-1 + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == n_head); - GGML_ASSERT(ggml_is_matrix(A)); } } @@ -7649,7 +7649,7 @@ struct ggml_tensor * ggml_ssm_scan( } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; @@ -7660,6 +7660,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[4] = B; result->src[5] = C; result->src[6] = D; + result->src[7] = ids; return result; } @@ -16635,13 +16636,14 @@ static void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs} + const struct ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} const struct ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} const struct ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} - const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {n_head} + const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} + const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16651,11 +16653,12 @@ static void ggml_compute_forward_ssm_scan_f32( const int64_t nh = src1->ne[1]; // n_head const int64_t ng = src4->ne[1]; const int64_t nt = src1->ne[2]; // number of tokens per sequence - const int64_t ns = src0->ne[3]; // number of sequences in the batch + const int64_t ns = src1->ne[3]; // number of sequences in the batch - const int64_t s_off = ggml_element_size(src1) * ggml_nelements(src1); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); @@ -16663,6 +16666,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); GGML_ASSERT(src6->nb[0] == sizeof(float)); + GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16673,22 +16677,22 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); + const int32_t * ids = (const int32_t *) src7->data; + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + for (int i2 = 0; i2 < nt; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + i3*(src0->nb[3])); // {d_state, dim, nh, ns} const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} - const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {nh} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} - float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} - - // use the output as the source when it's not the first token-wise iteration - if (i2 > 0) { s0 = s; } - if (ggml_is_vector(src3)) { + if (src3->ne[0] == 1) { // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop // n_head @@ -16778,6 +16782,8 @@ static void ggml_compute_forward_ssm_scan_f32( } } } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } } } diff --git a/src/llama.cpp b/src/llama.cpp index c11472112f8fb..3e1f8755ffb85 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2801,6 +2801,10 @@ struct llama_kv_cache { // computed before each graph build uint32_t n = 0; + // first zero-ed state + // NOTE: only used by recurrent models + int32_t rs_z = -1; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; @@ -3381,8 +3385,6 @@ struct llama_context { struct ggml_tensor * inp_mean; // F32 [n_batch, n_batch] struct ggml_tensor * inp_cls; // I32 [n_batch] struct ggml_tensor * inp_s_copy; // I32 [kv_size] - struct ggml_tensor * inp_s_mask; // F32 [1, n_kv] - struct ggml_tensor * inp_s_seq; // I32 [n_kv, n_batch] struct ggml_tensor * inp_pos_bucket; // I32 [n_batch|n_kv, n_batch] struct ggml_tensor * inp_embd_enc; // F32 [n_embd, n_outputs_enc] struct ggml_tensor * inp_KQ_mask_cross; // F32 [n_outputs_enc, n_batch] @@ -3813,6 +3815,15 @@ static bool llama_kv_cache_find_slot( } } + // Find first to-be-cleared cell + cache.rs_z = -1; + for (int i = min; i <= max; ++i) { + if (cache.cells[i].src == -1) { + cache.rs_z = i; + break; + } + } + // allow getting the range of used cells, from head to head + n cache.head = min; cache.n = max - min + 1; @@ -9569,36 +9580,42 @@ static struct ggml_tensor * llm_build_kv( return cur; } -static struct ggml_tensor * llm_build_copy_mask_state( +static struct ggml_tensor * llm_build_rs( struct ggml_context * ctx, struct ggml_cgraph * graph, struct ggml_tensor * s, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t n_state, int32_t kv_size, int32_t kv_head, int32_t n_kv, - int32_t n_seqs) { + int32_t n_seqs, + bool avoid_copies = false) { struct ggml_tensor * states = ggml_reshape_2d(ctx, s, n_state, kv_size); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx, states, state_copy); - - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx, states, state_mask); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + struct ggml_tensor * state_zero = ggml_view_1d(ctx, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(graph, ggml_scale_inplace(ctx, state_zero, 0)); // copy states which won't be changed further (between n_seqs and n_kv) + struct ggml_tensor * states_extra = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(graph, ggml_cpy(ctx, - ggml_view_1d(ctx, states, n_state*(n_kv - n_seqs), n_seqs*n_state*ggml_element_size(states)), + states_extra, ggml_view_1d(ctx, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); - // the part of the states that will be used and modified - return ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx, states, ggml_view_1d(ctx, state_copy, n_seqs, 0)); + // the part of the states that will be used and modified + states = ggml_view_2d(ctx, states, n_state, n_seqs, states->nb[1], 0); + } + + return states; } // TODO: split @@ -9609,7 +9626,7 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9639,14 +9656,14 @@ static struct ggml_tensor * llm_build_mamba( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9711,10 +9728,11 @@ static struct ggml_tensor * llm_build_mamba( x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -9746,7 +9764,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_cgraph * graph, struct ggml_tensor * cur, struct ggml_tensor * state_copy, - struct ggml_tensor * state_mask, + int32_t rs_zero, int32_t kv_head, int32_t n_kv, const llm_build_cb & cb, @@ -9772,14 +9790,14 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * ssm_states_all = kv.v_l[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * conv = llm_build_copy_mask_state(ctx, - graph, conv_states_all, state_copy, state_mask, + struct ggml_tensor * conv = llm_build_rs(ctx, + graph, conv_states_all, state_copy, rs_zero, hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs); conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - struct ggml_tensor * ssm = llm_build_copy_mask_state(ctx, - graph, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs); - ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, n_seqs); + struct ggml_tensor * ssm = llm_build_rs(ctx, + graph, ssm_states_all, state_copy, rs_zero, + hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true); + ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9835,9 +9853,12 @@ static struct ggml_tensor * llm_build_mamba2( // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); + // Use the same shape semantics for A as Mamba-1 + struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10069,6 +10090,7 @@ struct llm_build_context { const int32_t n_outputs; const int32_t n_outputs_enc; const int32_t kv_head; // index of where we store new KV data in the cache + const int32_t rs_zero; // the first zero-ed recurrent state const int32_t n_ctx_orig; const bool flash_attn; @@ -10119,6 +10141,7 @@ struct llm_build_context { n_outputs (worst_case ? n_tokens : lctx.n_outputs), n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd), kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head), + rs_zero (kv_self.rs_z), n_ctx_orig (cparams.n_ctx_orig_yarn), flash_attn (cparams.flash_attn), pooling_type (cparams.pooling_type), @@ -10147,8 +10170,6 @@ struct llm_build_context { lctx.inp_mean = nullptr; lctx.inp_cls = nullptr; lctx.inp_s_copy = nullptr; - lctx.inp_s_mask = nullptr; - lctx.inp_s_seq = nullptr; lctx.inp_pos_bucket = nullptr; lctx.inp_embd_enc = nullptr; lctx.inp_KQ_mask_cross = nullptr; @@ -10332,13 +10353,6 @@ struct llm_build_context { return lctx.inp_s_copy; } - struct ggml_tensor * build_inp_s_mask() { - lctx.inp_s_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - cb(lctx.inp_s_mask, "inp_s_mask", -1); - ggml_set_input(lctx.inp_s_mask); - return lctx.inp_s_mask; - } - struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { // find result_norm tensor for input struct ggml_tensor * inp = nullptr; @@ -13901,7 +13915,6 @@ struct llm_build_context { inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -13912,15 +13925,13 @@ struct llm_build_context { switch (version) { case 2: - cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba2(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; case 1: default: - cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, - state_copy, state_mask, - kv_head, n_kv, cb, il); + cur = llm_build_mamba(ctx0, lctx, batch, gf, cur, state_copy, + rs_zero, kv_head, n_kv, cb, il); break; } @@ -15946,7 +15957,6 @@ struct llm_build_context { struct ggml_tensor * cur; struct ggml_tensor * inpL; struct ggml_tensor * state_copy = build_inp_s_copy(); - struct ggml_tensor * state_mask = build_inp_s_mask(); inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1); @@ -15955,11 +15965,11 @@ struct llm_build_context { const llama_layer * layer = &model.layers[il]; // (ab)using the KV cache to store the states - struct ggml_tensor * token_shift = llm_build_copy_mask_state(ctx0, - gf, kv_self.k_l[il], state_copy, state_mask, + struct ggml_tensor * token_shift = llm_build_rs(ctx0, + gf, kv_self.k_l[il], state_copy, rs_zero, hparams.n_embd_k_s(), kv_self.size, kv_head, n_kv, n_seqs); - struct ggml_tensor * wkv_states = llm_build_copy_mask_state(ctx0, - gf, kv_self.v_l[il], state_copy, state_mask, + struct ggml_tensor * wkv_states = llm_build_rs(ctx0, + gf, kv_self.v_l[il], state_copy, rs_zero, hparams.n_embd_v_s(), kv_self.size, kv_head, n_kv, n_seqs); cur = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); @@ -16329,18 +16339,6 @@ static void llama_set_k_shift(llama_context & lctx) { } } -static void llama_set_s_copy(llama_context & lctx) { - const int64_t kv_size = lctx.kv_self.size; - - assert(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); - - int32_t * data = (int32_t *) lctx.inp_s_copy->data; - - for (int i = 0; i < kv_size; ++i) { - data[i] = lctx.kv_self.cells[i].src; - } -} - static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; @@ -16656,24 +16654,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; - if (lctx.inp_s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_mask->buffer)); - float * data = (float *) lctx.inp_s_mask->data; - - // clear unused states - for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self.head; - llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } - } - } - if (lctx.inp_s_copy) { GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer)); int32_t * data = (int32_t *) lctx.inp_s_copy->data; @@ -16683,8 +16663,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) { const uint32_t cell_id = i + kv_self.head; llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id]; - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self.size) { + if (kv_cell.src < 0) { + GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source + kv_cell.src = kv_self.rs_z; + } + if ((uint32_t) kv_cell.src >= kv_self.size) { + // ignore out-of-bound sources kv_cell.src = cell_id; } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aa7896defdad0..092639eed42e1 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1530,27 +1530,58 @@ struct test_ssm_scan : public test_case { const int64_t d_state; const int64_t d_inner; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_MUL_MAT @@ -3255,7 +3286,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 #if 1 for (ggml_type type_a : base_types) { From 2c77d799f9387f5971289139aaca23b4ce37c435 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:36:22 -0400 Subject: [PATCH 09/92] metal : attempt to adapt SSM_SCAN for Mamba-2 --- ggml/src/ggml-metal.m | 107 ++++++++++++++++++++-------- ggml/src/ggml-metal.metal | 146 ++++++++++++++++++++++++++++++++------ 2 files changed, 202 insertions(+), 51 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 9da08fe2e9771..5d5b98307d264 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -95,6 +95,7 @@ GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, @@ -591,6 +592,7 @@ static void ggml_metal_log(enum ggml_log_level level, const char * format, ...){ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction); @@ -1629,47 +1631,74 @@ static void ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; + struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); + GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; + size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; + id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); const uint64_t nb30 = src3->nb[0]; const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); const uint64_t nb40 = src4->nb[0]; const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); const uint64_t nb50 = src5->nb[0]; const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; + + const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); + + const uint64_t nb70 = src7->nb[0]; const int64_t d_state = ne00; const int64_t d_inner = ne01; + const int64_t n_head = ne02; + const int64_t n_group = ne41; const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_seqs = ne13; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + if (ne30 == 1) { + // Mamba-2 + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -1678,33 +1707,49 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:7]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:8]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:9]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:10]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:11]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:12]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:18]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:19]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:20]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; + [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; + [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; + [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; + [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; + [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; + [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; + [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_MUL_MAT: { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2b200032394b1..c75fa25c34e7d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -795,7 +795,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part // TODO: optimize kernel void kernel_ssm_scan_f32( device const void * src0, @@ -804,14 +804,19 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, + device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, constant int64_t & n_seq_tokens, constant int64_t & n_seqs, constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, + constant uint64_t & nb03, constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, @@ -824,47 +829,148 @@ kernel void kernel_ssm_scan_f32( constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, + constant uint64_t & nb43, constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq const int64_t nc = d_state; const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; const int64_t n_s = n_seqs; + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*nb10 + i2*nb11 + i3*nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*nb41 + i3*nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*nb51 + i3*nb52); - device float * y = (device float *) ((device char *) dst + ir*nb10 + i2*nb11 + i3*nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb02 + nb13); - - if (i2 > 0) { - s0 = s; + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; } - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device const void * src7, + device float * dst, + constant int64_t & d_state, + constant int64_t & d_inner, + constant int64_t & n_head, + constant int64_t & n_group, + constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant uint64_t & nb20, + constant uint64_t & nb21, + constant uint64_t & nb22, + constant uint64_t & nb30, + constant uint64_t & nb31, + constant uint64_t & nb40, + constant uint64_t & nb41, + constant uint64_t & nb42, + constant uint64_t & nb43, + constant uint64_t & nb50, + constant uint64_t & nb51, + constant uint64_t & nb52, + constant uint64_t & nb53, + constant uint64_t & nb60, + constant uint64_t & nb70, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const int64_t nc = d_state; + const int64_t nr = d_inner; + const int64_t nh = n_head; + const int64_t ng = n_group; + const int64_t n_t = n_seq_tokens; + const int64_t n_s = n_seqs; + + const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src7; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} + device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = expf(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } - y[0] = sumf; + y[0] = sumf + x[0] * D[0]; + + // recurse + s0 = s; } } From 87b97d08f43652c7a2e73929e34432ae5f9e8713 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:41:10 -0400 Subject: [PATCH 10/92] metal : fix SSM_SCAN pipeline scope --- ggml/src/ggml-metal.m | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5d5b98307d264..477f720a0e32f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1693,11 +1693,13 @@ static void ggml_metal_encode_node( const int64_t n_seq_tokens = ne11; const int64_t n_seqs = ne13; + id pipeline = nil; + if (ne30 == 1) { // Mamba-2 - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; } else { - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; } [encoder setComputePipelineState:pipeline]; From 03d0e6eabe6172a56a7d470bfd844012f2c2b291 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 10:58:41 -0400 Subject: [PATCH 11/92] metal : use log and exp instead of log1pf and expf in SSM_SCAN --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c75fa25c34e7d..cee9980a75619 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -866,13 +866,13 @@ kernel void kernel_ssm_scan_f32( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { const int64_t i = i0 + i1*nc; - const float state = (s0[i] * expf(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } @@ -955,9 +955,9 @@ kernel void kernel_ssm_scan_f32_group( device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} - const float dt_soft_plus = dt[0] <= 20.0f ? log1pf(expf(dt[0])) : dt[0]; + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - const float dA = expf(dt_soft_plus * A[0]); + const float dA = exp(dt_soft_plus * A[0]); float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { From 7a351abc28e36aeb73d1fd8ce172db56fbb3ebcb Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:28:16 -0400 Subject: [PATCH 12/92] metal : remove unused arguments for SSM_SCAN The max index is 31, so trimming the arguments is necessary. --- ggml/src/ggml-metal.m | 53 ++++++++++++++++----------------------- ggml/src/ggml-metal.metal | 34 +++++++++---------------- 2 files changed, 34 insertions(+), 53 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 477f720a0e32f..5127b34f8edaa 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1655,7 +1655,7 @@ static void ggml_metal_encode_node( const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); @@ -1663,7 +1663,7 @@ static void ggml_metal_encode_node( const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; const uint64_t nb43 = src4->nb[3]; @@ -1673,18 +1673,18 @@ static void ggml_metal_encode_node( const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; const uint64_t nb53 = src5->nb[3]; const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); - const uint64_t nb60 = src6->nb[0]; + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - const uint64_t nb70 = src7->nb[0]; + const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); const int64_t d_state = ne00; const int64_t d_inner = ne01; @@ -1718,32 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:15]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:16]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:17]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:18]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:19]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:20]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:21]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:22]; - [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:23]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:24]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:25]; - [encoder setBytes:&nb23 length:sizeof(nb23) atIndex:26]; - [encoder setBytes:&nb30 length:sizeof(nb30) atIndex:27]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:28]; - [encoder setBytes:&nb40 length:sizeof(nb40) atIndex:29]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:30]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:31]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:32]; - [encoder setBytes:&nb50 length:sizeof(nb50) atIndex:33]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:34]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:35]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:36]; - [encoder setBytes:&nb60 length:sizeof(nb60) atIndex:37]; - [encoder setBytes:&nb70 length:sizeof(nb70) atIndex:38]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + // NOTE: max index is 31 if (ne30 == 1) { // Mamba-2 diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index cee9980a75619..3745f2f225512 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,30 +812,21 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -843,12 +834,16 @@ kernel void kernel_ssm_scan_f32( const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -864,7 +859,7 @@ kernel void kernel_ssm_scan_f32( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; @@ -901,30 +896,21 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, - constant int64_t & n_seqs, - constant uint64_t & nb00, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, - constant uint64_t & nb10, constant uint64_t & nb11, constant uint64_t & nb12, constant uint64_t & nb13, - constant uint64_t & nb20, constant uint64_t & nb21, constant uint64_t & nb22, - constant uint64_t & nb30, constant uint64_t & nb31, - constant uint64_t & nb40, constant uint64_t & nb41, constant uint64_t & nb42, constant uint64_t & nb43, - constant uint64_t & nb50, constant uint64_t & nb51, constant uint64_t & nb52, constant uint64_t & nb53, - constant uint64_t & nb60, - constant uint64_t & nb70, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { @@ -932,12 +918,16 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + const uint64_t nb60 = sizeof(float); + const int64_t nc = d_state; const int64_t nr = d_inner; const int64_t nh = n_head; const int64_t ng = n_group; const int64_t n_t = n_seq_tokens; - const int64_t n_s = n_seqs; const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); @@ -953,7 +943,7 @@ kernel void kernel_ssm_scan_f32_group( device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*sizeof(float)); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; From 8b15bc6fa0fbb7a0d831b90955430c0a9e281ac2 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 11:47:56 -0400 Subject: [PATCH 13/92] metal : add back n_seqs to SSM_SCAN args Whoops, this is needed for the offset in the concatenated output. --- ggml/src/ggml-metal.m | 33 +++++++++++++++++---------------- ggml/src/ggml-metal.metal | 2 ++ 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 5127b34f8edaa..3f7183060d83d 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1718,22 +1718,23 @@ static void ggml_metal_encode_node( [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3745f2f225512..c36eedb010de1 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -812,6 +812,7 @@ kernel void kernel_ssm_scan_f32( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, @@ -896,6 +897,7 @@ kernel void kernel_ssm_scan_f32_group( constant int64_t & n_head, constant int64_t & n_group, constant int64_t & n_seq_tokens, + constant int64_t & n_seqs, constant uint64_t & nb01, constant uint64_t & nb02, constant uint64_t & nb03, From 5b8ec2b978b84dfdb05e6fca4def928f72b1090c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 12:11:45 -0400 Subject: [PATCH 14/92] metal : fix SSM_SCAN state head offset --- ggml/src/ggml-metal.metal | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c36eedb010de1..9e1d14ff5d8b5 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -850,8 +850,8 @@ kernel void kernel_ssm_scan_f32( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} @@ -935,8 +935,8 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src7; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb01 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); + device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} From 62b09b343c6c4e35486368f1a7b653c9ae58574a Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 2 Oct 2024 21:35:50 -0400 Subject: [PATCH 15/92] metal : fix wrong number of tokens per sequence in SSM_SCAN --- ggml/src/ggml-metal.m | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 3f7183060d83d..a39770bd4ed1b 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1690,7 +1690,7 @@ static void ggml_metal_encode_node( const int64_t d_inner = ne01; const int64_t n_head = ne02; const int64_t n_group = ne41; - const int64_t n_seq_tokens = ne11; + const int64_t n_seq_tokens = ne12; const int64_t n_seqs = ne13; id pipeline = nil; From 805512a73b9876853f0e7d0cd612259806fa5d93 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 12 Oct 2024 16:20:26 -0400 Subject: [PATCH 16/92] ggml : remove unused fast broadcast path in GGML_MUL This was initially added because states were masked with ggml_mul, but this is no longer done and so this "optimisation" is no longer necessary, or at least not worth the additional code complexity. --- ggml/src/ggml.c | 32 +------------------------------- 1 file changed, 1 insertion(+), 31 deletions(-) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e8a5e3d153548..8fd335270dd5a 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -10173,37 +10173,7 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - if (ne00 > 1 && ne10 == 1) { - // fast broadcast path - for (int64_t ir = ith; ir < nr; ir += nth) { - // src0 and dst are same shape => same indices - const int64_t i03 = ir/(ne02*ne01); - const int64_t i02 = (ir - i03*ne02*ne01)/ne01; - const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; - - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - - const float scale = src1_ptr[0]; - - if (scale == 0.0f) { - // NOTE: this also sets NANs to zero, which is not compliant with IEEE754, - // but it is useful when resetting the state of recurrent models. - memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float)); - } else { - if (dst->data != src0->data) { - // src0 is same shape as dst => same indices - memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float)); - } - if (scale != 1.0f) { - ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale); - } - } - } - } else if (nb10 == sizeof(float)) { + if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { // src0 and dst are same shape => same indices const int64_t i03 = ir/(ne02*ne01); From 3bc7103d2ef1c41cd380a1ad8d918cf9c26694d8 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 11:36:37 -0500 Subject: [PATCH 17/92] ggml : avoid multiply by D in GGML_OP_SSM_SCAN This makes the weight buft detection in src/llama.cpp simpler. * convert : transpose Mamba-2 A, D and reshape SSM_NORM This breaks existing conversions of Mamba-2 models to avoid some reshapes. Not sure if it's a good idea, but it makes the graph slightly cleaner. * llama : more appropriate SSM_SCAN and SSM_CONV buft support checks --- convert_hf_to_gguf.py | 26 ++++++++++++++++- ggml/include/ggml.h | 1 - ggml/src/ggml-metal.m | 57 ++++++++++++++++---------------------- ggml/src/ggml-metal.metal | 14 +++------- ggml/src/ggml.c | 20 ++++--------- src/llama.cpp | 54 +++++++++++++++++++----------------- tests/test-backend-ops.cpp | 25 ++++++++--------- 7 files changed, 100 insertions(+), 97 deletions(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f307b1ac69202..f0a63d921d65f 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -264,6 +264,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(self.map_tensor_name(name), data_torch)] + # TODO: merge into modify_tensors? (need to check tensor shapes for all arches before doing that) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + del new_name, bid # unused + + return data_torch.squeeze() + def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -295,7 +301,7 @@ def prepare_tensors(self): break for new_name, data_torch in (self.modify_tensors(data_torch, name, bid)): - data = data_torch.squeeze().numpy() + data = self.reshape_tensors(data_torch, new_name, bid).numpy() # if data ends up empty, it means data_torch was a scalar tensor -> restore if len(data.shape) == 0: @@ -3063,6 +3069,24 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) + def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> Tensor: + if any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + return data_torch.reshape((*data_torch.shape, 1)) + + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + return data_torch.reshape((n_group, d_inner // n_group)) + + return data_torch.squeeze() + + @Model.register("CohereForCausalLM") class CommandR2Model(Model): diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 0d2e5cb011a3b..735f56b005a28 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1828,7 +1828,6 @@ extern "C" { struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 73e2fedc36544..902728d8e6b55 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1649,25 +1649,21 @@ static void ggml_metal_encode_node( struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; struct ggml_tensor * src6 = node->src[6]; - struct ggml_tensor * src7 = node->src[7]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); GGML_ASSERT(src6); - GGML_ASSERT(src7); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; size_t offs_src6 = 0; - size_t offs_src7 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - id id_src7 = src7 ? ggml_metal_get_buffer(src7, &offs_src7) : nil; const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); @@ -1699,10 +1695,6 @@ static void ggml_metal_encode_node( const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); - const int64_t ne70 = src7->ne[0]; GGML_UNUSED(ne70); - - const uint64_t nb70 = src7->nb[0]; GGML_UNUSED(nb70); - const int64_t d_state = ne00; const int64_t d_inner = ne01; const int64_t n_head = ne02; @@ -1727,31 +1719,30 @@ static void ggml_metal_encode_node( [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; - [encoder setBuffer:id_src7 offset:offs_src7 atIndex:7]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:8]; - - [encoder setBytes:&d_state length:sizeof(d_state) atIndex:9]; - [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:10]; - [encoder setBytes:&n_head length:sizeof(n_head) atIndex:11]; - [encoder setBytes:&n_group length:sizeof(n_group) atIndex:12]; - [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:13]; - [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:14]; - - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:15]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:16]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:20]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:21]; - [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:22]; - [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:23]; - [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:24]; - [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:25]; - [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:26]; - [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:27]; - [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:28]; - [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:29]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + + [encoder setBytes:&d_state length:sizeof(d_state) atIndex:8]; + [encoder setBytes:&d_inner length:sizeof(d_inner) atIndex:9]; + [encoder setBytes:&n_head length:sizeof(n_head) atIndex:10]; + [encoder setBytes:&n_group length:sizeof(n_group) atIndex:11]; + [encoder setBytes:&n_seq_tokens length:sizeof(n_seq_tokens) atIndex:12]; + [encoder setBytes:&n_seqs length:sizeof(n_seqs) atIndex:13]; + + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:14]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:15]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:16]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:17]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:18]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:19]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:20]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(nb31) atIndex:22]; + [encoder setBytes:&nb41 length:sizeof(nb41) atIndex:23]; + [encoder setBytes:&nb42 length:sizeof(nb42) atIndex:24]; + [encoder setBytes:&nb43 length:sizeof(nb43) atIndex:25]; + [encoder setBytes:&nb51 length:sizeof(nb51) atIndex:26]; + [encoder setBytes:&nb52 length:sizeof(nb52) atIndex:27]; + [encoder setBytes:&nb53 length:sizeof(nb53) atIndex:28]; // NOTE: max index is 31 if (ne30 == 1) { diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 2f5a4d12eeec3..05d04e8f3fdbf 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -805,7 +805,6 @@ kernel void kernel_ssm_scan_f32( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -838,7 +837,6 @@ kernel void kernel_ssm_scan_f32( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -848,7 +846,7 @@ kernel void kernel_ssm_scan_f32( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -859,7 +857,6 @@ kernel void kernel_ssm_scan_f32( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {d_state, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -873,7 +870,7 @@ kernel void kernel_ssm_scan_f32( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; @@ -890,7 +887,6 @@ kernel void kernel_ssm_scan_f32_group( device const void * src4, device const void * src5, device const void * src6, - device const void * src7, device float * dst, constant int64_t & d_state, constant int64_t & d_inner, @@ -923,7 +919,6 @@ kernel void kernel_ssm_scan_f32_group( const uint64_t nb00 = sizeof(float); const uint64_t nb10 = sizeof(float); const uint64_t nb20 = sizeof(float); - const uint64_t nb60 = sizeof(float); const int64_t nc = d_state; const int64_t nr = d_inner; @@ -933,7 +928,7 @@ kernel void kernel_ssm_scan_f32_group( const int64_t s_off = d_inner * n_head * n_seq_tokens * n_seqs * sizeof(float); - device const int32_t * ids = (device const int32_t *) src7; + device const int32_t * ids = (device const int32_t *) src6; device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); @@ -944,7 +939,6 @@ kernel void kernel_ssm_scan_f32_group( device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device const float * D = (device const float *) ((device const char *) src6 + ir*nb60); // {nh} device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; @@ -959,7 +953,7 @@ kernel void kernel_ssm_scan_f32_group( s[i] = state; } - y[0] = sumf + x[0] * D[0]; + y[0] = sumf; // recurse s0 = s; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 91b256a4c25f0..9036fc0be9858 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -7181,7 +7181,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -7205,7 +7204,6 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * A, struct ggml_tensor * B, struct ggml_tensor * C, - struct ggml_tensor * D, struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); GGML_ASSERT(ggml_is_contiguous(dt)); @@ -7235,8 +7233,6 @@ struct ggml_tensor * ggml_ssm_scan( GGML_ASSERT(B->ne[0] == d_state); GGML_ASSERT(B->ne[2] == n_seq_tokens); GGML_ASSERT(B->ne[3] == n_seqs); - GGML_ASSERT(D->ne[0] == n_head); - GGML_ASSERT(ggml_is_vector(D)); GGML_ASSERT(ids->ne[0] == n_seqs); GGML_ASSERT(ggml_is_vector(ids)); GGML_ASSERT(A->ne[1] == n_head); @@ -7258,8 +7254,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; - result->src[6] = D; - result->src[7] = ids; + result->src[6] = ids; return result; } @@ -16217,8 +16212,7 @@ static void ggml_compute_forward_ssm_scan_f32( const struct ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} const struct ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} const struct ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} - const struct ggml_tensor * src6 = dst->src[6]; // D {n_head} - const struct ggml_tensor * src7 = dst->src[7]; // ids {n_seqs} + const struct ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; @@ -16240,8 +16234,7 @@ static void ggml_compute_forward_ssm_scan_f32( GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - GGML_ASSERT(src6->nb[0] == sizeof(float)); - GGML_ASSERT(src7->nb[0] == sizeof(int32_t)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); // allows optimizing the modulo since n_group should be a power of 2 GGML_ASSERT((ng & -ng) == ng); @@ -16252,7 +16245,7 @@ static void ggml_compute_forward_ssm_scan_f32( const int ih0 = dh*ith; const int ih1 = MIN(ih0 + dh, nh); - const int32_t * ids = (const int32_t *) src7->data; + const int32_t * ids = (const int32_t *) src6->data; for (int i3 = 0; i3 < ns; ++i3) { const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} @@ -16264,7 +16257,6 @@ static void ggml_compute_forward_ssm_scan_f32( const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} - const float * D = (const float *) ((const char *) src6->data); // {nh} float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} if (src3->ne[0] == 1) { @@ -16325,7 +16317,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } else { @@ -16353,7 +16345,7 @@ static void ggml_compute_forward_ssm_scan_f32( sumf += state * C[ig]; s[i] = state; } - y[ii] = sumf + x[ii] * D[h]; + y[ii] = sumf; } } } diff --git a/src/llama.cpp b/src/llama.cpp index e84510ce8ffd1..52052caf250b1 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -7120,6 +7120,7 @@ static const std::map llm_tensor_info_mapping = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -7227,23 +7228,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV: { @@ -8572,10 +8577,10 @@ static bool llm_load_tensors( layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); // no "weight" suffix for these - layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {n_head}, 0); - layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {n_head}, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); - layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner}, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); @@ -9994,7 +9999,7 @@ static struct ggml_tensor * llm_build_rs( return states; } -// TODO: split +// TODO: split conv and ssm static struct ggml_tensor * llm_build_mamba( struct ggml_context * ctx, struct llama_context & lctx, @@ -10102,13 +10107,14 @@ static struct ggml_tensor * llm_build_mamba( dt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_dt, dt); dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); + cur = x; x = ggml_reshape_4d(ctx, x, head_dim, n_head, n_seq_tokens, n_seqs); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10120,6 +10126,7 @@ static struct ggml_tensor * llm_build_mamba( // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, cur, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} @@ -10184,7 +10191,7 @@ static struct ggml_tensor * llm_build_mamba2( struct ggml_tensor * zxBCdt = llm_build_lora_mm(lctx, ctx, model.layers[il].ssm_in, cur); // split the above in three - struct ggml_tensor * z = ggml_view_3d(ctx, zxBCdt, d_inner, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], 0); + struct ggml_tensor * z = ggml_view_4d(ctx, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); struct ggml_tensor * xBC = ggml_view_3d(ctx, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); struct ggml_tensor * dt = ggml_view_3d(ctx, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); @@ -10230,11 +10237,9 @@ static struct ggml_tensor * llm_build_mamba2( dt = ggml_add(ctx, dt, model.layers[il].ssm_dt_b); struct ggml_tensor * ssm_ids = ggml_view_1d(ctx, state_copy, n_seqs, 0); - // Use the same shape semantics for A as Mamba-1 - struct ggml_tensor * A = ggml_reshape_2d(ctx, model.layers[il].ssm_a, 1, n_head); // TODO: use semistructured matrices to implement state-space duality // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, model.layers[il].ssm_d, ssm_ids); + struct ggml_tensor * y_ssm = ggml_ssm_scan(ctx, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(graph, @@ -10242,17 +10247,16 @@ static struct ggml_tensor * llm_build_mamba2( ggml_view_1d(ctx, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), ggml_view_1d(ctx, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - struct ggml_tensor * y = ggml_view_3d(ctx, y_ssm, d_inner, n_seq_tokens, n_seqs, n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + struct ggml_tensor * y = ggml_view_4d(ctx, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); // TODO: skip computing output earlier for unused tokens + y = ggml_add(ctx, y, ggml_mul(ctx, x, model.layers[il].ssm_d)); y = ggml_mul(ctx, y, ggml_silu(ctx, ggml_cont(ctx, z))); // grouped RMS norm y = ggml_reshape_4d(ctx, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); - y = llm_build_norm(ctx, y, hparams, - ggml_reshape_2d(ctx, model.layers[il].ssm_norm, d_inner / n_group, n_group), NULL, - LLM_NORM_RMS, cb, il); + y = llm_build_norm(ctx, y, hparams, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, cb, il); y = ggml_reshape_3d(ctx, y, d_inner, n_seq_tokens, n_seqs); // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6ca254a45f23f..95f8abbd80968 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1589,35 +1589,34 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; const int64_t n_head; const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR7(type, d_state, d_inner, n_head, n_group, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, int64_t d_state = 32, - int64_t d_inner = 1, // non-zero for Mamba-2 + int64_t head_dim = 1, // non-zero for Mamba-2 int64_t n_head = 32, int64_t n_group = 1, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, d_inner, n_head, n_seqs); - ggml_tensor * x = ggml_new_tensor_4d(ctx, type, d_inner, n_head, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); - ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (d_inner > 1) ? 1 : d_state, n_head); - ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); - ggml_tensor * D = ggml_new_tensor_1d(ctx, type, n_head); - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, D, ids); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; } From b4e9c5998dea2d657cfd22bc2e6fa0630fba2fa9 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 4 Nov 2024 15:26:15 -0500 Subject: [PATCH 18/92] convert : fix flake8 lint --- convert_hf_to_gguf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index f0efe5d5b0c7c..019e7b7ef93b6 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -3088,7 +3088,6 @@ def reshape_tensors(self, data_torch: Tensor, new_name: str, bid: int | None) -> return data_torch.squeeze() - @Model.register("CohereForCausalLM") class CommandR2Model(Model): model_arch = gguf.MODEL_ARCH.COMMAND_R From cf4f0a4123d94b3f09e9f9343b76f48bd6043756 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 18:55:34 -0400 Subject: [PATCH 19/92] metal : fix confusion between ; and , --- ggml/src/ggml-metal/ggml-metal.m | 42 ++++++++++++++++---------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index 370d0ad7744fa..7d6377c790903 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2652,27 +2652,27 @@ static bool ggml_metal_encode_node( } ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state; - /*.d_inner =*/ d_inner; - /*.n_head =*/ n_head; - /*.n_group =*/ n_group; - /*.n_seq_tokens =*/ n_seq_tokens; - /*.n_seqs =*/ n_seqs; - /*.nb01 =*/ nb01; - /*.nb02 =*/ nb02; - /*.nb03 =*/ nb03; - /*.nb11 =*/ nb11; - /*.nb12 =*/ nb12; - /*.nb13 =*/ nb13; - /*.nb21 =*/ nb21; - /*.nb22 =*/ nb22; - /*.nb31 =*/ nb31; - /*.nb41 =*/ nb41; - /*.nb42 =*/ nb42; - /*.nb43 =*/ nb43; - /*.nb51 =*/ nb51; - /*.nb52 =*/ nb52; - /*.nb53 =*/ nb53; + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, + /*.n_seq_tokens =*/ n_seq_tokens, + /*.n_seqs =*/ n_seqs, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.nb53 =*/ nb53, }; [encoder setComputePipelineState:pipeline]; From 6def5cd729fdde64b2addeaa5cce016c72485e06 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 19:10:20 -0400 Subject: [PATCH 20/92] metal : add missing args for nb references in ssm_scan_f32_group --- ggml/src/ggml-metal/ggml-metal.metal | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 71ab693721298..4b5e4f8457210 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1350,16 +1350,16 @@ kernel void kernel_ssm_scan_f32_group( device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb02 + ids[i3]*nb03); - device float * s = (device float *) ((device char *) dst + ir*nb02 + i3*nb03 + s_off); + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*nb11 + i2*nb12 + i3*nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*nb21 + i3*nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*nb31); // {1, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*nb41 + i2*nb42 + i3*nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*nb51 + i2*nb52 + i3*nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; From 791998b42d6cd6edb31e4d5824e29c100cecd40b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 21:27:12 -0400 Subject: [PATCH 21/92] metal : single-user mamba2 inference works --- ggml/src/ggml-metal/ggml-metal.metal | 14 +++++++------- src/llama-model.cpp | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4b5e4f8457210..4e50efdee41ca 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1284,7 +1284,7 @@ kernel void kernel_ssm_scan_f32( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * nt * args.n_seqs * sizeof(float); + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); device const int32_t * ids = (device const int32_t *) src6; @@ -1292,12 +1292,12 @@ kernel void kernel_ssm_scan_f32( device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; @@ -1354,12 +1354,12 @@ kernel void kernel_ssm_scan_f32_group( device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f295c684099e7..cffdbc6845363 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9009,7 +9009,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); // {n_head, n_seq_tokens, n_seqs} - dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); // TODO: use semistructured matrices to implement state-space duality From 94c3d5304352eef27c33b08a858facdffbb28438 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Thu, 1 May 2025 22:18:57 -0400 Subject: [PATCH 22/92] kv-cache : remove const_cast when setting inputs for s_copy And also fix multi-user inference for recurrent models by using cell_id instead of i as the kv cell index when populating s_copy. --- src/llama-graph.cpp | 22 ++++++++-------------- src/llama-kv-cache.cpp | 7 +++++-- src/llama-kv-cache.h | 1 + 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0f77f98b24f64..8d2fceb17def5 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -286,27 +286,21 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { for (uint32_t i = 0; i < n_kv; ++i) { const uint32_t cell_id = i + kv_self->head; - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; + const llama_kv_cell & kv_cell = kv_self->cells[cell_id]; + + int32_t src = kv_cell.src0; // prevent out-of-bound sources - if (kv_cell.src < 0) { + if (src < 0) { GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source - kv_cell.src = kv_self->rs_z; + src = kv_self->rs_z; } - if ((uint32_t) kv_cell.src >= kv_self->size) { + if ((uint32_t) src >= kv_self->size) { // ignore out-of-bound sources - kv_cell.src = cell_id; + src = cell_id; } - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } + data[i] = src; } } } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 108c07731b1ab..743b30badcf67 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -665,10 +665,13 @@ bool llama_kv_cache_unified::find_slot( // Find first to-be-cleared cell rs_z = -1; for (int i = min; i <= max; ++i) { - if (cells[i].src == -1) { + if (rs_z < 0 && cells[i].src == -1) { rs_z = i; - break; } + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + cells[i].src = i; } // allow getting the range of used cells, from head to head + n diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 7939bc6b8dd2d..6b115e8f7d134 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -47,6 +47,7 @@ struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; int32_t src = -1; // used by recurrent state models to copy states + int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once) int32_t tail = -1; std::set seq_id; From d55b0d06210cdc10b6cf872b9009d82bb6372b01 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 2 May 2025 18:24:55 -0400 Subject: [PATCH 23/92] convert : avoid AutoConfig for Mamba and Mamba2 hparams --- convert_hf_to_gguf.py | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 532cc879de324..2debb6e63fef9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4127,6 +4127,14 @@ def set_gguf_parameters(self): class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 8 @@ -4205,6 +4213,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter class Mamba2Model(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA2 + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 16 @@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) - arch = hparams["architectures"][0] + arch = None + if (arches := hparams.get("architectures")) is not None and len(arches) > 0: + arch = arches[0] + elif "ssm_cfg" in hparams: + # For non-hf Mamba and Mamba2 models + arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" + # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] elif model_type == ModelType.VISION and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] + if arch is None: + raise ValueError("Failed to detect model architecture") return arch From e94f3932f2dbcb2120580a9f42878e058a18cf5b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Fri, 2 May 2025 19:29:23 -0400 Subject: [PATCH 24/92] kv-cache : allow context shift for recurrent models --- src/llama-kv-cache.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 87ce7ce03d503..99dd20b68fd73 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1938,7 +1938,8 @@ llama_pos llama_kv_cache_recurrent::get_pos_max() const { } bool llama_kv_cache_recurrent::get_can_shift() const { - return false; + // shifting is trivial, the recurrent states don't care about the absolute position + return true; } uint32_t llama_kv_cache_recurrent::cell_max() const { From 2fa5f2ceb8b49bbd2835878ad5429ea74383566c Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 10 Jun 2025 20:00:41 -0400 Subject: [PATCH 25/92] graph : fix recurrent state copies when avoiding copies Works, but using lambda functions might not be that clean. --- src/llama-graph.cpp | 19 +++++++------------ src/llama-graph.h | 3 ++- src/llama-model.cpp | 44 +++++++++++++++++++++++++++++--------------- 3 files changed, 38 insertions(+), 28 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e74c9ff53b05a..1abe3b8febb4a 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1429,7 +1429,8 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, - bool avoid_copies) const { + const std::function & get_state_rows) const { + const auto * kv_state = static_cast(mstate); const auto n_kv = kv_state->get_n_kv(); @@ -1445,17 +1446,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state( ggml_tensor * output_states; - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_kv) ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); diff --git a/src/llama-graph.h b/src/llama-graph.h index 88fb77f1ddc9a..1fcf1cde45a41 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -599,7 +599,8 @@ struct llm_graph_context { ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, - bool avoid_copies = false) const; + const std::function + & get_state_rows = ggml_get_rows) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 2c0f7d4084344..2999483ad71ed 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9024,11 +9024,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il); - // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); - ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size()); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9094,11 +9091,21 @@ struct llm_build_mamba : public llm_graph_context { cur = x; x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); - ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); - // Custom operator to optimize the parallel associative scan - // as described in the Annex D of the Mamba paper. - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -9151,11 +9158,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * conv_states_all = kv_state->get_k_l(il); ggml_tensor * ssm_states_all = kv_state->get_v_l(il); - // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); - ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); - ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_state->get_size()); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -9211,10 +9215,20 @@ struct llm_build_mamba : public llm_graph_context { // {n_head, n_seq_tokens, n_seqs} dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); - ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); - // TODO: use semistructured matrices to implement state-space duality - // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); + ggml_tensor * A = model.layers[il].ssm_a; + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, From 757aa6239de5cc41afdea32561ab227b7b447424 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 11 Jun 2025 12:33:05 -0400 Subject: [PATCH 26/92] ggml : fix mamba2 ssm scan when compiled with SVE --- ggml/src/ggml-cpu/ops.cpp | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2a6be25852e4e..11d4819c868f3 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7664,6 +7664,37 @@ static void ggml_compute_forward_ssm_scan_f32( const float x_dt = x[ii] * dt_soft_plus; float sumf = 0.0f; #if defined(GGML_SIMD) + #if defined(__ARM_FEATURE_SVE) + const int ggml_f32_epr = svcntw(); + const int ggml_f32_step = 1 * ggml_f32_epr; + + const int np = (nc & ~(ggml_f32_step - 1)); + + GGML_F32_VEC sum = GGML_F32_VEC_ZERO; + + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); + + for (int i = 0; i < np; i += ggml_f32_step) { + // TODO: maybe unroll more? + for (int j = 0; j < 1; j++) { + GGML_F32_VEC t0 = GGML_F32_VEC_LOAD(s0 + i + j*ggml_f32_epr + ii*nc); + GGML_F32_VEC t1 = GGML_F32_VEC_LOAD(B + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + GGML_F32_VEC t2 = GGML_F32_VEC_LOAD(C + i + j*ggml_f32_epr + (h & (ng - 1))*nc); + + t0 = GGML_F32_VEC_MUL(t0, adA); + t1 = GGML_F32_VEC_MUL(t1, axdt); + + t0 = GGML_F32_VEC_ADD(t0, t1); + + sum = GGML_F32_VEC_FMA(sum, t0, t2); + + GGML_F32_VEC_STORE(s + i + j*ggml_f32_epr + ii*nc, t0); + } + } + + sumf = GGML_F32xt_REDUCE_ONE(sum); + #else const int np = (nc & ~(GGML_F32_STEP - 1)); GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; @@ -7694,6 +7725,7 @@ static void ggml_compute_forward_ssm_scan_f32( // reduce sum0..sum3 to sum0 GGML_F32_VEC_REDUCE(sumf, sum); + #endif #else const int np = 0; #endif @@ -7722,7 +7754,7 @@ static void ggml_compute_forward_ssm_scan_f32( for (int i1 = 0; i1 < nr; ++i1) { const int ii = i1 + h*nr; const float x_dt = x[ii] * dt_soft_plus; -#ifdef __ARM_FEATURE_SVE +#if defined(__ARM_FEATURE_SVE) svfloat32_t vx_dt = GGML_F32_VEC_SET1(x_dt); svfloat32_t vdt_soft_plus = GGML_F32_VEC_SET1(dt_soft_plus); svfloat32_t r1_vector = GGML_F32_VEC_ZERO; From 0b6f6becb4e916a24fcaf2966647381a21d1f084 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Wed, 11 Jun 2025 15:29:58 -0400 Subject: [PATCH 27/92] ggml-cpu : reorder SVE FMA for consistency with other SIMD arches --- ggml/src/ggml-cpu/ops.cpp | 2 +- ggml/src/ggml-cpu/simd-mappings.h | 2 +- ggml/src/ggml-cpu/vec.cpp | 18 +++++++++--------- ggml/src/ggml-cpu/vec.h | 18 +++++++++--------- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 11d4819c868f3..711d8abcc5fdd 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7771,7 +7771,7 @@ static void ggml_compute_forward_ssm_scan_f32( t1 = exp_ps_sve(svptrue_b32(), t1); svfloat32_t t2 = GGML_F32_VEC_MUL(vx_dt, vB); - vs0 = GGML_F32_VEC_FMA(vs0, t1, t2); + vs0 = GGML_F32_VEC_FMA(t2, vs0, t1); r1_vector = GGML_F32_VEC_ADD(GGML_F32_VEC_MUL(vs0, vC), r1_vector); GGML_F32_VEC_STORE(&s[ii*nc + k], vs0); diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 2e3669c0186c9..91bb867bf57b8 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -32,7 +32,7 @@ #define GGML_F32xt_LOAD(...) GGML_F32xt_LOAD_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_STORE_IMPL(pg,a,b) svst1_f32(pg, a, b) #define GGML_F32xt_STORE(...) GGML_F32xt_STORE_IMPL(DEFAULT_PG, __VA_ARGS__) -#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, a, b, c) +#define GGML_F32xt_FMA_IMPL(pg, a, b, c) svmad_f32_m(pg, b, c, a) #define GGML_F32xt_FMA(...) GGML_F32xt_FMA_IMPL(DEFAULT_PG, __VA_ARGS__) #define GGML_F32xt_ADD_IMPL(pg, a, b) svadd_f32_m(pg, a, b) #define GGML_F32xt_ADD(...) GGML_F32xt_ADD_IMPL(DEFAULT_PG, __VA_ARGS__) diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index f7614568ea388..7e61b5bf965a3 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -37,35 +37,35 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = 0; i < np; i += ggml_f32_step) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - sum2 = GGML_F32_VEC_FMA(ax2, ay2, sum2); + sum2 = GGML_F32_VEC_FMA(sum2, ax2, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - sum3 = GGML_F32_VEC_FMA(ax3, ay3, sum3); + sum3 = GGML_F32_VEC_FMA(sum3, ax3, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - sum4 = GGML_F32_VEC_FMA(ax4, ay4, sum4); + sum4 = GGML_F32_VEC_FMA(sum4, ax4, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - sum5 = GGML_F32_VEC_FMA(ax5, ay5, sum5); + sum5 = GGML_F32_VEC_FMA(sum5, ax5, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - sum6 = GGML_F32_VEC_FMA(ax6, ay6, sum6); + sum6 = GGML_F32_VEC_FMA(sum6, ax6, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - sum7 = GGML_F32_VEC_FMA(ax7, ay7, sum7); + sum7 = GGML_F32_VEC_FMA(sum7, ax7, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - sum8 = GGML_F32_VEC_FMA(ax8, ay8, sum8); + sum8 = GGML_F32_VEC_FMA(sum8, ax8, ay8); } // leftovers // Since 8 unrolls are done in above loop, leftovers lie in range [0, ggml_f32_step] which is handled in below loop @@ -73,7 +73,7 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - sum1 = GGML_F32_VEC_FMA(ax1, ay1, sum1); + sum1 = GGML_F32_VEC_FMA(sum1, ax1, ay1); } // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only if (np2 < n) { diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index 09dbade2179fb..a144259800477 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -163,49 +163,49 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); ax2 = GGML_F32_VEC_LOAD(x + i + 1*ggml_f32_epr); ay2 = GGML_F32_VEC_LOAD(y + i + 1*ggml_f32_epr); - ay2 = GGML_F32_VEC_FMA(ax2, vx, ay2); + ay2 = GGML_F32_VEC_FMA(ay2, ax2, vx); GGML_F32_VEC_STORE(y + i + 1*ggml_f32_epr, ay2); ax3 = GGML_F32_VEC_LOAD(x + i + 2*ggml_f32_epr); ay3 = GGML_F32_VEC_LOAD(y + i + 2*ggml_f32_epr); - ay3 = GGML_F32_VEC_FMA(ax3, vx, ay3); + ay3 = GGML_F32_VEC_FMA(ay3, ax3, vx); GGML_F32_VEC_STORE(y + i + 2*ggml_f32_epr, ay3); ax4 = GGML_F32_VEC_LOAD(x + i + 3*ggml_f32_epr); ay4 = GGML_F32_VEC_LOAD(y + i + 3*ggml_f32_epr); - ay4 = GGML_F32_VEC_FMA(ax4, vx, ay4); + ay4 = GGML_F32_VEC_FMA(ay4, ax4, vx); GGML_F32_VEC_STORE(y + i + 3*ggml_f32_epr, ay4); ax5 = GGML_F32_VEC_LOAD(x + i + 4*ggml_f32_epr); ay5 = GGML_F32_VEC_LOAD(y + i + 4*ggml_f32_epr); - ay5 = GGML_F32_VEC_FMA(ax5, vx, ay5); + ay5 = GGML_F32_VEC_FMA(ay5, ax5, vx); GGML_F32_VEC_STORE(y + i + 4*ggml_f32_epr, ay5); ax6 = GGML_F32_VEC_LOAD(x + i + 5*ggml_f32_epr); ay6 = GGML_F32_VEC_LOAD(y + i + 5*ggml_f32_epr); - ay6 = GGML_F32_VEC_FMA(ax6, vx, ay6); + ay6 = GGML_F32_VEC_FMA(ay6, ax6, vx); GGML_F32_VEC_STORE(y + i + 5*ggml_f32_epr, ay6); ax7 = GGML_F32_VEC_LOAD(x + i + 6*ggml_f32_epr); ay7 = GGML_F32_VEC_LOAD(y + i + 6*ggml_f32_epr); - ay7 = GGML_F32_VEC_FMA(ax7, vx, ay7); + ay7 = GGML_F32_VEC_FMA(ay7, ax7, vx); GGML_F32_VEC_STORE(y + i + 6*ggml_f32_epr, ay7); ax8 = GGML_F32_VEC_LOAD(x + i + 7*ggml_f32_epr); ay8 = GGML_F32_VEC_LOAD(y + i + 7*ggml_f32_epr); - ay8 = GGML_F32_VEC_FMA(ax8, vx, ay8); + ay8 = GGML_F32_VEC_FMA(ay8, ax8, vx); GGML_F32_VEC_STORE(y + i + 7*ggml_f32_epr, ay8); } @@ -215,7 +215,7 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const for (int i = np; i < np2; i += ggml_f32_epr) { ax1 = GGML_F32_VEC_LOAD(x + i); ay1 = GGML_F32_VEC_LOAD(y + i); - ay1 = GGML_F32_VEC_FMA(ax1, vx, ay1); + ay1 = GGML_F32_VEC_FMA(ay1, ax1, vx); GGML_F32_VEC_STORE(y + i, ay1); } From 4dff845a22471a23c4478e3da8892e8118d3ff73 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:21:29 -0600 Subject: [PATCH 28/92] feat: Add llama_model_is_hybrid API call Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart --- include/llama.h | 3 +++ src/llama-arch.cpp | 22 ++++++++++++++++++++++ src/llama-arch.h | 3 +++ src/llama-model.cpp | 13 +++++-------- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/include/llama.h b/include/llama.h index 015a57898e22d..64f5d2e637b1c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -569,6 +569,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 43fa60a8070b7..391c34e8b676e 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1752,3 +1752,25 @@ llm_arch llm_arch_from_string(const std::string & name) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid(const llm_arch & arch) { + // TODO: There are currently no hybrid models! Once there are, this will be + // the place to identify them + switch (arch) { + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index f3825528aefdb..3b7ed7cb10c3b 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -436,3 +436,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); + +bool llm_arch_is_recurrent(const llm_arch& arch); +bool llm_arch_is_hybrid(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c64bf9de939f4..e5cfa2c62f86d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13821,14 +13821,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { } bool llama_model_is_recurrent(const llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; - } + return llm_arch_is_recurrent(model->arch); +} + +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { From 18224b7de379ea8ce4e1571ac493e3610329e241 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:08:33 -0600 Subject: [PATCH 29/92] feat: Add c++ side constants for attention layer indices hparam Branch: GraniteFour --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 391c34e8b676e..f908da8c5d728 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -144,6 +144,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 3b7ed7cb10c3b..5516b091cdd76 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -148,6 +148,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, From 4782ac0c2c316a42b3c7aca51095f61ef94951fd Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:04:36 -0600 Subject: [PATCH 30/92] feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-hparams.cpp | 14 ++++++++++++-- src/llama-hparams.h | 10 ++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 1499eb08a5dd9..70a7114f39715 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s() const { +uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s() const { +uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; @@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const { return ssm_d_state * ssm_d_inner; } +bool llama_hparams::recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; +} + bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index b2bcb8b01a18b..3614596464318 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -115,6 +115,9 @@ struct llama_hparams { uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + // for hybrid state space models + std::array recurrent_layer_arr; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -181,10 +184,13 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s() const; + uint32_t n_embd_k_s(uint32_t il = 0) const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s() const; + uint32_t n_embd_v_s(uint32_t il = 0) const; + + // whether or not the given layer is recurrent (for hybrid models) + bool recurrent_layer(uint32_t il) const; bool is_swa(uint32_t il) const; }; From 6a397e5d18d2f889f9628fbbe71b4c965b8e970b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:22:18 -0600 Subject: [PATCH 31/92] feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e5cfa2c62f86d..5f833d2a6b93a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -469,6 +469,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); From 452170859e66eba441712bdd33f2f7c6cb1be999 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 28 May 2025 06:48:53 -0600 Subject: [PATCH 32/92] refactor: rename *_is_hybrid -> *_is_hybrid_recurrent The implementation of the hybrid cache intentionally does not specify the types of the child caches, so there was a naming mismatch with these predicate functions that used "hybrid" to imply "hybrid recurrent." Branch: HybridCache Signed-off-by: Gabe Goodhart --- include/llama.h | 2 +- src/llama-arch.cpp | 2 +- src/llama-arch.h | 2 +- src/llama-model.cpp | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/llama.h b/include/llama.h index 64f5d2e637b1c..5a3284339fb96 100644 --- a/include/llama.h +++ b/include/llama.h @@ -570,7 +570,7 @@ extern "C" { LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) - LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model); // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f908da8c5d728..24008fe1be1f6 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1767,7 +1767,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { } } -bool llm_arch_is_hybrid(const llm_arch & arch) { +bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) { // TODO: There are currently no hybrid models! Once there are, this will be // the place to identify them switch (arch) { diff --git a/src/llama-arch.h b/src/llama-arch.h index 5516b091cdd76..ea124be434cda 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -439,4 +439,4 @@ llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); bool llm_arch_is_recurrent(const llm_arch& arch); -bool llm_arch_is_hybrid(const llm_arch& arch); +bool llm_arch_is_hybrid_recurrent(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5f833d2a6b93a..560250ca097b3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13828,8 +13828,8 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } -bool llama_model_is_hybrid(const llama_model * model) { - return llm_arch_is_hybrid(model->arch); +bool llama_model_is_hybrid_recurrent(const llama_model * model) { + return llm_arch_is_hybrid_recurrent(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { From 263d913159487f87722f7b6a150a53f40ce2a39c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 13:43:16 -0600 Subject: [PATCH 33/92] feat: Add layer filter to recurrent cache Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-recurrent.cpp | 22 ++++++++++++++-------- src/llama-kv-cache-recurrent.h | 17 +++++++++++------ src/llama-model.cpp | 1 + 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index de23b4ad23bce..7e71719251142 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -16,12 +16,13 @@ // llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -63,6 +64,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { + if (filter && !filter(i)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -88,8 +94,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); + k_l[i] = k; + v_l[i] = v; } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index d7c02ea872160..b4e285274b761 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -15,13 +15,18 @@ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it class llama_kv_cache_recurrent : public llama_memory_i { public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 560250ca097b3..b9e9d76b9ac5d 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13218,6 +13218,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = new llama_kv_cache_recurrent( *this, + nullptr, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv, From a7bd78272cb39f638825b6b59c9f06d50b778eb6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 09:16:06 -0600 Subject: [PATCH 34/92] fix: Use per-layer sizing everywhere in kv caches Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-recurrent.cpp | 16 ++++++++-------- src/llama-kv-cache-unified.cpp | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 7e71719251142..345de647b17fb 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); const char * dev_name = "CPU"; @@ -756,7 +756,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; @@ -776,7 +776,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -797,7 +797,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -944,7 +944,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -972,7 +972,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -1000,7 +1000,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 89606c598fc4f..10e3279d70dad 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); const char * dev_name = "CPU"; @@ -1429,7 +1429,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Write key type const int32_t k_type_i = (int32_t)layer.k->type; @@ -1451,7 +1451,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1475,7 +1475,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1620,7 +1620,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -1650,7 +1650,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -1680,7 +1680,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; From af916a0816da957fb44a142521a40bf40dfd9ece Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 30 May 2025 09:35:26 -0600 Subject: [PATCH 35/92] feat: First pass at llama_kv_cache_hybrid_recurrent This follows the pattern in iswa where the two child caches are held explicitly to support the case where a model requires a single attention cache and a single recurrent cache where each layer uses exactly one of the caches. This is a rewrite of the more generic approach in the original hybrid cache PR: https://github.com/ggml-org/llama.cpp/pull/13276 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/CMakeLists.txt | 1 + src/llama-kv-cache-hybrid-recurrent.cpp | 241 ++++++++++++++++++++++++ src/llama-kv-cache-hybrid-recurrent.h | 140 ++++++++++++++ 3 files changed, 382 insertions(+) create mode 100644 src/llama-kv-cache-hybrid-recurrent.cpp create mode 100644 src/llama-kv-cache-hybrid-recurrent.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 70be604e4b0d3..e70c8d9a183df 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(llama llama-kv-cache-unified.cpp llama-kv-cache-unified-iswa.cpp llama-kv-cache-recurrent.cpp + llama-kv-cache-hybrid-recurrent.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp new file mode 100644 index 0000000000000..bd2323762f37f --- /dev/null +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -0,0 +1,241 @@ +#include "llama-kv-cache-hybrid-recurrent.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_kv_cache_hybrid_recurrent +// + +llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( + const llama_model & model, + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload) : + hparams(model.hparams), + kv_attn(new llama_kv_cache_unified( + model, + [&](int32_t il) { return !model.hparams.recurrent_layer(il); }, + attn_type_k, + attn_type_v, + attn_v_trans, + offload, + attn_kv_size, + n_seq_max, + attn_n_pad, + attn_n_swa, + attn_swa_type + )), + kv_recurrent(new llama_kv_cache_recurrent( + model, + [&](int32_t il) { return model.hparams.recurrent_layer(il); }, + recurrent_type_k, + recurrent_type_v, + offload, + recurrent_kv_size, + n_seq_max + )) {} + +void llama_kv_cache_hybrid_recurrent::clear() { + kv_attn ->clear(); + kv_recurrent->clear(); +} + +bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!kv_recurrent->seq_rm(seq_id, p0, p1)) { + return false; + } + return kv_attn->seq_rm(seq_id, p0, p1); +} + +void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) { + kv_attn ->seq_keep(seq_id); + kv_recurrent->seq_keep(seq_id); +} + +void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_attn->seq_add(seq_id, p0, p1, shift); + kv_recurrent->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_attn ->seq_div(seq_id, p0, p1, d); + kv_recurrent->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id)); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + + // since this includes a recurrent cache, we cannot use split_simple + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + // prepare the recurrent batches first + if (!kv_recurrent->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined state at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache + auto heads_attn = kv_attn->prepare(ubatches); + if (heads_attn.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { + return std::make_unique(this); +} + +bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) { + bool res = false; + + res = res | kv_attn ->update(lctx); + res = res | kv_recurrent->update(lctx); + + return res; +} + +void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { + kv_attn ->defrag_sched(thold); + kv_recurrent->defrag_sched(thold); +} + +bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { + // TODO: Should this return true if the attention cache can shift? + return false; +} + +void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_attn ->state_write(io, seq_id); + kv_recurrent->state_write(io, seq_id); +} + +void llama_kv_cache_hybrid_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_attn ->state_read(io, seq_id); + kv_recurrent->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn() const { + return kv_attn.get(); +} + +llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() const { + return kv_recurrent.get(); +} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) + : status(status), state_attn(status), state_recurrent(status) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + kv(kv), + state_attn(status, kv->get_kv_attn()), + state_recurrent(status, kv->get_kv_recurrent()) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_sbatch sbatch, + std::vector heads_attn, + std::vector ubatches) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + kv(kv), + sbatch(std::move(sbatch)), + heads_attn(std::move(heads_attn)), + ubatches(std::move(ubatches)), + // NOTE: these child states are only used as wrapper APIs for the + // const methods, so we use the "init full" signature since the + // actual state is not used. + state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()), + state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {} + + +bool llama_kv_cache_hybrid_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_hybrid_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]); + kv->get_kv_recurrent()->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_hybrid_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_hybrid_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const { + return &state_attn; +} + +const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { + return &state_recurrent; +} diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h new file mode 100644 index 0000000000000..692079b650027 --- /dev/null +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" +#include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-unified.h" + +#include +#include + +// +// llama_kv_cache_hybrid_recurrent +// + +// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to +// support models where each layer may be either attention-based or recurrent + +class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { +public: + llama_kv_cache_hybrid_recurrent( + const llama_model & model, + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload); + + ~llama_kv_cache_hybrid_recurrent() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_hybrid_recurrent specific API + // + + llama_kv_cache_unified * get_kv_attn () const; + llama_kv_cache_recurrent * get_kv_recurrent() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr kv_attn; + const std::unique_ptr kv_recurrent; +}; + +class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { +public: + // init failure + explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); + + // init full + explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); + + // init success + llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_sbatch sbatch, + std::vector heads_attn, + std::vector ubatches); + + ~llama_kv_cache_hybrid_recurrent_state() = default; + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_hybrid_recurrent_state_i + // + + const llama_kv_cache_unified_state * get_state_attn () const; + const llama_kv_cache_recurrent_state * get_state_recurrent() const; + +private: + const llama_memory_status status; + + llama_kv_cache_hybrid_recurrent * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads_attn; + std::vector ubatches; + + const llama_kv_cache_unified_state state_attn; + const llama_kv_cache_recurrent_state state_recurrent; +}; From 1ee27a4ce69d2d26cb2e703f756f900accc3a519 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 28 May 2025 08:57:18 -0600 Subject: [PATCH 36/92] feat: Construct hybrid recurrent cache for hybrid recurrent models This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 110 +++++++++++++++++++++++++------------------- 1 file changed, 63 insertions(+), 47 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b9e9d76b9ac5d..01d708852fd77 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9,6 +9,7 @@ #include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-hybrid-recurrent.h" #include "ggml-cpp.h" @@ -13202,6 +13203,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_i * res; switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: @@ -13210,58 +13213,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - res = new llama_kv_cache_recurrent( - *this, - nullptr, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max); - } break; + // Models that need standard caching should rely on recurrent/hybrid + // checks default: { - const auto padding = llama_kv_cache_unified::get_padding(cparams); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); - - res = new llama_kv_cache_unified_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.n_ctx, - cparams.n_seq_max, - cparams.n_ubatch, - padding); - } else { - GGML_ASSERT(!hparams.is_swa_any()); - - res = new llama_kv_cache_unified( + if (llm_arch_is_recurrent(arch)) { + res = new llama_kv_cache_recurrent( *this, nullptr, - params.type_k, - params.type_v, - !cparams.flash_attn, + GGML_TYPE_F32, + GGML_TYPE_F32, cparams.offload_kqv, - cparams.n_ctx, - cparams.n_seq_max, - padding, - hparams.n_swa, - hparams.swa_type); + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max); + } else if (llm_arch_is_hybrid_recurrent(arch)) { + res = new llama_kv_cache_hybrid_recurrent( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams), + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv); + } else { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.n_ctx, + cparams.n_seq_max, + cparams.n_ubatch, + padding); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } } } } From 781d5aacc03988f1e31119061833359cff576fa9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 28 May 2025 11:02:54 -0600 Subject: [PATCH 37/92] fix: Fix wrong bool condition for split equal in hybrid cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index bd2323762f37f..beadcee7ba3d1 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -96,7 +96,7 @@ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) cons llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { // since this includes a recurrent cache, we cannot use split_simple - auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); // follow the recurrent pattern for creating the ubatch splits std::vector ubatches; From a3f58d9ed23b242277b7bf3aae0a8277552ccd25 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 3 Jun 2025 16:29:40 -0600 Subject: [PATCH 38/92] fix: Fix shift logic to defer to unified cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index beadcee7ba3d1..a6468482dae5d 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -150,8 +150,8 @@ void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { } bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { - // TODO: Should this return true if the attention cache can shift? - return false; + // Shifting is trivially supported for recurrent + return kv_attn->get_can_shift(); } void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { From c8c9aa2c125e5a8003531067ff1eae3fc4c033ab Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 4 Jun 2025 08:47:55 -0600 Subject: [PATCH 39/92] feat: Support hybrid recurrent in llama-graph NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 52 +++++++++++++++++++++++++++++++++++++++++++-- src/llama-graph.h | 30 ++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e74c9ff53b05a..07d0264da3174 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-hybrid-recurrent.h" #include #include @@ -396,6 +397,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { +} + // // llm_graph_context // @@ -954,8 +962,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const auto * kv_state = static_cast(mstate); +ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { + if (kv_state == nullptr) { + kv_state = static_cast(mstate); + } auto inp = std::make_unique(kv_state); @@ -1284,6 +1294,44 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); + + const auto n_kv = kv_state->get_state_attn()->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + return build_attn( + static_cast(inp), + gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il + ); +} + llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { const auto * kv_state = static_cast(mstate); diff --git a/src/llama-graph.h b/src/llama-graph.h index 88fb77f1ddc9a..7aca98d8a8249 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -22,6 +22,7 @@ struct llama_memory_state_i; class llama_kv_cache_unified_state; class llama_kv_cache_unified_iswa_state; class llama_kv_cache_recurrent_state; +class llama_kv_cache_hybrid_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -242,7 +243,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { cparams(cparams), kv_state(kv_state) { } - ~llm_graph_input_attn_kv_unified() = default; + virtual ~llm_graph_input_attn_kv_unified() = default; void set_input(const llama_ubatch * ubatch) override; @@ -285,6 +286,16 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_kv_cache_unified_iswa_state * kv_state; }; +class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { +public: + llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state); + + virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; +}; + class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -508,7 +519,7 @@ struct llm_graph_context { ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; - ggml_tensor * build_inp_s_copy() const; + ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -574,6 +585,21 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( From e77503f65acc7d9952350f9dc477c43f6463ec44 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 4 Jun 2025 15:02:14 -0600 Subject: [PATCH 40/92] fix: Fix logic for initializing inputs and attn layers for hybrid caches Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 53 +++++++++++++++------------------------------ src/llama-graph.h | 38 +++++++------------------------- 2 files changed, 25 insertions(+), 66 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 07d0264da3174..fc530f7e08683 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -397,13 +397,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state) : - llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { -} - // // llm_graph_context // @@ -1262,7 +1255,9 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state = static_cast(mstate); + // NOTE: For hybrid caches, this may be a child of mstate, so we use the one + // encapsulated in inp + const auto * kv_state = inp->kv_state; // store to KV cache { @@ -1294,10 +1289,10 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_state); + auto inp = std::make_unique(hparams, cparams, kv_state->get_state_attn()); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); @@ -1311,25 +1306,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } - return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); -} - -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - return build_attn( - static_cast(inp), - gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il - ); + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { @@ -1472,13 +1449,17 @@ ggml_tensor * llm_graph_context::build_attn( } ggml_tensor * llm_graph_context::build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies, + const llama_kv_cache_recurrent_state * kv_state) const { + + if (kv_state == nullptr) { + kv_state = static_cast(mstate); + } const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); diff --git a/src/llama-graph.h b/src/llama-graph.h index 7aca98d8a8249..d0810f12acb1c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -286,16 +286,6 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_kv_cache_unified_iswa_state * kv_state; }; -class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { -public: - llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state); - - virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; -}; - class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -585,20 +575,7 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; - - ggml_tensor * build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] - ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il) const; + llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const; llm_graph_input_attn_cross * build_attn_inp_cross() const; @@ -620,12 +597,13 @@ struct llm_graph_context { // ggml_tensor * build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies = false) const; + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false, + const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, From d327af7a1b50185e21c31f6189fe18a0cd9d7f17 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Jun 2025 14:07:07 -0600 Subject: [PATCH 41/92] fix: Update recurrent cache for changes to remove intermediate kv_cache interface Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 129 ++++++++++++------------ src/llama-kv-cache-hybrid-recurrent.h | 50 ++++----- 2 files changed, 93 insertions(+), 86 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index a6468482dae5d..ea228a834397f 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -49,50 +49,6 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( n_seq_max )) {} -void llama_kv_cache_hybrid_recurrent::clear() { - kv_attn ->clear(); - kv_recurrent->clear(); -} - -bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - // Try removing from the recurrent cache first since it may fail. If it does - // fail, the cache will not have been mutated. - if (!kv_recurrent->seq_rm(seq_id, p0, p1)) { - return false; - } - return kv_attn->seq_rm(seq_id, p0, p1); -} - -void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1); - kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) { - kv_attn ->seq_keep(seq_id); - kv_recurrent->seq_keep(seq_id); -} - -void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - kv_attn->seq_add(seq_id, p0, p1, shift); - kv_recurrent->seq_add(seq_id, p0, p1, shift); -} - -void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - kv_attn ->seq_div(seq_id, p0, p1, d); - kv_recurrent->seq_div(seq_id, p0, p1, d); -} - -llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const { - // the min of the total cache is the max of the two caches' min values - return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id)); -} - -llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const { - // the max of the total cache is the min of the two caches' max values - return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); -} - llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { // since this includes a recurrent cache, we cannot use split_simple @@ -135,23 +91,59 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { return std::make_unique(this); } -bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) { - bool res = false; +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { + return std::make_unique( + this, + static_cast( kv_attn ->init_update(lctx, optimize).release()), + static_cast(kv_recurrent->init_update(lctx, optimize).release())); +} + +bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { + // Shifting is trivially supported for recurrent + return kv_attn->get_can_shift(); +} +void llama_kv_cache_hybrid_recurrent::clear() { + kv_attn ->clear(); + kv_recurrent->clear(); +} - res = res | kv_attn ->update(lctx); - res = res | kv_recurrent->update(lctx); +bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!kv_recurrent->seq_rm(seq_id, p0, p1)) { + return false; + } + return kv_attn->seq_rm(seq_id, p0, p1); +} - return res; +void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { - kv_attn ->defrag_sched(thold); - kv_recurrent->defrag_sched(thold); +void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) { + kv_attn ->seq_keep(seq_id); + kv_recurrent->seq_keep(seq_id); } -bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { - // Shifting is trivially supported for recurrent - return kv_attn->get_can_shift(); +void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_attn->seq_add(seq_id, p0, p1, shift); + kv_recurrent->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_attn ->seq_div(seq_id, p0, p1, d); + kv_recurrent->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id)); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); } void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { @@ -173,13 +165,24 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c } llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) - : status(status), state_attn(status), state_recurrent(status) {} + : status(status), + state_attn(new llama_kv_cache_unified_state(status)), + state_recurrent(new llama_kv_cache_recurrent_state(status)) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), - state_attn(status, kv->get_kv_attn()), - state_recurrent(status, kv->get_kv_recurrent()) {} + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), + state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + kv(kv), + state_attn(state_unified), + state_recurrent(state_recurrent) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, @@ -194,8 +197,8 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( // NOTE: these child states are only used as wrapper APIs for the // const methods, so we use the "init full" signature since the // actual state is not used. - state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()), - state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {} + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), + state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {} bool llama_kv_cache_hybrid_recurrent_state::next() { @@ -232,10 +235,10 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { return ubatches[i_next]; } -const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const { - return &state_attn; +const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { + return state_attn.get(); } const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { - return &state_recurrent; + return state_recurrent.get(); } diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 692079b650027..e504631e47ae4 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -2,9 +2,10 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache.h" #include "llama-kv-cache-recurrent.h" #include "llama-kv-cache-unified.h" +#include "llama-kv-cells.h" +#include "llama-memory.h" #include #include @@ -16,7 +17,7 @@ // utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to // support models where each layer may be either attention-based or recurrent -class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { +class llama_kv_cache_hybrid_recurrent : public llama_memory_i { public: llama_kv_cache_hybrid_recurrent( const llama_model & model, @@ -42,21 +43,6 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { // llama_memory_i // - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, @@ -65,12 +51,21 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -92,12 +87,21 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { public: + using llama_kv_cache_unified_state_ptr = std::unique_ptr; + using llama_kv_cache_recurrent_state_ptr = std::unique_ptr; + // init failure explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); // init full explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); + // init update + explicit llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent); + // init success llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, @@ -116,7 +120,7 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_hybrid_recurrent_state_i + // llama_kv_cache_hybrid_recurrent_state // const llama_kv_cache_unified_state * get_state_attn () const; @@ -135,6 +139,6 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { std::vector heads_attn; std::vector ubatches; - const llama_kv_cache_unified_state state_attn; - const llama_kv_cache_recurrent_state state_recurrent; + const llama_kv_cache_unified_state_ptr state_attn; + const llama_kv_cache_recurrent_state_ptr state_recurrent; }; From ee9b31c24bcd09c330eae93fd0038ac4686a2267 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Jun 2025 14:41:08 -0600 Subject: [PATCH 42/92] fix: Fix status for init_update sig for recurrent cache state Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index ea228a834397f..d269a6b5057af 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -179,7 +179,7 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent) - : status(LLAMA_MEMORY_STATUS_SUCCESS), + : status(LLAMA_MEMORY_STATUS_NO_UPDATE), kv(kv), state_attn(state_unified), state_recurrent(state_recurrent) {} From c3c0ee61d80879df1790f8a3963f3af87c9eef3c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Jun 2025 15:54:50 -0600 Subject: [PATCH 43/92] fix: Add missing padding to n_ctx for hybrid cache construction Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 01d708852fd77..89cabfe42edb7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13227,13 +13227,17 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max); } else if (llm_arch_is_hybrid_recurrent(arch)) { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + res = new llama_kv_cache_hybrid_recurrent( /* model */ *this, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams), + /* attn_n_pad */ padding, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, From df695b38d84d78b9c06751b44748d0e1c13f99ec Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 6 Jun 2025 09:38:10 -0600 Subject: [PATCH 44/92] fix: Update clear signature for data argument after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 7 ++++--- src/llama-kv-cache-hybrid-recurrent.h | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index d269a6b5057af..8871dbf631164 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -102,9 +102,10 @@ bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { // Shifting is trivially supported for recurrent return kv_attn->get_can_shift(); } -void llama_kv_cache_hybrid_recurrent::clear() { - kv_attn ->clear(); - kv_recurrent->clear(); + +void llama_kv_cache_hybrid_recurrent::clear(bool data) { + kv_attn ->clear(data); + kv_recurrent->clear(data); } bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index e504631e47ae4..8728fd733cc28 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -55,7 +55,7 @@ class llama_kv_cache_hybrid_recurrent : public llama_memory_i { bool get_can_shift() const override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; From 9520af22e41acf2a45c574f3d5c4bbf545758cdf Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Jun 2025 16:26:31 -0600 Subject: [PATCH 45/92] fix: Remove errant virtual destructor leftover from previous impl attempt Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-graph.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-graph.h b/src/llama-graph.h index d0810f12acb1c..1896f9b417526 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -243,7 +243,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { cparams(cparams), kv_state(kv_state) { } - virtual ~llm_graph_input_attn_kv_unified() = default; + ~llm_graph_input_attn_kv_unified() = default; void set_input(const llama_ubatch * ubatch) override; From 6361a732ca20ca6c6320a6179271602975d7f8bd Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Jun 2025 16:30:49 -0600 Subject: [PATCH 46/92] fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 89cabfe42edb7..27462dbf3b0a0 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8939,11 +8939,11 @@ struct llm_build_mamba : public llm_graph_context { // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state( gf, conv_states_all, state_copy, - hparams.n_embd_k_s(), n_seqs); + hparams.n_embd_k_s(il), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); ggml_tensor * ssm = build_recurrent_state( gf, ssm_states_all, state_copy, - hparams.n_embd_v_s(), n_seqs); + hparams.n_embd_v_s(il), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} From b11505274e952ba9d1b296d013f2ce7548267608 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 12:20:04 -0600 Subject: [PATCH 47/92] refactor: Remove n_embd_k/v_s from unified cache No longer needed now that unified isn't also supporting recurrent https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140761069 Branch: HybridRecurrentCache --- src/llama-kv-cache-unified.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 10e3279d70dad..6a0cd9bbce6a0 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); const char * dev_name = "CPU"; @@ -1429,7 +1429,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Write key type const int32_t k_type_i = (int32_t)layer.k->type; @@ -1451,7 +1451,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1475,7 +1475,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1620,7 +1620,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Read type of key int32_t k_type_i_ref; @@ -1650,7 +1650,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -1680,7 +1680,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; From 869049531122a7a04a920ea20445025c2583e27c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 12:20:47 -0600 Subject: [PATCH 48/92] refactor: Remove layer index from n_embd_k/v_s Now that it's not used at all in the unified cache, we don't need to use the layer index to zero it out for attention layers. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-hparams.cpp | 10 ++-------- src/llama-hparams.h | 4 ++-- src/llama-kv-cache-recurrent.cpp | 16 ++++++++-------- src/llama-model.cpp | 4 ++-- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 70a7114f39715..0ec3a8a501f76 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,10 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { - if (!recurrent_layer(il)) { - return 0; - } +uint32_t llama_hparams::n_embd_k_s() const { if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -79,10 +76,7 @@ uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { - if (!recurrent_layer(il)) { - return 0; - } +uint32_t llama_hparams::n_embd_v_s() const { if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 3614596464318..84234494c5611 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -184,10 +184,10 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s(uint32_t il = 0) const; + uint32_t n_embd_k_s() const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s(uint32_t il = 0) const; + uint32_t n_embd_v_s() const; // whether or not the given layer is recurrent (for hybrid models) bool recurrent_layer(uint32_t il) const; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 345de647b17fb..7e71719251142 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); const char * dev_name = "CPU"; @@ -756,7 +756,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; @@ -776,7 +776,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -797,7 +797,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -944,7 +944,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Read type of key int32_t k_type_i_ref; @@ -972,7 +972,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -1000,7 +1000,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 27462dbf3b0a0..89cabfe42edb7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8939,11 +8939,11 @@ struct llm_build_mamba : public llm_graph_context { // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state( gf, conv_states_all, state_copy, - hparams.n_embd_k_s(il), n_seqs); + hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); ggml_tensor * ssm = build_recurrent_state( gf, ssm_states_all, state_copy, - hparams.n_embd_v_s(il), n_seqs); + hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} From fa6dbe8422dc772c715f88f71505a0d2f9cbc847 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 12:56:26 -0600 Subject: [PATCH 49/92] refactor: Remove n_embd_k/v_gqa from recurrent cache This is no longer needed now that there are separate implementations https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-recurrent.cpp | 39 +++++++++++++------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 7e71719251142..ed3b9031e36d0 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - const char * dev_name = "CPU"; ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); @@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); k_l[i] = k; @@ -756,14 +753,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out @@ -776,14 +772,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out @@ -797,7 +792,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_s = hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -808,10 +803,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + io.write(&n_embd_v_s, sizeof(n_embd_v_s)); // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + for (uint32_t j = 0; j < n_embd_v_s; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; @@ -944,7 +939,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Read type of key int32_t k_type_i_ref; @@ -958,7 +952,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -972,7 +966,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -986,7 +979,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -1000,7 +993,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_s = hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -1020,17 +1013,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return false; } - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + // Read state embedding size + uint32_t n_embd_v_s_ref; + io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref)); + if (n_embd_v_s != n_embd_v_s_ref) { + LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il); return false; } if (cell_count) { // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + for (uint32_t j = 0; j < n_embd_v_s; ++j) { const size_t dst_offset = (head + j * size) * v_size_el; ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } From 42459d0c6af75a6750de85b7c3cf01b3ffa929c1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 13:41:52 -0600 Subject: [PATCH 50/92] feat: Allow custom layer filters for hybrid recurrent This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 41 +++++++++++++++---------- src/llama-kv-cache-hybrid-recurrent.h | 37 +++++++++++++--------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 8871dbf631164..889f43025a3e3 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -10,25 +10,30 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( const llama_model & model, - /* attn */ - ggml_type attn_type_k, - ggml_type attn_type_v, - bool attn_v_trans, - uint32_t attn_kv_size, - uint32_t attn_n_pad, - uint32_t attn_n_swa, - llama_swa_type attn_swa_type, - /* recurrent */ - ggml_type recurrent_type_k, - ggml_type recurrent_type_v, - uint32_t recurrent_kv_size, - /* common */ - uint32_t n_seq_max, - bool offload) : + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter, + layer_filter_cb && recurrent_filter) : hparams(model.hparams), kv_attn(new llama_kv_cache_unified( model, - [&](int32_t il) { return !model.hparams.recurrent_layer(il); }, + attn_filter == nullptr ? + [&](int32_t il) { return !model.hparams.recurrent_layer(il); } + : attn_filter, attn_type_k, attn_type_v, attn_v_trans, @@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( )), kv_recurrent(new llama_kv_cache_recurrent( model, - [&](int32_t il) { return model.hparams.recurrent_layer(il); }, + recurrent_filter == nullptr ? + [&](int32_t il) { return model.hparams.recurrent_layer(il); } + : recurrent_filter, recurrent_type_k, recurrent_type_v, offload, diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 8728fd733cc28..444e87e101136 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -19,23 +19,30 @@ class llama_kv_cache_hybrid_recurrent : public llama_memory_i { public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_hybrid_recurrent( const llama_model & model, - /* attn */ - ggml_type attn_type_k, - ggml_type attn_type_v, - bool attn_v_trans, - uint32_t attn_kv_size, - uint32_t attn_n_pad, - uint32_t attn_n_swa, - llama_swa_type attn_swa_type, - /* recurrent */ - ggml_type recurrent_type_k, - ggml_type recurrent_type_v, - uint32_t recurrent_kv_size, - /* common */ - uint32_t n_seq_max, - bool offload); + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter = nullptr, + layer_filter_cb && recurrent_filter = nullptr); ~llama_kv_cache_hybrid_recurrent() = default; From 74ad4f8f324a7461baa7534cb36ff0661aadf4e7 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 14:00:53 -0600 Subject: [PATCH 51/92] fix: Remove logits_all after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 4 ++-- src/llama-kv-cache-hybrid-recurrent.h | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 889f43025a3e3..49a7c35ab8cfa 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -56,10 +56,10 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( n_seq_max )) {} -llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { // since this includes a recurrent cache, we cannot use split_simple - auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); // follow the recurrent pattern for creating the ubatch splits std::vector ubatches; diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 444e87e101136..d6678eb2164fa 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -53,8 +53,7 @@ class llama_kv_cache_hybrid_recurrent : public llama_memory_i { llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; + bool embd_pooled) override; llama_memory_state_ptr init_full() override; From b3c948a5e742592737ecb8575b93931e10107d83 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 14:01:28 -0600 Subject: [PATCH 52/92] fix: Remove llama_model_is_hybrid_Recurrent public API https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2141728423 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- include/llama.h | 3 --- src/llama-model.cpp | 4 ---- 2 files changed, 7 deletions(-) diff --git a/include/llama.h b/include/llama.h index 5a3284339fb96..015a57898e22d 100644 --- a/include/llama.h +++ b/include/llama.h @@ -569,9 +569,6 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); - // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) - LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model); - // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 89cabfe42edb7..3dc200ddac20c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13849,10 +13849,6 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } -bool llama_model_is_hybrid_recurrent(const llama_model * model) { - return llm_arch_is_hybrid_recurrent(model->arch); -} - const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } From f3e34bb378dcd126117c13ec7a3d3b992b6c3ea6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 14:30:21 -0600 Subject: [PATCH 53/92] refactor: Use llama_memory_state_ptr for child states in hybrid memory state Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 4 ++-- src/llama-kv-cache-hybrid-recurrent.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 49a7c35ab8cfa..a2afda7647b00 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -244,9 +244,9 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { } const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { - return state_attn.get(); + return static_cast(state_attn.get()); } const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { - return state_recurrent.get(); + return static_cast(state_recurrent.get()); } diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index d6678eb2164fa..93bf72ec34837 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -145,6 +145,6 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { std::vector heads_attn; std::vector ubatches; - const llama_kv_cache_unified_state_ptr state_attn; - const llama_kv_cache_recurrent_state_ptr state_recurrent; + const llama_memory_state_ptr state_attn; + const llama_memory_state_ptr state_recurrent; }; From 6253c7c87e1e1c3c9630171411bc0a4e470a7470 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 17:04:27 -0600 Subject: [PATCH 54/92] feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738 This is a big overhaul to bring consistency between how inputs and per- layer components are created for attention layers and recurrent layers. The main changes are: - Rename class llm_graph_input_s_copy -> llm_graph_input_rs - Add a corresponding llm_graph_input_rs_hybrid_recurrent - Rename build_inp_s_copy -> build_rs_inp_recurrent - Add a corresponding build_rs_inp_hybrid_recurrent - Rename build_recurrent_state -> build_rs to match build_attn w/ llm_graph_input_rs * as the first input - Add a corresponding overload of build_rs w/ llm_graph_input_rs_hybrid_recurrent * as the first input - Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to llm_graph_input_attn_kv_unified - Add a build_attn override that takes llm_graph_input_attn_kv_hybrid_recurrent * as the first input This makes the two paradigms fully consistent. The main drawback is the code duplication in the build_attn and build_rs implementations where the only difference between implementations is how they cast the memory state. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 203 ++++++++++++++++++++++++++++++++++---------- src/llama-graph.h | 71 ++++++++++++---- src/llama-model.cpp | 66 +++++++------- 3 files changed, 240 insertions(+), 100 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fc530f7e08683..e86a602a94660 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -235,7 +235,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { +void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); const int64_t n_kv = kv_state->get_n_kv(); @@ -251,6 +251,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent( + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_rs(kv_state->get_state_recurrent()) { +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -354,6 +359,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { +} + void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); @@ -955,25 +967,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { - if (kv_state == nullptr) { - kv_state = static_cast(mstate); - } - - auto inp = std::make_unique(kv_state); - - const auto n_kv = kv_state->get_n_kv(); - - auto & cur = inp->s_copy; - - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); - ggml_set_input(cur); - - res->add_input(std::move(inp)); - - return cur; -} - ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); @@ -1255,9 +1248,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - // NOTE: For hybrid caches, this may be a child of mstate, so we use the one - // encapsulated in inp - const auto * kv_state = inp->kv_state; + const auto * kv_state = static_cast(mstate); // store to KV cache { @@ -1289,15 +1280,14 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { - const auto * kv_state = static_cast(mstate); - - auto inp = std::make_unique(hparams, cparams, kv_state->get_state_attn()); +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { + auto inp = std::make_unique( + hparams, cparams, static_cast(mstate)); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - const auto n_kv = kv_state->get_state_attn()->get_n_kv(); + const auto n_kv = inp->kv_state->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1306,7 +1296,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } - return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto * kv_state = static_cast(mstate)->get_state_attn(); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { @@ -1448,19 +1488,90 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies, - const llama_kv_cache_recurrent_state * kv_state) const { +llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(kv_state); + + const auto n_kv = kv_state->get_n_kv(); + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + return (llm_graph_input_rs *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_rs( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { + + const auto * kv_state = static_cast(mstate); + + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); + const auto rs_zero = kv_state->get_rs_z(); + + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); + + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + + ggml_tensor * output_states; - if (kv_state == nullptr) { - kv_state = static_cast(mstate); + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; } + // copy extra states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + + return output_states; +} + +llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { + auto inp = std::make_unique( + static_cast(mstate)); + + const auto n_kv = inp->kv_state->get_n_kv(); + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_rs( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { + + const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); const auto rs_zero = kv_state->get_rs_z(); @@ -1478,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state( // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); ggml_build_forward_expand(gf, output_states); } else { // FIXME: make the gathering operation happen before the copy below @@ -1487,7 +1598,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state( } // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, @@ -1497,9 +1608,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state( } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( - ggml_cgraph * gf, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -1509,8 +1620,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = kv_state->get_k_l(il); - ggml_tensor * token_shift = build_recurrent_state( - gf, token_shift_all, state_copy, + ggml_tensor * token_shift = build_rs( + inp, gf, token_shift_all, hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); diff --git a/src/llama-graph.h b/src/llama-graph.h index 1896f9b417526..924f5bac829df 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -189,10 +189,10 @@ class llm_graph_input_cls : public llm_graph_input_i { const llama_cparams & cparams; }; -class llm_graph_input_s_copy : public llm_graph_input_i { +class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} - virtual ~llm_graph_input_s_copy() = default; + llm_graph_input_rs(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} + virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -201,6 +201,12 @@ class llm_graph_input_s_copy : public llm_graph_input_i { const llama_kv_cache_recurrent_state * kv_state; }; +class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs { +public: + llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state); + virtual ~llm_graph_input_rs_hybrid_recurrent() = default; +}; + class llm_graph_input_cross_embd : public llm_graph_input_i { public: llm_graph_input_cross_embd( @@ -258,6 +264,15 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { const llama_kv_cache_unified_state * kv_state; }; +class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { +public: + llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state); + virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; +}; + class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_unified_iswa( @@ -509,7 +524,6 @@ struct llm_graph_context { ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; - ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -575,8 +589,6 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const; - llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( @@ -592,23 +604,48 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; // // recurrent // - ggml_tensor * build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies = false, - const llama_kv_cache_recurrent_state * kv_state = nullptr) const; + llm_graph_input_rs * build_rs_inp_recurrent() const; + + ggml_tensor * build_rs( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; + + llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const; + + ggml_tensor * build_rs( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; ggml_tensor * build_rwkv_token_shift_load( - ggml_cgraph * gf, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + const llama_ubatch & ubatch, int il) const; ggml_tensor * build_rwkv_token_shift_store( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 3dc200ddac20c..858a4d05b47af 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8861,7 +8861,7 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); for (int il = 0; il < n_layer; ++il) { // norm @@ -8870,7 +8870,7 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); + cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -8908,11 +8908,11 @@ struct llm_build_mamba : public llm_graph_context { // TODO: split ggml_tensor * build_mamba_layer( - ggml_cgraph * gf, - ggml_tensor * cur, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, - int il) const { + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { const auto * kv_state = static_cast(mstate); const auto kv_head = kv_state->get_head(); @@ -8937,12 +8937,12 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states - ggml_tensor * conv = build_recurrent_state( - gf, conv_states_all, state_copy, + ggml_tensor * conv = build_rs( + inp, gf, conv_states_all, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_recurrent_state( - gf, ssm_states_all, state_copy, + ggml_tensor * ssm = build_rs( + inp, gf, ssm_states_all, hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); @@ -11654,10 +11654,10 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * build_rwkv6_time_mix( + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, - ggml_tensor * state_copy, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -11781,8 +11781,8 @@ struct llm_build_rwkv6_base : public llm_graph_context { k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); } - ggml_tensor * wkv_state = build_recurrent_state( - gf, kv_state->get_v_l(il), state_copy, + ggml_tensor * wkv_state = build_rs( + inp, gf, kv_state->get_v_l(il), hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11837,7 +11837,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11847,9 +11847,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -11864,7 +11862,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -11934,7 +11932,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11944,9 +11942,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -11958,7 +11954,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -12046,10 +12042,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { } ggml_tensor * build_rwkv7_time_mix( + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, - ggml_tensor * state_copy, ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { @@ -12132,8 +12128,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); - ggml_tensor * wkv_state = build_recurrent_state( - gf, kv_state->get_v_l(il), state_copy, + ggml_tensor * wkv_state = build_rs( + inp, gf, kv_state->get_v_l(il), hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12190,7 +12186,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12200,9 +12196,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -12217,7 +12211,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12283,7 +12277,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12293,9 +12287,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -12307,7 +12299,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); From 0385fc02196511e671a976d3732587fb8445f28b Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Mon, 16 Jun 2025 13:18:00 +0400 Subject: [PATCH 55/92] some changes --- src/llama-arch.cpp | 62 +++- src/llama-arch.h | 25 ++ src/llama-graph.cpp | 40 +++ src/llama-graph.h | 9 + src/llama-hparams.h | 38 ++- src/llama-kv-cache-hybrid-recurrent.cpp | 2 +- src/llama-kv-cache-recurrent.cpp | 9 +- src/llama-model.cpp | 414 +++++++++++++++++++++++- src/llama-model.h | 8 + src/llama-vocab.cpp | 1 + 10 files changed, 589 insertions(+), 19 deletions(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 4150e69c46c4d..017b2f958b997 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -73,6 +73,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, + { LLM_ARCH_FALCON_H1, "falcon-h1" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -122,6 +123,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_ATTN_HEAD_DIM, "%s.attention.head_dim" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -162,12 +164,30 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SPLIT_COUNT, "split.count" }, { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, - { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, - { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, - { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, - { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, + { LLM_KV_SSM_HEAD_DIM, "%s.ssm.head_dim" }, + { LLM_KV_MAMBA_D_SSM, "%s.ssm.mamba_d_ssm" }, + + { LLM_KV_FALCON_H1_USE_MLP, "%s.mamba_use_mlp" }, + { LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, "%s.attention_in_multiplier" }, + { LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, "%s.attention_out_multiplier" }, + { LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, "%s.ssm_in_multiplier" }, + { LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, "%s.ssm_out_multiplier" }, + { LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, "%s.mlp_gate_multiplier" }, + { LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, "%s.mlp_down_multiplier" }, + { LLM_KV_FALCON_H1_SSM_HAS_MUP, "%s.ssm.has_mup" }, + { LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, "%s.mamba_norm_before_gate" }, + { LLM_KV_FALCON_H1_MAMBA_RMS_NORM, "%s.mamba_rms_norm" }, + { LLM_KV_FALCON_H1_ROPE_THETA, "%s.rope_theta" }, + { LLM_KV_FALCON_H1_KEY_MULTIPLIER, "%s.key_multiplier" }, + { LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, "%s.lm_head_multiplier" }, + { LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, "%s.embedding_multiplier" }, + { LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, "%s.ssm.mamba_chunk_size" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -334,6 +354,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_FALCON_H1, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_MUP_VEC, "blk.%d.ssm_mup_vec" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_FFN_PRE_NORM, "blk.%d.ffn_pre_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_GROK, { @@ -1591,6 +1636,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, + {LLM_TENSOR_FINAL_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ROPE_FREQS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ROPE}}, @@ -1660,6 +1706,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_MUP_VEC, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1683,6 +1730,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_ATTN_OUT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_FFN_PRE_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_FFN_NORM_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_ATTN_Q_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, @@ -1791,6 +1839,8 @@ bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) { // TODO: There are currently no hybrid models! Once there are, this will be // the place to identify them switch (arch) { + case LLM_ARCH_FALCON_H1: + return true; default: return false; } diff --git a/src/llama-arch.h b/src/llama-arch.h index 9331e5bf5f49d..3ac64754e7299 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -13,6 +13,7 @@ enum llm_arch { LLM_ARCH_LLAMA4, LLM_ARCH_DECI, LLM_ARCH_FALCON, + LLM_ARCH_FALCON_H1, LLM_ARCH_BAICHUAN, LLM_ARCH_GROK, LLM_ARCH_GPT2, @@ -151,6 +152,27 @@ enum llm_kv { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, LLM_KV_ATTENTION_LAYER_INDICES, + // Falcon-H1 specific + LLM_KV_ATTN_HEAD_DIM, + LLM_KV_SSM_HEAD_DIM, + LLM_KV_MAMBA_D_SSM, + LLM_KV_N_LAYER, + LLM_KV_FALCON_H1_USE_MLP, + LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, + LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, + LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, + LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, + LLM_KV_FALCON_H1_SSM_HAS_MUP, + LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, + LLM_KV_FALCON_H1_MAMBA_RMS_NORM, + LLM_KV_FALCON_H1_ROPE_THETA, + LLM_KV_FALCON_H1_KEY_MULTIPLIER, + LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, + LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, + LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, + LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, LLM_KV_ROPE_FREQ_BASE, @@ -367,6 +389,9 @@ enum llm_tensor { LLM_TENSOR_POS_NET_ATTN_K, LLM_TENSOR_POS_NET_ATTN_V, LLM_TENSOR_POS_NET_ATTN_OUT, + LLM_TENSOR_SSM_MUP_VEC, + LLM_TENSOR_FFN_PRE_NORM, + LLM_TENSOR_FINAL_NORM, }; enum llm_tensor_layer { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 99a6b129c6f9e..7c422e4e93d42 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1597,6 +1597,46 @@ ggml_tensor * llm_graph_context::build_rs( return output_states; } +// ggml_tensor * llm_graph_context::build_rs( +// llm_graph_input_attn_kv_hybrid_recurrent * inp, +// ggml_cgraph * gf, +// ggml_tensor * s, +// int32_t state_size, +// int32_t n_seqs, +// const std::function +// & get_state_rows) const { + +// const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + +// const auto n_kv = kv_state->get_n_kv(); +// const auto kv_head = kv_state->get_head(); +// const auto rs_zero = kv_state->get_rs_z(); + +// ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); + +// // Clear a single state which will then be copied to the other cleared states. +// // Note that this is a no-op when the view is zero-sized. +// ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); +// ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + +// ggml_tensor * output_states; + +// // copy states +// // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv +// // {state_size, kv_size} -> {state_size, n_seqs} +// output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); +// ggml_build_forward_expand(gf, output_states); + +// // copy extra states which won't be changed further (between n_seqs and n_kv) +// ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); +// ggml_build_forward_expand(gf, +// ggml_cpy(ctx0, +// states_extra, +// ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + +// return output_states; +// } + ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( llm_graph_input_rs * inp, ggml_cgraph * gf, diff --git a/src/llama-graph.h b/src/llama-graph.h index 24080928db95b..4e1e3ca9b6491 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -644,6 +644,15 @@ struct llm_graph_context { const std::function & get_state_rows = ggml_get_rows) const; + // ggml_tensor * build_rs( + // llm_graph_input_attn_kv_hybrid_recurrent * inp, + // ggml_cgraph * gf, + // ggml_tensor * s, + // int32_t state_size, + // int32_t n_seqs, + // const std::function + // & get_state_rows = ggml_get_rows) const; + ggml_tensor * build_rwkv_token_shift_load( llm_graph_input_rs * inp, ggml_cgraph * gf, diff --git a/src/llama-hparams.h b/src/llama-hparams.h index d192b66a5c056..ef758d5f603f9 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -110,16 +110,42 @@ struct llama_hparams { std::array swa_layers; // for State Space Models - uint32_t ssm_d_conv = 0; - uint32_t ssm_d_inner = 0; - uint32_t ssm_d_state = 0; - uint32_t ssm_dt_rank = 0; - uint32_t ssm_n_group = 0; + uint32_t ssm_d_conv = 0; + uint32_t ssm_d_inner = 0; + uint32_t ssm_d_state = 0; + uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; + bool ssm_dt_b_c_rms = false; + uint32_t ssm_head_dim = 0; + uint32_t ssm_mamba_d_ssm = 0; + + // Falcon-H1 specific parameters + uint32_t attn_head_dim = 0; + bool mamba_use_mlp = false; + bool mamba_norm_before_gate = false; + bool mamba_rms_norm = false; + float attention_in_multiplier = 1.0f; + float attention_out_multiplier = 1.0f; + float ssm_in_multiplier = 1.0f; + float ssm_out_multiplier = 1.0f; + float mlp_gate_multiplier = 1.0f; + float mlp_down_multiplier = 1.0f; + float key_multiplier = 1.0f; + float lm_head_multiplier = 1.0f; + float rope_theta = 10000.0f; + bool ssm_has_mup = false; + float embedding_multiplier = 1.0f; + uint32_t vocab_size = 0; + uint32_t intermediate_size = 0; + float mamba_expand = 0.0f; + bool ssm_rms_norm = false; + bool ssm_conv_bias = false; + bool ssm_proj_bias = false; + uint32_t chunk_size = 0; // for hybrid state space models std::array recurrent_layer_arr; - bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index a2afda7647b00..6855cd8f3ae88 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -32,7 +32,7 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( kv_attn(new llama_kv_cache_unified( model, attn_filter == nullptr ? - [&](int32_t il) { return !model.hparams.recurrent_layer(il); } + [&](int32_t il) { return model.hparams.recurrent_layer(il); } : attn_filter, attn_type_k, attn_type_v, diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index ed3b9031e36d0..d66be5f7f86fd 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -45,27 +45,26 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; - ggml_context * ctx = ggml_init(params); if (!ctx) { + std::printf("Failed to create ggml context for kv cache\n"); return nullptr; } - ctx_map[buft] = ctx; ctxs.emplace_back(ctx); return ctx; } - return it->second; }; - k_l.reserve(n_layer); - v_l.reserve(n_layer); + k_l.resize(n_layer); + v_l.resize(n_layer); for (int i = 0; i < n_layer; i++) { if (filter && !filter(i)) { LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); + std::printf("Entered here\n"); continue; } diff --git a/src/llama-model.cpp b/src/llama-model.cpp index b1545e1b26b0f..a49bbcf3ba2f1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -474,10 +474,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + // std::fill( + // hparams.recurrent_layer_arr.begin(), + // hparams.recurrent_layer_arr.end(), + // llm_arch_is_recurrent(ml.get_arch())); std::fill( hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), - llm_arch_is_recurrent(ml.get_arch())); + true); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); @@ -1485,6 +1489,53 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_FALCON_H1: + { + // Common parameters + ml.get_key(LLM_KV_VOCAB_SIZE, hparams.vocab_size); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // SSM parameters + ml.get_key(LLM_KV_MAMBA_D_SSM, hparams.ssm_mamba_d_ssm); + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + ml.get_key(LLM_KV_SSM_HEAD_DIM, hparams.ssm_head_dim); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_CHUNK_SIZE, hparams.chunk_size); + + // Falcon-H1 parameters + ml.get_key(LLM_KV_ATTN_HEAD_DIM, hparams.attn_head_dim); + ml.get_key(LLM_KV_FALCON_H1_USE_MLP, hparams.mamba_use_mlp); + ml.get_key(LLM_KV_FALCON_H1_ATTENTION_IN_MULTIPLIER, hparams.attention_in_multiplier); + ml.get_key(LLM_KV_FALCON_H1_ATTENTION_OUT_MULTIPLIER, hparams.attention_out_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_IN_MULTIPLIER, hparams.ssm_in_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_OUT_MULTIPLIER, hparams.ssm_out_multiplier); + ml.get_key(LLM_KV_FALCON_H1_MLP_GATE_MULTIPLIER, hparams.mlp_gate_multiplier); + ml.get_key(LLM_KV_FALCON_H1_MLP_DOWN_MULTIPLIER, hparams.mlp_down_multiplier); + ml.get_key(LLM_KV_FALCON_H1_SSM_HAS_MUP, hparams.ssm_has_mup); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_NORM_BEFORE_GATE, hparams.mamba_norm_before_gate); + ml.get_key(LLM_KV_FALCON_H1_MAMBA_RMS_NORM, hparams.mamba_rms_norm); + ml.get_key(LLM_KV_FALCON_H1_ROPE_THETA, hparams.rope_theta); + ml.get_key(LLM_KV_FALCON_H1_KEY_MULTIPLIER, hparams.key_multiplier); + ml.get_key(LLM_KV_FALCON_H1_LM_HEAD_MULTIPLIER, hparams.lm_head_multiplier); + ml.get_key(LLM_KV_FALCON_H1_EMBEDDING_MULTIPLIER, hparams.embedding_multiplier); + + switch (hparams.n_layer) { + case 36: + type = LLM_TYPE_0_5B; break; + case 24: + type = LLM_TYPE_1_5B; break; + case 66: + type = LLM_TYPE_1B; break; + case 32: + type = LLM_TYPE_3B; break; + case 44: + type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -4212,6 +4263,94 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); } } break; + case LLM_ARCH_FALCON_H1: + { + // Common + const float layer_norm_epsilon = hparams.f_norm_rms_eps; // TODO layer_norm_epsilon + const int64_t hidden_size = hparams.n_embd; // hidden_size + const int64_t vocab_size = hparams.vocab_size; // vocab_size + + // mamba2 Mixer SSM params + const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size + const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups + const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size + const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand + const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads + const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; + const int64_t ssm_head_dim = hparams.ssm_head_dim; // ssm_head_dim + const bool ssm_rms_norm = hparams.mamba_rms_norm; + const int64_t ssm_chunk_size = hparams.chunk_size; + const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; + const int64_t ssm_groups_time_state_size = ssm_n_groups * ssm_state_size; // groups_time_state_size + + // attn params + const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head + const int64_t attn_num_key_value_head = hparams.n_head_kv(0); + const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head; + const int64_t attn_num_key_value_groups = attn_num_attention_head / attn_num_key_value_head; + + // ffn params + const int64_t ffn_intermediate_size = hparams.n_ff(0); + + // embeddings + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, vocab_size}, 0); + + // output + { + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, vocab_size}, TENSOR_NOT_REQUIRED); + final_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + /*SSM LAYERS*/ + // ssm in + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0); + layer.ssm_in_b = create_tensor(tn(LLM_TENSOR_SSM_IN, "bias", i), {n_embd, ssm_projection_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + // ssm 1d conv + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + // ssm_dt + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0); + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0); + if (hparams.ssm_has_mup == true) { + layer.ssm_mup_vec = create_tensor(tn(LLM_TENSOR_SSM_MUP_VEC, i), {2*ssm_intermediate_size + 2*ssm_n_groups*ssm_state_size + ssm_num_heads}, 0); + } + // ssm_norm + if (hparams.mamba_rms_norm == true) { + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, 0); + } + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0); + + /*ATTENTION LAYERS*/ + // attention layers (with optional bias) + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, attn_head_dim * attn_num_attention_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * attn_head_dim}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {attn_head_dim * attn_num_attention_head, hidden_size}, 0); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * attn_head_dim}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0); + + + // feed forward (w/ optional biases) + layer.ffn_pre_norm = create_tensor(tn(LLM_TENSOR_FFN_PRE_NORM, i), {hidden_size}, 0); + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, llama_model_loader::TENSOR_NOT_REQUIRED | (i != 0 ? llama_model_loader::TENSOR_DUPLICATED : 0)); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0); + + layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, llama_model_loader::TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -5248,6 +5387,275 @@ struct llm_build_baichuan : public llm_graph_context { } }; +struct llm_build_falcon_h1 : public llm_graph_context { + const llama_model & model; + + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + inpL = ggml_scale(ctx0, inpL, hparams.embedding_multiplier); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_hybrid = build_attn_inp_kv_hybrid_recurrent(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + //std::printf("Here %d\n", il); + ggml_tensor * attn_out = build_attn(inp_hybrid, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + //std::printf("Here %d - after\n", il); + + cur = ggml_scale(ctx0, cur, hparams.attention_out_multiplier); + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + // Mamba2 layer + // std::printf("Here 2\n"); + // ggml_tensor * ssm_out = build_mamba2_layer(inp_hybrid, gf, cur, ubatch, il); + // std::printf("Here\n"); + // // Aggregation + // cur = ggml_add(ctx0, attn_out, ssm_out); + } + + + + if (il == n_layer - 1) { + // skip computing output for unused tokens + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cur = ggml_scale(ctx0, cur, hparams.lm_head_multiplier); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + + // ggml_tensor * build_mamba2_layer( + // llm_graph_input_attn_kv_hybrid_recurrent * inp, + // ggml_cgraph * gf, + // ggml_tensor * cur, + // const llama_ubatch & ubatch, + // int il) const { + // const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + + // const auto kv_head = kv_state->get_head(); + + // const int64_t d_conv = hparams.ssm_d_conv; + // const int64_t d_inner = hparams.ssm_d_inner; + // const int64_t d_state = hparams.ssm_d_state; + // const int64_t n_head = hparams.ssm_dt_rank; + // const int64_t head_dim = d_inner / n_head; + // const int64_t n_group = hparams.ssm_n_group; + // const int64_t n_seqs = ubatch.n_seqs; + + + + // const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + // GGML_ASSERT(n_seqs != 0); + // GGML_ASSERT(ubatch.equal_seqs); + // GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + + + // ggml_tensor * conv_states_all = kv_state->get_k_l(il); + // ggml_tensor * ssm_states_all = kv_state->get_v_l(il); + + // ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_k_s(), n_seqs); + // conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + // cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + // ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + // // split the above in three + // ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + // ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + // ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // // conv + // { + // // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + // ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + // // copy last (d_conv - 1) columns back into the state cache + // ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + // ggml_build_forward_expand(gf, + // ggml_cpy(ctx0, last_conv, + // ggml_view_1d(ctx0, conv_states_all, + // (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + // kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // // 1D convolution + // // The equivalent is to make a self-overlapping view of conv_x + // // over d_conv columns at each stride in the 3rd dimension, + // // then element-wise multiply that with the conv1d weight, + // // then sum the elements of each row, + // // (the last two steps are a dot product over rows (also doable with mul_mat)) + // // then permute away the ne[0] dimension, + // // and then you're left with the resulting x tensor. + // // For simultaneous sequences, all sequences need to have the same length. + // xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // // bias + // xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + // xBC = ggml_silu(ctx0, xBC); + // } + + // // ssm + // { + // // These correspond to V K Q in SSM/attention duality + // ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + // ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + // ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // // {n_head, n_seq_tokens, n_seqs} + // dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + // ggml_tensor * A = model.layers[il].ssm_a; + + // // use the states and the indices provided by build_rs + // // (this is necessary in order to properly use the states before they are overwritten, + // // while avoiding to make unnecessary copies of the states) + // auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + // ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size()); + + // // TODO: use semistructured matrices to implement state-space duality + // // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + // return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + // }; + + // ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_v_s(), ubatch.n_seqs, get_ssm_rows); + + // // store last states + // ggml_build_forward_expand(gf, + // ggml_cpy(ctx0, + // ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + // ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + // ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); + + // // TODO: skip computing output earlier for unused tokens + + // y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + // y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // // grouped RMS norm + // y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + // y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + // y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + + // // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + // cur = build_lora_mm(model.layers[il].ssm_out, y); + // } + + // // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + // cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + // // cb(cur, "mamba_out", il); + + // return cur; + // } +}; + struct llm_build_xverse : public llm_graph_context { llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -13777,6 +14185,10 @@ llm_graph_result_ptr llama_model::build_graph( { llm = std::make_unique(*this, params, gf); } break; + case LLM_ARCH_FALCON_H1: + { + llm = std::make_unique(*this, params, gf); + } break; default: GGML_ABORT("fatal error"); } diff --git a/src/llama-model.h b/src/llama-model.h index b3897dbde1032..60f8ebc412ffa 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -169,6 +169,7 @@ struct llama_layer { struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; struct ggml_tensor * ssm_norm = nullptr; + struct ggml_tensor * final_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; @@ -211,6 +212,7 @@ struct llama_layer { struct ggml_tensor * layer_out_norm_b = nullptr; struct ggml_tensor * ffn_norm_exps = nullptr; struct ggml_tensor * ffn_norm_enc = nullptr; + struct ggml_tensor * ffn_pre_norm = nullptr; // ff struct ggml_tensor * ffn_gate = nullptr; // w1 @@ -254,6 +256,10 @@ struct llama_layer { struct ggml_tensor * ssm_conv1d_b = nullptr; struct ggml_tensor * ssm_dt_b = nullptr; + // falcon-h1 + struct ggml_tensor * ssm_in_b = nullptr; + struct ggml_tensor * ssm_mup_vec = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; @@ -344,6 +350,8 @@ struct llama_model { struct ggml_tensor * output = nullptr; struct ggml_tensor * output_b = nullptr; struct ggml_tensor * output_norm_enc = nullptr; + struct ggml_tensor * final_norm = nullptr; + // classifier struct ggml_tensor * cls = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index ba2e1864ec005..3676b774310c1 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1519,6 +1519,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "llama-v3" || tokenizer_pre == "llama-bpe"|| tokenizer_pre == "falcon3" || + tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; From c5f9f36d93b0511591c15cbc0c8f090f2161c459 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:21:29 -0600 Subject: [PATCH 56/92] feat: Add llama_model_is_hybrid API call Also, split llama_model_is_recurrent into llm_arch_is_recurrent in llama-arch with llama_model_is_recurrent delegating to llm_arch_is_recurrent. The same split is done for hybird. This is needed because there are places where the llama_model has not yet been initialized but we need to check if the model is recurrent (specifically for the per-layer recurrent check array in hparams). Branch: GraniteFour Signed-off-by: Gabe Goodhart --- include/llama.h | 3 +++ src/llama-arch.cpp | 22 ++++++++++++++++++++++ src/llama-arch.h | 3 +++ src/llama-model.cpp | 13 +++++-------- 4 files changed, 33 insertions(+), 8 deletions(-) diff --git a/include/llama.h b/include/llama.h index 635508b10f2ff..168059cdc4ea8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -572,6 +572,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index de8d289cf967e..d20f2bf268aba 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1816,3 +1816,25 @@ llm_arch llm_arch_from_string(const std::string & name) { const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor) { return LLM_TENSOR_INFOS.at(tensor); } + +bool llm_arch_is_recurrent(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_MAMBA: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; + } +} + +bool llm_arch_is_hybrid(const llm_arch & arch) { + // TODO: There are currently no hybrid models! Once there are, this will be + // the place to identify them + switch (arch) { + default: + return false; + } +} diff --git a/src/llama-arch.h b/src/llama-arch.h index 3e8a61da3c13e..0c248f72df86f 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -439,3 +439,6 @@ const char * llm_arch_name(llm_arch arch); llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); + +bool llm_arch_is_recurrent(const llm_arch& arch); +bool llm_arch_is_hybrid(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a5eb122f998d8..bac1a07c4ddad 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14377,14 +14377,11 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { } bool llama_model_is_recurrent(const llama_model * model) { - switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; - } + return llm_arch_is_recurrent(model->arch); +} + +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { From 665a35769f535cbcac497b682a7d11b231ff4374 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:08:33 -0600 Subject: [PATCH 57/92] feat: Add c++ side constants for attention layer indices hparam Branch: GraniteFour --- src/llama-arch.cpp | 1 + src/llama-arch.h | 1 + 2 files changed, 2 insertions(+) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index d20f2bf268aba..0bc60565df12c 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -147,6 +147,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_LAYER_INDICES, "%s.attention.layer_indices" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, diff --git a/src/llama-arch.h b/src/llama-arch.h index 0c248f72df86f..82b57d2dfb694 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -151,6 +151,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_LAYER_INDICES, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, From 6497b801e6e3c497d29fc388c419e69417ed92cb Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:04:36 -0600 Subject: [PATCH 58/92] feat: Add support for distinguishing recurrent vs non-recurrent layers in hparams Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-hparams.cpp | 14 ++++++++++++-- src/llama-hparams.h | 10 ++++++++-- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 1499eb08a5dd9..70a7114f39715 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,7 +65,10 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s() const { +uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -76,7 +79,10 @@ uint32_t llama_hparams::n_embd_k_s() const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s() const { +uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { + if (!recurrent_layer(il)) { + return 0; + } if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; @@ -86,6 +92,10 @@ uint32_t llama_hparams::n_embd_v_s() const { return ssm_d_state * ssm_d_inner; } +bool llama_hparams::recurrent_layer(uint32_t il) const { + return recurrent_layer_arr[il]; +} + bool llama_hparams::is_swa(uint32_t il) const { if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index b2bcb8b01a18b..3614596464318 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -115,6 +115,9 @@ struct llama_hparams { uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + // for hybrid state space models + std::array recurrent_layer_arr; + bool ssm_dt_b_c_rms = false; float f_clamp_kqv = 0.0f; @@ -181,10 +184,13 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s() const; + uint32_t n_embd_k_s(uint32_t il = 0) const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s() const; + uint32_t n_embd_v_s(uint32_t il = 0) const; + + // whether or not the given layer is recurrent (for hybrid models) + bool recurrent_layer(uint32_t il) const; bool is_swa(uint32_t il) const; }; From acd55643535a441dcab63153c23a977024a35881 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 9 May 2025 15:22:18 -0600 Subject: [PATCH 59/92] feat: Auto-fill hparams.recurrent_layer_arr based on whether the model is recurrent Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index bac1a07c4ddad..39cf6631c9851 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -470,6 +470,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { std::fill(hparams.n_head_arr.begin(), hparams.n_head_arr.end(), 0); std::fill(hparams.n_head_kv_arr.begin(), hparams.n_head_kv_arr.end(), 0); std::fill(hparams.n_ff_arr.begin(), hparams.n_ff_arr.end(), 0); + std::fill( + hparams.recurrent_layer_arr.begin(), + hparams.recurrent_layer_arr.end(), + llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); From cc40ce989d725e2d39aa47c2d607c9ba48f0765f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 28 May 2025 06:48:53 -0600 Subject: [PATCH 60/92] refactor: rename *_is_hybrid -> *_is_hybrid_recurrent The implementation of the hybrid cache intentionally does not specify the types of the child caches, so there was a naming mismatch with these predicate functions that used "hybrid" to imply "hybrid recurrent." Branch: HybridCache Signed-off-by: Gabe Goodhart --- include/llama.h | 2 +- src/llama-arch.cpp | 2 +- src/llama-arch.h | 2 +- src/llama-model.cpp | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/llama.h b/include/llama.h index 168059cdc4ea8..10f58b278d3ae 100644 --- a/include/llama.h +++ b/include/llama.h @@ -573,7 +573,7 @@ extern "C" { LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) - LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model); // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 0bc60565df12c..cc1e3bebaf1fc 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -1831,7 +1831,7 @@ bool llm_arch_is_recurrent(const llm_arch & arch) { } } -bool llm_arch_is_hybrid(const llm_arch & arch) { +bool llm_arch_is_hybrid_recurrent(const llm_arch & arch) { // TODO: There are currently no hybrid models! Once there are, this will be // the place to identify them switch (arch) { diff --git a/src/llama-arch.h b/src/llama-arch.h index 82b57d2dfb694..4c1c03d6ba8f8 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -442,4 +442,4 @@ llm_arch llm_arch_from_string(const std::string & name); const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); bool llm_arch_is_recurrent(const llm_arch& arch); -bool llm_arch_is_hybrid(const llm_arch& arch); +bool llm_arch_is_hybrid_recurrent(const llm_arch& arch); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 39cf6631c9851..a49a2c09fd1d1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14384,8 +14384,8 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } -bool llama_model_is_hybrid(const llama_model * model) { - return llm_arch_is_hybrid(model->arch); +bool llama_model_is_hybrid_recurrent(const llama_model * model) { + return llm_arch_is_hybrid_recurrent(model->arch); } const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { From f56561bbcdf19522aacffb084c91e42ab8cbc142 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 20 May 2025 13:43:16 -0600 Subject: [PATCH 61/92] feat: Add layer filter to recurrent cache Branch: HybridCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-recurrent.cpp | 22 ++++++++++++++-------- src/llama-kv-cache-recurrent.h | 17 +++++++++++------ src/llama-model.cpp | 1 + 3 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 8f6f120f682b7..917d2a60c9aac 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -16,12 +16,13 @@ // llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -63,6 +64,11 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( v_l.reserve(n_layer); for (int i = 0; i < n_layer; i++) { + if (filter && !filter(i)) { + LLAMA_LOG_DEBUG("%s: layer %3d: skipped\n", __func__, i); + continue; + } + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); @@ -88,8 +94,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); - k_l.push_back(k); - v_l.push_back(v); + k_l[i] = k; + v_l[i] = v; } // allocate tensors and initialize the buffers to avoid NaNs in the padding diff --git a/src/llama-kv-cache-recurrent.h b/src/llama-kv-cache-recurrent.h index f9b01a6513393..b89590b7be6a2 100644 --- a/src/llama-kv-cache-recurrent.h +++ b/src/llama-kv-cache-recurrent.h @@ -15,13 +15,18 @@ // see the implementation of llama_kv_cache_unified_state_i for an example how to do it class llama_kv_cache_recurrent : public llama_memory_i { public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_recurrent( - const llama_model & model, - ggml_type type_k, - ggml_type type_v, - bool offload, - uint32_t kv_size, - uint32_t n_seq_max); + const llama_model & model, + layer_filter_cb && filter, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size, + uint32_t n_seq_max); ~llama_kv_cache_recurrent() = default; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a49a2c09fd1d1..c82cdc6278abe 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13759,6 +13759,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = new llama_kv_cache_recurrent( *this, + nullptr, GGML_TYPE_F32, GGML_TYPE_F32, cparams.offload_kqv, From 17abb2bccbd07fc524458d06274ad2e5e13f0055 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 14 May 2025 09:16:06 -0600 Subject: [PATCH 62/92] fix: Use per-layer sizing everywhere in kv caches Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-recurrent.cpp | 16 ++++++++-------- src/llama-kv-cache-unified.cpp | 16 ++++++++-------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 917d2a60c9aac..672f0197d07bf 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); const char * dev_name = "CPU"; @@ -754,7 +754,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; @@ -774,7 +774,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -795,7 +795,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -942,7 +942,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -970,7 +970,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -998,7 +998,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 3b37679859d39..91e093859dcc0 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); const char * dev_name = "CPU"; @@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Write key type const int32_t k_type_i = (int32_t)layer.k->type; @@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); // Read type of key int32_t k_type_i_ref; @@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; @@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); // Read type of value int32_t v_type_i_ref; From b63beacfdc5010837bff414d834b8e1f3459696c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 30 May 2025 09:35:26 -0600 Subject: [PATCH 63/92] feat: First pass at llama_kv_cache_hybrid_recurrent This follows the pattern in iswa where the two child caches are held explicitly to support the case where a model requires a single attention cache and a single recurrent cache where each layer uses exactly one of the caches. This is a rewrite of the more generic approach in the original hybrid cache PR: https://github.com/ggml-org/llama.cpp/pull/13276 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/CMakeLists.txt | 1 + src/llama-kv-cache-hybrid-recurrent.cpp | 241 ++++++++++++++++++++++++ src/llama-kv-cache-hybrid-recurrent.h | 140 ++++++++++++++ 3 files changed, 382 insertions(+) create mode 100644 src/llama-kv-cache-hybrid-recurrent.cpp create mode 100644 src/llama-kv-cache-hybrid-recurrent.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 70be604e4b0d3..e70c8d9a183df 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -23,6 +23,7 @@ add_library(llama llama-kv-cache-unified.cpp llama-kv-cache-unified-iswa.cpp llama-kv-cache-recurrent.cpp + llama-kv-cache-hybrid-recurrent.cpp llama-memory.cpp llama-mmap.cpp llama-model-loader.cpp diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp new file mode 100644 index 0000000000000..bd2323762f37f --- /dev/null +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -0,0 +1,241 @@ +#include "llama-kv-cache-hybrid-recurrent.h" + +#include "llama-impl.h" +#include "llama-model.h" +#include "llama-context.h" + +// +// llama_kv_cache_hybrid_recurrent +// + +llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( + const llama_model & model, + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload) : + hparams(model.hparams), + kv_attn(new llama_kv_cache_unified( + model, + [&](int32_t il) { return !model.hparams.recurrent_layer(il); }, + attn_type_k, + attn_type_v, + attn_v_trans, + offload, + attn_kv_size, + n_seq_max, + attn_n_pad, + attn_n_swa, + attn_swa_type + )), + kv_recurrent(new llama_kv_cache_recurrent( + model, + [&](int32_t il) { return model.hparams.recurrent_layer(il); }, + recurrent_type_k, + recurrent_type_v, + offload, + recurrent_kv_size, + n_seq_max + )) {} + +void llama_kv_cache_hybrid_recurrent::clear() { + kv_attn ->clear(); + kv_recurrent->clear(); +} + +bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!kv_recurrent->seq_rm(seq_id, p0, p1)) { + return false; + } + return kv_attn->seq_rm(seq_id, p0, p1); +} + +void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1); +} + +void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) { + kv_attn ->seq_keep(seq_id); + kv_recurrent->seq_keep(seq_id); +} + +void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_attn->seq_add(seq_id, p0, p1, shift); + kv_recurrent->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_attn ->seq_div(seq_id, p0, p1, d); + kv_recurrent->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id)); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { + + // since this includes a recurrent cache, we cannot use split_simple + auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + + // follow the recurrent pattern for creating the ubatch splits + std::vector ubatches; + while (sbatch.n_tokens > 0) { + llama_ubatch ubatch; + + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + ubatch = sbatch.split_seq(n_ubatch); + } else { + ubatch = sbatch.split_equal(n_ubatch); + } + + ubatches.push_back(ubatch); + } + + // prepare the recurrent batches first + if (!kv_recurrent->prepare(ubatches)) { + // TODO: will the recurrent cache be in an undefined state at this point? + LLAMA_LOG_ERROR("%s: failed to prepare recurrent ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + // prepare the attention cache + auto heads_attn = kv_attn->prepare(ubatches); + if (heads_attn.empty()) { + LLAMA_LOG_ERROR("%s: failed to prepare attention ubatches\n", __func__); + return std::make_unique(LLAMA_MEMORY_STATUS_FAILED_PREPARE); + } + + return std::make_unique( + this, std::move(sbatch), std::move(heads_attn), std::move(ubatches)); +} + +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { + return std::make_unique(this); +} + +bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) { + bool res = false; + + res = res | kv_attn ->update(lctx); + res = res | kv_recurrent->update(lctx); + + return res; +} + +void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { + kv_attn ->defrag_sched(thold); + kv_recurrent->defrag_sched(thold); +} + +bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { + // TODO: Should this return true if the attention cache can shift? + return false; +} + +void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + kv_attn ->state_write(io, seq_id); + kv_recurrent->state_write(io, seq_id); +} + +void llama_kv_cache_hybrid_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + kv_attn ->state_read(io, seq_id); + kv_recurrent->state_read(io, seq_id); +} + +llama_kv_cache_unified * llama_kv_cache_hybrid_recurrent::get_kv_attn() const { + return kv_attn.get(); +} + +llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() const { + return kv_recurrent.get(); +} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) + : status(status), state_attn(status), state_recurrent(status) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + kv(kv), + state_attn(status, kv->get_kv_attn()), + state_recurrent(status, kv->get_kv_recurrent()) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_sbatch sbatch, + std::vector heads_attn, + std::vector ubatches) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + kv(kv), + sbatch(std::move(sbatch)), + heads_attn(std::move(heads_attn)), + ubatches(std::move(ubatches)), + // NOTE: these child states are only used as wrapper APIs for the + // const methods, so we use the "init full" signature since the + // actual state is not used. + state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()), + state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {} + + +bool llama_kv_cache_hybrid_recurrent_state::next() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + if (++i_next >= ubatches.size()) { + return false; + } + + return true; +} + +bool llama_kv_cache_hybrid_recurrent_state::apply() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]); + kv->get_kv_recurrent()->find_slot(ubatches[i_next]); + + return true; +} + +std::vector & llama_kv_cache_hybrid_recurrent_state::out_ids() { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + + return sbatch.out_ids; +} + +llama_memory_status llama_kv_cache_hybrid_recurrent_state::get_status() const { + return status; +} + +const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { + assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + return ubatches[i_next]; +} + +const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const { + return &state_attn; +} + +const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { + return &state_recurrent; +} diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h new file mode 100644 index 0000000000000..692079b650027 --- /dev/null +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -0,0 +1,140 @@ +#pragma once + +#include "llama-batch.h" +#include "llama-graph.h" +#include "llama-kv-cache.h" +#include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-unified.h" + +#include +#include + +// +// llama_kv_cache_hybrid_recurrent +// + +// utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to +// support models where each layer may be either attention-based or recurrent + +class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { +public: + llama_kv_cache_hybrid_recurrent( + const llama_model & model, + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload); + + ~llama_kv_cache_hybrid_recurrent() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + llama_memory_state_ptr init_batch( + const llama_batch & batch, + uint32_t n_ubatch, + bool embd_pooled, + bool logits_all) override; + + llama_memory_state_ptr init_full() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + bool get_can_shift() const override; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // + // llama_kv_cache_hybrid_recurrent specific API + // + + llama_kv_cache_unified * get_kv_attn () const; + llama_kv_cache_recurrent * get_kv_recurrent() const; + +private: + const llama_hparams & hparams; + + const std::unique_ptr kv_attn; + const std::unique_ptr kv_recurrent; +}; + +class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { +public: + // init failure + explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); + + // init full + explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); + + // init success + llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_sbatch sbatch, + std::vector heads_attn, + std::vector ubatches); + + ~llama_kv_cache_hybrid_recurrent_state() = default; + + bool next() override; + bool apply() override; + + std::vector & out_ids() override; + + llama_memory_status get_status() const override; + const llama_ubatch & get_ubatch() const override; + + // + // llama_kv_cache_hybrid_recurrent_state_i + // + + const llama_kv_cache_unified_state * get_state_attn () const; + const llama_kv_cache_recurrent_state * get_state_recurrent() const; + +private: + const llama_memory_status status; + + llama_kv_cache_hybrid_recurrent * kv; + + llama_sbatch sbatch; + + // the index of the next ubatch to process + size_t i_next = 0; + + std::vector heads_attn; + std::vector ubatches; + + const llama_kv_cache_unified_state state_attn; + const llama_kv_cache_recurrent_state state_recurrent; +}; From 4c3734317652f5851fd201a50093eba28faeba1f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 28 May 2025 08:57:18 -0600 Subject: [PATCH 64/92] feat: Construct hybrid recurrent cache for hybrid recurrent models This includes a refactor of the create_memory logic to avoid needing to use the arch enum explicitly unless a model needs explicit cache instantiation logic beyond the standard logic for recurrent, hybrid, unified, and iswa. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 110 +++++++++++++++++++++++++------------------- 1 file changed, 63 insertions(+), 47 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c82cdc6278abe..439d542b057e6 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9,6 +9,7 @@ #include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-hybrid-recurrent.h" #include "ggml-cpp.h" @@ -13742,6 +13743,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_i * res; switch (arch) { + // Models that need specific instantiation should be handled in the + // switch statement case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: @@ -13751,58 +13754,71 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, { res = nullptr; } break; - case LLM_ARCH_MAMBA: - case LLM_ARCH_RWKV6: - case LLM_ARCH_RWKV6QWEN2: - case LLM_ARCH_RWKV7: - case LLM_ARCH_ARWKV7: - { - res = new llama_kv_cache_recurrent( - *this, - nullptr, - GGML_TYPE_F32, - GGML_TYPE_F32, - cparams.offload_kqv, - std::max((uint32_t) 1, cparams.n_seq_max), - cparams.n_seq_max); - } break; + // Models that need standard caching should rely on recurrent/hybrid + // checks default: { - const auto padding = llama_kv_cache_unified::get_padding(cparams); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - GGML_ASSERT(hparams.is_swa_any()); - - res = new llama_kv_cache_unified_iswa( - *this, - params.type_k, - params.type_v, - !cparams.flash_attn, - cparams.offload_kqv, - params.swa_full, - cparams.n_ctx, - cparams.n_seq_max, - cparams.n_ubatch, - padding); - } else { - GGML_ASSERT(!hparams.is_swa_any()); - - res = new llama_kv_cache_unified( + if (llm_arch_is_recurrent(arch)) { + res = new llama_kv_cache_recurrent( *this, nullptr, - params.type_k, - params.type_v, - !cparams.flash_attn, + GGML_TYPE_F32, + GGML_TYPE_F32, cparams.offload_kqv, - cparams.n_ctx, - cparams.n_seq_max, - padding, - hparams.n_swa, - hparams.swa_type); + std::max((uint32_t) 1, cparams.n_seq_max), + cparams.n_seq_max); + } else if (llm_arch_is_hybrid_recurrent(arch)) { + res = new llama_kv_cache_hybrid_recurrent( + /* model */ *this, + /* attn_type_k */ params.type_k, + /* attn_type_v */ params.type_v, + /* attn_v_trans */ !cparams.flash_attn, + /* attn_kv_size */ cparams.n_ctx, + /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams), + /* attn_n_swa */ hparams.n_swa, + /* attn_swa_type */ hparams.swa_type, + /* recurrent_type_k */ GGML_TYPE_F32, + /* recurrent_type_v */ GGML_TYPE_F32, + /* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max), + /* n_seq_max */ cparams.n_seq_max, + /* offload */ cparams.offload_kqv); + } else { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + GGML_ASSERT(hparams.is_swa_any()); + + res = new llama_kv_cache_unified_iswa( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + params.swa_full, + cparams.n_ctx, + cparams.n_seq_max, + cparams.n_ubatch, + padding); + } else { + GGML_ASSERT(!hparams.is_swa_any()); + + res = new llama_kv_cache_unified( + *this, + nullptr, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + cparams.n_seq_max, + padding, + hparams.n_swa, + hparams.swa_type); + } } } } From 0ffbec5c9618c2cd5cb2858dec39a8509562796f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 28 May 2025 11:02:54 -0600 Subject: [PATCH 65/92] fix: Fix wrong bool condition for split equal in hybrid cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index bd2323762f37f..beadcee7ba3d1 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -96,7 +96,7 @@ llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) cons llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { // since this includes a recurrent cache, we cannot use split_simple - auto sbatch = llama_sbatch(batch, hparams.n_embd, true, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); // follow the recurrent pattern for creating the ubatch splits std::vector ubatches; From 936da1a2a86f597407f9cba1d18be4d0bf14aec6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 3 Jun 2025 16:29:40 -0600 Subject: [PATCH 66/92] fix: Fix shift logic to defer to unified cache Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index beadcee7ba3d1..a6468482dae5d 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -150,8 +150,8 @@ void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { } bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { - // TODO: Should this return true if the attention cache can shift? - return false; + // Shifting is trivially supported for recurrent + return kv_attn->get_can_shift(); } void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { From fb220cc3f5726bb5a29d06d6128aba300e2ac51b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 4 Jun 2025 08:47:55 -0600 Subject: [PATCH 67/92] feat: Support hybrid recurrent in llama-graph NOTE: I intentionally did not add support for s_mask since it will be going away soon Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 52 +++++++++++++++++++++++++++++++++++++++++++-- src/llama-graph.h | 30 ++++++++++++++++++++++++-- 2 files changed, 78 insertions(+), 4 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 337fb5cb0df36..d83e02811b662 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -7,6 +7,7 @@ #include "llama-kv-cache-unified.h" #include "llama-kv-cache-unified-iswa.h" #include "llama-kv-cache-recurrent.h" +#include "llama-kv-cache-hybrid-recurrent.h" #include #include @@ -403,6 +404,13 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { +} + // // llm_graph_context // @@ -961,8 +969,10 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const auto * kv_state = static_cast(mstate); +ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { + if (kv_state == nullptr) { + kv_state = static_cast(mstate); + } auto inp = std::make_unique(kv_state); @@ -1291,6 +1301,44 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(hparams, cparams, kv_state); + + { + GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); + + const auto n_kv = kv_state->get_state_attn()->get_n_kv(); + + inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); + //cb(inp->self_kq_mask, "KQ_mask", -1); + ggml_set_input(inp->self_kq_mask); + + inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + } + + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + return build_attn( + static_cast(inp), + gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il + ); +} + llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { const auto * kv_state = static_cast(mstate); diff --git a/src/llama-graph.h b/src/llama-graph.h index 87813119b1a3c..5abdfde24c87b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -22,6 +22,7 @@ struct llama_memory_state_i; class llama_kv_cache_unified_state; class llama_kv_cache_unified_iswa_state; class llama_kv_cache_recurrent_state; +class llama_kv_cache_hybrid_recurrent_state; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -242,7 +243,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { cparams(cparams), kv_state(kv_state) { } - ~llm_graph_input_attn_kv_unified() = default; + virtual ~llm_graph_input_attn_kv_unified() = default; void set_input(const llama_ubatch * ubatch) override; @@ -285,6 +286,16 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_kv_cache_unified_iswa_state * kv_state; }; +class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { +public: + llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state); + + virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; +}; + class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -508,7 +519,7 @@ struct llm_graph_context { ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; - ggml_tensor * build_inp_s_copy() const; + ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -574,6 +585,21 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; + llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( From 40e7abda5581de9796583fb1993ca73984b55dc5 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 4 Jun 2025 15:02:14 -0600 Subject: [PATCH 68/92] fix: Fix logic for initializing inputs and attn layers for hybrid caches Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 53 +++++++++++++++------------------------------ src/llama-graph.h | 38 +++++++------------------------- 2 files changed, 25 insertions(+), 66 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index d83e02811b662..bb3e6c17c3395 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -404,13 +404,6 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { } } -llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state) : - llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { -} - // // llm_graph_context // @@ -1269,7 +1262,9 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - const auto * kv_state = static_cast(mstate); + // NOTE: For hybrid caches, this may be a child of mstate, so we use the one + // encapsulated in inp + const auto * kv_state = inp->kv_state; // store to KV cache { @@ -1301,10 +1296,10 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { +llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { const auto * kv_state = static_cast(mstate); - auto inp = std::make_unique(hparams, cparams, kv_state); + auto inp = std::make_unique(hparams, cparams, kv_state->get_state_attn()); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); @@ -1318,25 +1313,7 @@ llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_ inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } - return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); -} - -ggml_tensor * llm_graph_context::build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, - ggml_tensor * k_cur, - ggml_tensor * v_cur, - ggml_tensor * kq_b, - ggml_tensor * v_mla, - float kq_scale, - int il) const { - return build_attn( - static_cast(inp), - gf, wo, wo_b, q_cur, k_cur, v_cur, kq_b, v_mla, kq_scale, il - ); + return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { @@ -1479,13 +1456,17 @@ ggml_tensor * llm_graph_context::build_attn( } ggml_tensor * llm_graph_context::build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies, + const llama_kv_cache_recurrent_state * kv_state) const { + + if (kv_state == nullptr) { + kv_state = static_cast(mstate); + } const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5abdfde24c87b..5f5846ab7ab84 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -286,16 +286,6 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { const llama_kv_cache_unified_iswa_state * kv_state; }; -class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { -public: - llm_graph_input_attn_kv_hybrid_recurrent( - const llama_hparams & hparams, - const llama_cparams & cparams, - const llama_kv_cache_hybrid_recurrent_state * kv_state); - - virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; -}; - class llm_graph_input_attn_cross : public llm_graph_input_i { public: llm_graph_input_attn_cross(const llama_cross * cross) : cross(cross) {} @@ -585,20 +575,7 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; - - ggml_tensor * build_attn( - llm_graph_input_attn_kv_hybrid_recurrent * inp, - ggml_cgraph * gf, - ggml_tensor * wo, - ggml_tensor * wo_b, - ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] - ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] - ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] - ggml_tensor * kq_b, - ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] - float kq_scale, - int il) const; + llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const; llm_graph_input_attn_cross * build_attn_inp_cross() const; @@ -620,12 +597,13 @@ struct llm_graph_context { // ggml_tensor * build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies = false) const; + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false, + const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, From 309d75e2206d03eda8fa6ae5185b0c79f97175f4 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Jun 2025 14:07:07 -0600 Subject: [PATCH 69/92] fix: Update recurrent cache for changes to remove intermediate kv_cache interface Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 129 ++++++++++++------------ src/llama-kv-cache-hybrid-recurrent.h | 50 ++++----- 2 files changed, 93 insertions(+), 86 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index a6468482dae5d..ea228a834397f 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -49,50 +49,6 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( n_seq_max )) {} -void llama_kv_cache_hybrid_recurrent::clear() { - kv_attn ->clear(); - kv_recurrent->clear(); -} - -bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { - // Try removing from the recurrent cache first since it may fail. If it does - // fail, the cache will not have been mutated. - if (!kv_recurrent->seq_rm(seq_id, p0, p1)) { - return false; - } - return kv_attn->seq_rm(seq_id, p0, p1); -} - -void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1); - kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1); -} - -void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) { - kv_attn ->seq_keep(seq_id); - kv_recurrent->seq_keep(seq_id); -} - -void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { - kv_attn->seq_add(seq_id, p0, p1, shift); - kv_recurrent->seq_add(seq_id, p0, p1, shift); -} - -void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { - kv_attn ->seq_div(seq_id, p0, p1, d); - kv_recurrent->seq_div(seq_id, p0, p1, d); -} - -llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const { - // the min of the total cache is the max of the two caches' min values - return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id)); -} - -llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const { - // the max of the total cache is the min of the two caches' max values - return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); -} - llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { // since this includes a recurrent cache, we cannot use split_simple @@ -135,23 +91,59 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { return std::make_unique(this); } -bool llama_kv_cache_hybrid_recurrent::update(llama_context & lctx) { - bool res = false; +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { + return std::make_unique( + this, + static_cast( kv_attn ->init_update(lctx, optimize).release()), + static_cast(kv_recurrent->init_update(lctx, optimize).release())); +} + +bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { + // Shifting is trivially supported for recurrent + return kv_attn->get_can_shift(); +} +void llama_kv_cache_hybrid_recurrent::clear() { + kv_attn ->clear(); + kv_recurrent->clear(); +} - res = res | kv_attn ->update(lctx); - res = res | kv_recurrent->update(lctx); +bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + // Try removing from the recurrent cache first since it may fail. If it does + // fail, the cache will not have been mutated. + if (!kv_recurrent->seq_rm(seq_id, p0, p1)) { + return false; + } + return kv_attn->seq_rm(seq_id, p0, p1); +} - return res; +void llama_kv_cache_hybrid_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + kv_attn ->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv_recurrent->seq_cp(seq_id_src, seq_id_dst, p0, p1); } -void llama_kv_cache_hybrid_recurrent::defrag_sched(float thold) { - kv_attn ->defrag_sched(thold); - kv_recurrent->defrag_sched(thold); +void llama_kv_cache_hybrid_recurrent::seq_keep(llama_seq_id seq_id) { + kv_attn ->seq_keep(seq_id); + kv_recurrent->seq_keep(seq_id); } -bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { - // Shifting is trivially supported for recurrent - return kv_attn->get_can_shift(); +void llama_kv_cache_hybrid_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + kv_attn->seq_add(seq_id, p0, p1, shift); + kv_recurrent->seq_add(seq_id, p0, p1, shift); +} + +void llama_kv_cache_hybrid_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + kv_attn ->seq_div(seq_id, p0, p1, d); + kv_recurrent->seq_div(seq_id, p0, p1, d); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_min(llama_seq_id seq_id) const { + // the min of the total cache is the max of the two caches' min values + return std::max(kv_attn->seq_pos_min(seq_id), kv_recurrent->seq_pos_min(seq_id)); +} + +llama_pos llama_kv_cache_hybrid_recurrent::seq_pos_max(llama_seq_id seq_id) const { + // the max of the total cache is the min of the two caches' max values + return std::min(kv_attn->seq_pos_max(seq_id), kv_recurrent->seq_pos_max(seq_id)); } void llama_kv_cache_hybrid_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { @@ -173,13 +165,24 @@ llama_kv_cache_recurrent * llama_kv_cache_hybrid_recurrent::get_kv_recurrent() c } llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_memory_status status) - : status(status), state_attn(status), state_recurrent(status) {} + : status(status), + state_attn(new llama_kv_cache_unified_state(status)), + state_recurrent(new llama_kv_cache_recurrent_state(status)) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), - state_attn(status, kv->get_kv_attn()), - state_recurrent(status, kv->get_kv_recurrent()) {} + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), + state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} + +llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent) + : status(LLAMA_MEMORY_STATUS_SUCCESS), + kv(kv), + state_attn(state_unified), + state_recurrent(state_recurrent) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, @@ -194,8 +197,8 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( // NOTE: these child states are only used as wrapper APIs for the // const methods, so we use the "init full" signature since the // actual state is not used. - state_attn(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_attn()), - state_recurrent(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent()) {} + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), + state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {} bool llama_kv_cache_hybrid_recurrent_state::next() { @@ -232,10 +235,10 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { return ubatches[i_next]; } -const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn () const { - return &state_attn; +const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { + return state_attn.get(); } const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { - return &state_recurrent; + return state_recurrent.get(); } diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 692079b650027..e504631e47ae4 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -2,9 +2,10 @@ #include "llama-batch.h" #include "llama-graph.h" -#include "llama-kv-cache.h" #include "llama-kv-cache-recurrent.h" #include "llama-kv-cache-unified.h" +#include "llama-kv-cells.h" +#include "llama-memory.h" #include #include @@ -16,7 +17,7 @@ // utilizes instances of llama_kv_cache_recurrent and llama_kv_cache_unified to // support models where each layer may be either attention-based or recurrent -class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { +class llama_kv_cache_hybrid_recurrent : public llama_memory_i { public: llama_kv_cache_hybrid_recurrent( const llama_model & model, @@ -42,21 +43,6 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { // llama_memory_i // - void clear() override; - - bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; - void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; - void seq_keep(llama_seq_id seq_id) override; - void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; - void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; - - llama_pos seq_pos_min(llama_seq_id seq_id) const override; - llama_pos seq_pos_max(llama_seq_id seq_id) const override; - - // - // llama_kv_cache - // - llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, @@ -65,12 +51,21 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { llama_memory_state_ptr init_full() override; - bool update(llama_context & lctx) override; - - void defrag_sched(float thold) override; + llama_memory_state_ptr init_update(llama_context * lctx, bool optimize) override; bool get_can_shift() const override; + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_min(llama_seq_id seq_id) const override; + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -92,12 +87,21 @@ class llama_kv_cache_hybrid_recurrent : public llama_kv_cache { class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { public: + using llama_kv_cache_unified_state_ptr = std::unique_ptr; + using llama_kv_cache_recurrent_state_ptr = std::unique_ptr; + // init failure explicit llama_kv_cache_hybrid_recurrent_state(llama_memory_status status); // init full explicit llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv); + // init update + explicit llama_kv_cache_hybrid_recurrent_state( + llama_kv_cache_hybrid_recurrent * kv, + llama_kv_cache_unified_state * state_unified, + llama_kv_cache_recurrent_state * state_recurrent); + // init success llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, @@ -116,7 +120,7 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { const llama_ubatch & get_ubatch() const override; // - // llama_kv_cache_hybrid_recurrent_state_i + // llama_kv_cache_hybrid_recurrent_state // const llama_kv_cache_unified_state * get_state_attn () const; @@ -135,6 +139,6 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { std::vector heads_attn; std::vector ubatches; - const llama_kv_cache_unified_state state_attn; - const llama_kv_cache_recurrent_state state_recurrent; + const llama_kv_cache_unified_state_ptr state_attn; + const llama_kv_cache_recurrent_state_ptr state_recurrent; }; From eaaf4a9c4616cc5a1c41015f73baa66766fc3c01 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Jun 2025 14:41:08 -0600 Subject: [PATCH 70/92] fix: Fix status for init_update sig for recurrent cache state Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index ea228a834397f..d269a6b5057af 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -179,7 +179,7 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent) - : status(LLAMA_MEMORY_STATUS_SUCCESS), + : status(LLAMA_MEMORY_STATUS_NO_UPDATE), kv(kv), state_attn(state_unified), state_recurrent(state_recurrent) {} From 581113face12a4f5be27e5266a486d598a92c3a6 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 5 Jun 2025 15:54:50 -0600 Subject: [PATCH 71/92] fix: Add missing padding to n_ctx for hybrid cache construction Branch: GraniteFour Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 439d542b057e6..6e9dd532237b1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -13768,13 +13768,17 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, std::max((uint32_t) 1, cparams.n_seq_max), cparams.n_seq_max); } else if (llm_arch_is_hybrid_recurrent(arch)) { + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + res = new llama_kv_cache_hybrid_recurrent( /* model */ *this, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_kv_size */ cparams.n_ctx, - /* attn_n_pad */ llama_kv_cache_unified::get_padding(cparams), + /* attn_n_pad */ padding, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, /* recurrent_type_k */ GGML_TYPE_F32, From 84e300f3a6542755f665bd1b34cd6dc44097d6c9 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Fri, 6 Jun 2025 09:38:10 -0600 Subject: [PATCH 72/92] fix: Update clear signature for data argument after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 7 ++++--- src/llama-kv-cache-hybrid-recurrent.h | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index d269a6b5057af..8871dbf631164 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -102,9 +102,10 @@ bool llama_kv_cache_hybrid_recurrent::get_can_shift() const { // Shifting is trivially supported for recurrent return kv_attn->get_can_shift(); } -void llama_kv_cache_hybrid_recurrent::clear() { - kv_attn ->clear(); - kv_recurrent->clear(); + +void llama_kv_cache_hybrid_recurrent::clear(bool data) { + kv_attn ->clear(data); + kv_recurrent->clear(data); } bool llama_kv_cache_hybrid_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index e504631e47ae4..8728fd733cc28 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -55,7 +55,7 @@ class llama_kv_cache_hybrid_recurrent : public llama_memory_i { bool get_can_shift() const override; - void clear() override; + void clear(bool data) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; From ae7a02e659a1f1bcebf45acaefd4d4a7dd06382c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Jun 2025 16:26:31 -0600 Subject: [PATCH 73/92] fix: Remove errant virtual destructor leftover from previous impl attempt Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-graph.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-graph.h b/src/llama-graph.h index 5f5846ab7ab84..329461a1d06d4 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -243,7 +243,7 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { cparams(cparams), kv_state(kv_state) { } - virtual ~llm_graph_input_attn_kv_unified() = default; + ~llm_graph_input_attn_kv_unified() = default; void set_input(const llama_ubatch * ubatch) override; From 1e49029e74a6d8427d82ad1dc31df3ecce6293f2 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Jun 2025 16:30:49 -0600 Subject: [PATCH 74/92] fix: Use per-layer n_embd_k/v_s calls for mamba (1) layers Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-model.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6e9dd532237b1..34643226e5a9e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9194,11 +9194,11 @@ struct llm_build_mamba : public llm_graph_context { // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state( gf, conv_states_all, state_copy, - hparams.n_embd_k_s(), n_seqs); + hparams.n_embd_k_s(il), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); ggml_tensor * ssm = build_recurrent_state( gf, ssm_states_all, state_copy, - hparams.n_embd_v_s(), n_seqs); + hparams.n_embd_v_s(il), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} From ddc85f0f524c3b5f3b73cfeddaec091e312dbf42 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 12:20:04 -0600 Subject: [PATCH 75/92] refactor: Remove n_embd_k/v_s from unified cache No longer needed now that unified isn't also supporting recurrent https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140761069 Branch: HybridRecurrentCache --- src/llama-kv-cache-unified.cpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/llama-kv-cache-unified.cpp b/src/llama-kv-cache-unified.cpp index 91e093859dcc0..d4412288925c3 100644 --- a/src/llama-kv-cache-unified.cpp +++ b/src/llama-kv-cache-unified.cpp @@ -68,8 +68,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); const char * dev_name = "CPU"; @@ -1430,7 +1430,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Write key type const int32_t k_type_i = (int32_t)layer.k->type; @@ -1452,7 +1452,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1476,7 +1476,7 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Write value type const int32_t v_type_i = (int32_t)layer.v->type; @@ -1621,7 +1621,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); // Read type of key int32_t k_type_i_ref; @@ -1651,7 +1651,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; @@ -1681,7 +1681,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell for (const auto & layer : layers) { const uint32_t il = layer.il; - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); // Read type of value int32_t v_type_i_ref; From 368c9f46e5244b5457006f3d9fe1b469ff72f0cc Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 12:20:47 -0600 Subject: [PATCH 76/92] refactor: Remove layer index from n_embd_k/v_s Now that it's not used at all in the unified cache, we don't need to use the layer index to zero it out for attention layers. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-hparams.cpp | 10 ++-------- src/llama-hparams.h | 4 ++-- src/llama-kv-cache-recurrent.cpp | 16 ++++++++-------- src/llama-model.cpp | 4 ++-- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 70a7114f39715..0ec3a8a501f76 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -65,10 +65,7 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } -uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { - if (!recurrent_layer(il)) { - return 0; - } +uint32_t llama_hparams::n_embd_k_s() const { if (wkv_head_size != 0) { // for RWKV models return token_shift_count * n_embd; @@ -79,10 +76,7 @@ uint32_t llama_hparams::n_embd_k_s(uint32_t il) const { return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; } -uint32_t llama_hparams::n_embd_v_s(uint32_t il) const { - if (!recurrent_layer(il)) { - return 0; - } +uint32_t llama_hparams::n_embd_v_s() const { if (wkv_head_size != 0) { // corresponds to RWKV's wkv_states size return n_embd * wkv_head_size; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 3614596464318..84234494c5611 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -184,10 +184,10 @@ struct llama_hparams { // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size - uint32_t n_embd_k_s(uint32_t il = 0) const; + uint32_t n_embd_k_s() const; // dimension of the recurrent state embeddings - uint32_t n_embd_v_s(uint32_t il = 0) const; + uint32_t n_embd_v_s() const; // whether or not the given layer is recurrent (for hybrid models) bool recurrent_layer(uint32_t il) const; diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 672f0197d07bf..917d2a60c9aac 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -69,8 +69,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(i); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(i); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); const char * dev_name = "CPU"; @@ -754,7 +754,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; @@ -774,7 +774,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -795,7 +795,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -942,7 +942,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(il); + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Read type of key int32_t k_type_i_ref; @@ -970,7 +970,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -998,7 +998,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(il); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 34643226e5a9e..6e9dd532237b1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9194,11 +9194,11 @@ struct llm_build_mamba : public llm_graph_context { // (ab)using the KV cache to store the states ggml_tensor * conv = build_recurrent_state( gf, conv_states_all, state_copy, - hparams.n_embd_k_s(il), n_seqs); + hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); ggml_tensor * ssm = build_recurrent_state( gf, ssm_states_all, state_copy, - hparams.n_embd_v_s(il), n_seqs); + hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} From b82a2255dbc06408aeb107242d390f61f786b259 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 12:56:26 -0600 Subject: [PATCH 77/92] refactor: Remove n_embd_k/v_gqa from recurrent cache This is no longer needed now that there are separate implementations https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140825128 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-recurrent.cpp | 39 +++++++++++++------------------- 1 file changed, 16 insertions(+), 23 deletions(-) diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index 917d2a60c9aac..be19edd316542 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -69,9 +69,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - const char * dev_name = "CPU"; ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); @@ -90,8 +87,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, hparams.n_embd_k_s()*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, hparams.n_embd_v_s()*kv_size); ggml_format_name(k, "cache_k_l%d", i); ggml_format_name(v, "cache_v_l%d", i); k_l[i] = k; @@ -754,14 +751,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Write key type const int32_t k_type_i = (int32_t)k_l[il]->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out @@ -774,14 +770,13 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out @@ -795,7 +790,7 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t kv_size = size; for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_s = hparams.n_embd_v_s(); // Write value type const int32_t v_type_i = (int32_t)v_l[il]->type; @@ -806,10 +801,10 @@ void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size - io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + io.write(&n_embd_v_s, sizeof(n_embd_v_s)); // For each row, we get the element values of each cell - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + for (uint32_t j = 0; j < n_embd_v_s; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out for (const auto & range : cell_ranges) { const size_t range_size = range.second - range.first; @@ -942,7 +937,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); // Read type of key int32_t k_type_i_ref; @@ -956,7 +950,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(k_l[il]->type, hparams.n_embd_k_s()); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -970,7 +964,6 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce if (!v_trans) { for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -984,7 +977,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(v_l[il]->type, hparams.n_embd_v_s()); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -998,7 +991,7 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + const uint32_t n_embd_v_s = hparams.n_embd_v_s(); // Read type of value int32_t v_type_i_ref; @@ -1018,17 +1011,17 @@ bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t ce return false; } - // Read GQA embedding size - uint32_t n_embd_v_gqa_ref; - io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); - if (n_embd_v_gqa != n_embd_v_gqa_ref) { - LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + // Read state embedding size + uint32_t n_embd_v_s_ref; + io.read_to(&n_embd_v_s_ref, sizeof(n_embd_v_s_ref)); + if (n_embd_v_s != n_embd_v_s_ref) { + LLAMA_LOG_ERROR("%s: mismatched state embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_s, n_embd_v_s_ref, il); return false; } if (cell_count) { // For each row in the transposed matrix, read the values for the whole cell range - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + for (uint32_t j = 0; j < n_embd_v_s; ++j) { const size_t dst_offset = (head + j * size) * v_size_el; ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } From 8b7f687d1eecb98647db01b8560a60145dc4866d Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 11 Jun 2025 13:41:52 -0600 Subject: [PATCH 78/92] feat: Allow custom layer filters for hybrid recurrent This should help support architectures like Falcon H1 where there is overlap between layers that need attention and recurrent caches. https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2140748922 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 41 +++++++++++++++---------- src/llama-kv-cache-hybrid-recurrent.h | 37 +++++++++++++--------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 8871dbf631164..889f43025a3e3 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -10,25 +10,30 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( const llama_model & model, - /* attn */ - ggml_type attn_type_k, - ggml_type attn_type_v, - bool attn_v_trans, - uint32_t attn_kv_size, - uint32_t attn_n_pad, - uint32_t attn_n_swa, - llama_swa_type attn_swa_type, - /* recurrent */ - ggml_type recurrent_type_k, - ggml_type recurrent_type_v, - uint32_t recurrent_kv_size, - /* common */ - uint32_t n_seq_max, - bool offload) : + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter, + layer_filter_cb && recurrent_filter) : hparams(model.hparams), kv_attn(new llama_kv_cache_unified( model, - [&](int32_t il) { return !model.hparams.recurrent_layer(il); }, + attn_filter == nullptr ? + [&](int32_t il) { return !model.hparams.recurrent_layer(il); } + : attn_filter, attn_type_k, attn_type_v, attn_v_trans, @@ -41,7 +46,9 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( )), kv_recurrent(new llama_kv_cache_recurrent( model, - [&](int32_t il) { return model.hparams.recurrent_layer(il); }, + recurrent_filter == nullptr ? + [&](int32_t il) { return model.hparams.recurrent_layer(il); } + : recurrent_filter, recurrent_type_k, recurrent_type_v, offload, diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 8728fd733cc28..444e87e101136 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -19,23 +19,30 @@ class llama_kv_cache_hybrid_recurrent : public llama_memory_i { public: + + // this callback is used to filter out layers that should not be included in the cache + using layer_filter_cb = std::function; + llama_kv_cache_hybrid_recurrent( const llama_model & model, - /* attn */ - ggml_type attn_type_k, - ggml_type attn_type_v, - bool attn_v_trans, - uint32_t attn_kv_size, - uint32_t attn_n_pad, - uint32_t attn_n_swa, - llama_swa_type attn_swa_type, - /* recurrent */ - ggml_type recurrent_type_k, - ggml_type recurrent_type_v, - uint32_t recurrent_kv_size, - /* common */ - uint32_t n_seq_max, - bool offload); + /* attn */ + ggml_type attn_type_k, + ggml_type attn_type_v, + bool attn_v_trans, + uint32_t attn_kv_size, + uint32_t attn_n_pad, + uint32_t attn_n_swa, + llama_swa_type attn_swa_type, + /* recurrent */ + ggml_type recurrent_type_k, + ggml_type recurrent_type_v, + uint32_t recurrent_kv_size, + /* common */ + uint32_t n_seq_max, + bool offload, + /* layer filters */ + layer_filter_cb && attn_filter = nullptr, + layer_filter_cb && recurrent_filter = nullptr); ~llama_kv_cache_hybrid_recurrent() = default; From 3ee22224493725e36456be8c22d9d8eaf70fa29f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 14:00:53 -0600 Subject: [PATCH 79/92] fix: Remove logits_all after rebase Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 4 ++-- src/llama-kv-cache-hybrid-recurrent.h | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 889f43025a3e3..49a7c35ab8cfa 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -56,10 +56,10 @@ llama_kv_cache_hybrid_recurrent::llama_kv_cache_hybrid_recurrent( n_seq_max )) {} -llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled, bool logits_all) { +llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_batch(const llama_batch & batch, uint32_t n_ubatch, bool embd_pooled) { // since this includes a recurrent cache, we cannot use split_simple - auto sbatch = llama_sbatch(batch, hparams.n_embd, false, logits_all); + auto sbatch = llama_sbatch(batch, hparams.n_embd, false); // follow the recurrent pattern for creating the ubatch splits std::vector ubatches; diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 444e87e101136..d6678eb2164fa 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -53,8 +53,7 @@ class llama_kv_cache_hybrid_recurrent : public llama_memory_i { llama_memory_state_ptr init_batch( const llama_batch & batch, uint32_t n_ubatch, - bool embd_pooled, - bool logits_all) override; + bool embd_pooled) override; llama_memory_state_ptr init_full() override; From bd37fc89905665c1e6a88355b48340c074d16f7f Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 14:01:28 -0600 Subject: [PATCH 80/92] fix: Remove llama_model_is_hybrid_Recurrent public API https://github.com/ggml-org/llama.cpp/pull/13979#discussion_r2141728423 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- include/llama.h | 3 --- src/llama-model.cpp | 4 ---- 2 files changed, 7 deletions(-) diff --git a/include/llama.h b/include/llama.h index 10f58b278d3ae..635508b10f2ff 100644 --- a/include/llama.h +++ b/include/llama.h @@ -572,9 +572,6 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); - // Returns true if the model is hybrid-recurrent (like Jamba, Bamba, etc.) - LLAMA_API bool llama_model_is_hybrid_recurrent(const struct llama_model * model); - // Returns 0 on success LLAMA_API uint32_t llama_model_quantize( const char * fname_inp, diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 6e9dd532237b1..971d1df199d53 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -14405,10 +14405,6 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } -bool llama_model_is_hybrid_recurrent(const llama_model * model) { - return llm_arch_is_hybrid_recurrent(model->arch); -} - const std::vector> & llama_internal_get_tensor_map(const llama_model * model) { return model->tensors_by_name; } From 2e5e45c9ecb14c5b6a42f6a0f61802985f0964d2 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 14:30:21 -0600 Subject: [PATCH 81/92] refactor: Use llama_memory_state_ptr for child states in hybrid memory state Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 4 ++-- src/llama-kv-cache-hybrid-recurrent.h | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index 49a7c35ab8cfa..a2afda7647b00 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -244,9 +244,9 @@ const llama_ubatch & llama_kv_cache_hybrid_recurrent_state::get_ubatch() const { } const llama_kv_cache_unified_state * llama_kv_cache_hybrid_recurrent_state::get_state_attn() const { - return state_attn.get(); + return static_cast(state_attn.get()); } const llama_kv_cache_recurrent_state * llama_kv_cache_hybrid_recurrent_state::get_state_recurrent() const { - return state_recurrent.get(); + return static_cast(state_recurrent.get()); } diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index d6678eb2164fa..93bf72ec34837 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -145,6 +145,6 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { std::vector heads_attn; std::vector ubatches; - const llama_kv_cache_unified_state_ptr state_attn; - const llama_kv_cache_recurrent_state_ptr state_recurrent; + const llama_memory_state_ptr state_attn; + const llama_memory_state_ptr state_recurrent; }; From a42b9cb8f992281f68e46f1bfcdfad323a76ea91 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Thu, 12 Jun 2025 17:04:27 -0600 Subject: [PATCH 82/92] feat: Overhaul build_recurrent_state / build_inp_s_copy to match attention pattern https://github.com/ggml-org/llama.cpp/pull/13979/files#r2141701738 This is a big overhaul to bring consistency between how inputs and per- layer components are created for attention layers and recurrent layers. The main changes are: - Rename class llm_graph_input_s_copy -> llm_graph_input_rs - Add a corresponding llm_graph_input_rs_hybrid_recurrent - Rename build_inp_s_copy -> build_rs_inp_recurrent - Add a corresponding build_rs_inp_hybrid_recurrent - Rename build_recurrent_state -> build_rs to match build_attn w/ llm_graph_input_rs android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a corresponding overload of build_rs w/ llm_graph_input_rs_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input - Add a llm_graph_input_attn_kv_hybrid_recurrent analogous to llm_graph_input_attn_kv_unified - Add a build_attn override that takes llm_graph_input_attn_kv_hybrid_recurrent android-build AUTHORS bamba-9b-2.2T.gguf bamba-9b-2.2T.q4_k_m.gguf broken.log build build-rel build-xcframework.sh build.android build.android.bak ci cmake CMakeLists.txt CMakePresets.json CODEOWNERS common common.o CONTRIBUTING.md convert_hf_to_gguf_update.py convert_hf_to_gguf.py convert_llama_ggml_to_gguf.py convert_lora_to_gguf.py debug.log docs examples flake.lock flake.nix ggml ggml-alloc.o ggml-backend.o ggml-metal.o ggml-model-BF16.gguf ggml-model-Q4_K_M.gguf ggml-quants.o ggml.o gguf-py grammar-parser.o grammars include LICENSE licenses llama.log llama.o llamacpp_trace.log main.log Makefile media models mypy.ini pocs poetry.lock prompts pyproject.toml pyrightconfig.json q4_k_m_boot.log q8_0_boot.log quant.log quant2.log README.md requirements requirements.txt sampling.o scripts SECURITY.md src test-grammar-output.tmp test-json-schema-input.tmp tests tools vendor working.log as the first input This makes the two paradigms fully consistent. The main drawback is the code duplication in the build_attn and build_rs implementations where the only difference between implementations is how they cast the memory state. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 203 ++++++++++++++++++++++++++++++++++---------- src/llama-graph.h | 71 ++++++++++++---- src/llama-model.cpp | 66 +++++++------- 3 files changed, 240 insertions(+), 100 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index bb3e6c17c3395..039718c04e401 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -239,7 +239,7 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) { } } -void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { +void llm_graph_input_rs::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); const int64_t n_kv = kv_state->get_n_kv(); @@ -255,6 +255,11 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_rs_hybrid_recurrent::llm_graph_input_rs_hybrid_recurrent( + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_rs(kv_state->get_state_recurrent()) { +} + void llm_graph_input_cross_embd::set_input(const llama_ubatch * ubatch) { GGML_UNUSED(ubatch); @@ -360,6 +365,13 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { } } +llm_graph_input_attn_kv_hybrid_recurrent::llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state) : + llm_graph_input_attn_kv_unified(hparams, cparams, kv_state->get_state_attn()) { +} + void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { if (self_kq_mask) { kv_state->get_base()->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); @@ -962,25 +974,6 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state) const { - if (kv_state == nullptr) { - kv_state = static_cast(mstate); - } - - auto inp = std::make_unique(kv_state); - - const auto n_kv = kv_state->get_n_kv(); - - auto & cur = inp->s_copy; - - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); - ggml_set_input(cur); - - res->add_input(std::move(inp)); - - return cur; -} - ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); @@ -1262,9 +1255,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_build_forward_expand(gf, k_cur); ggml_build_forward_expand(gf, v_cur); - // NOTE: For hybrid caches, this may be a child of mstate, so we use the one - // encapsulated in inp - const auto * kv_state = inp->kv_state; + const auto * kv_state = static_cast(mstate); // store to KV cache { @@ -1296,15 +1287,14 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { - const auto * kv_state = static_cast(mstate); - - auto inp = std::make_unique(hparams, cparams, kv_state->get_state_attn()); +llm_graph_input_attn_kv_hybrid_recurrent * llm_graph_context::build_attn_inp_kv_hybrid_recurrent() const { + auto inp = std::make_unique( + hparams, cparams, static_cast(mstate)); { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Hybrid recurrent is not supported with SWA attention layers"); - const auto n_kv = kv_state->get_state_attn()->get_n_kv(); + const auto n_kv = inp->kv_state->get_n_kv(); inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); //cb(inp->self_kq_mask, "KQ_mask", -1); @@ -1313,7 +1303,57 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_hybrid_re inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; } - return (llm_graph_input_attn_kv_unified *) res->add_input(std::move(inp)); + return (llm_graph_input_attn_kv_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, + ggml_tensor * k_cur, + ggml_tensor * v_cur, + ggml_tensor * kq_b, + ggml_tensor * v_mla, + float kq_scale, + int il) const { + // these nodes are added to the graph together so that they are not reordered + // by doing so, the number of splits in the graph is reduced + ggml_build_forward_expand(gf, q_cur); + ggml_build_forward_expand(gf, k_cur); + ggml_build_forward_expand(gf, v_cur); + + const auto * kv_state = static_cast(mstate)->get_state_attn(); + + // store to KV cache + { + ggml_build_forward_expand(gf, kv_state->cpy_k(ctx0, k_cur, il)); + ggml_build_forward_expand(gf, kv_state->cpy_v(ctx0, v_cur, il)); + } + + const auto & kq_mask = inp->get_kq_mask(); + + ggml_tensor * q = q_cur; + ggml_tensor * k = kv_state->get_k(ctx0, il); + ggml_tensor * v = kv_state->get_v(ctx0, il); + + ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + cb(cur, "kqv_out", il); + + if (wo) { + cur = build_lora_mm(wo, cur); + if (arch == LLM_ARCH_GLM4) { + // GLM4 seems to have numerical issues with half-precision accumulators + ggml_mul_mat_set_prec(cur, GGML_PREC_F32); + } + } + + if (wo_b) { + cur = ggml_add(ctx0, cur, wo_b); + } + + return cur; } llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unified_iswa() const { @@ -1455,19 +1495,90 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies, - const llama_kv_cache_recurrent_state * kv_state) const { +llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { + const auto * kv_state = static_cast(mstate); + + auto inp = std::make_unique(kv_state); + + const auto n_kv = kv_state->get_n_kv(); + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + return (llm_graph_input_rs *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_rs( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { + + const auto * kv_state = static_cast(mstate); + + const auto n_kv = kv_state->get_n_kv(); + const auto kv_head = kv_state->get_head(); + const auto rs_zero = kv_state->get_rs_z(); + + ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); + + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + + ggml_tensor * output_states; - if (kv_state == nullptr) { - kv_state = static_cast(mstate); + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // {state_size, kv_size} -> {state_size, n_seqs} + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); + ggml_build_forward_expand(gf, output_states); + } else { + // FIXME: make the gathering operation happen before the copy below + // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) + output_states = states; } + // copy extra states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + states_extra, + ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); + + return output_states; +} + +llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { + auto inp = std::make_unique( + static_cast(mstate)); + + const auto n_kv = inp->kv_state->get_n_kv(); + + auto & cur = inp->s_copy; + + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); + + return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp)); +} + +ggml_tensor * llm_graph_context::build_rs( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { + + const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); const auto rs_zero = kv_state->get_rs_z(); @@ -1485,7 +1596,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state( // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); ggml_build_forward_expand(gf, output_states); } else { // FIXME: make the gathering operation happen before the copy below @@ -1494,7 +1605,7 @@ ggml_tensor * llm_graph_context::build_recurrent_state( } // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, @@ -1504,9 +1615,9 @@ ggml_tensor * llm_graph_context::build_recurrent_state( } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( - ggml_cgraph * gf, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -1516,8 +1627,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = kv_state->get_k_l(il); - ggml_tensor * token_shift = build_recurrent_state( - gf, token_shift_all, state_copy, + ggml_tensor * token_shift = build_rs( + inp, gf, token_shift_all, hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); diff --git a/src/llama-graph.h b/src/llama-graph.h index 329461a1d06d4..77f19a673c3dd 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -189,10 +189,10 @@ class llm_graph_input_cls : public llm_graph_input_i { const llama_cparams & cparams; }; -class llm_graph_input_s_copy : public llm_graph_input_i { +class llm_graph_input_rs : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} - virtual ~llm_graph_input_s_copy() = default; + llm_graph_input_rs(const llama_kv_cache_recurrent_state * kv_state) : kv_state(kv_state) {} + virtual ~llm_graph_input_rs() = default; void set_input(const llama_ubatch * ubatch) override; @@ -201,6 +201,12 @@ class llm_graph_input_s_copy : public llm_graph_input_i { const llama_kv_cache_recurrent_state * kv_state; }; +class llm_graph_input_rs_hybrid_recurrent : public llm_graph_input_rs { +public: + llm_graph_input_rs_hybrid_recurrent(const llama_kv_cache_hybrid_recurrent_state * kv_state); + virtual ~llm_graph_input_rs_hybrid_recurrent() = default; +}; + class llm_graph_input_cross_embd : public llm_graph_input_i { public: llm_graph_input_cross_embd( @@ -258,6 +264,15 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { const llama_kv_cache_unified_state * kv_state; }; +class llm_graph_input_attn_kv_hybrid_recurrent : public llm_graph_input_attn_kv_unified { +public: + llm_graph_input_attn_kv_hybrid_recurrent( + const llama_hparams & hparams, + const llama_cparams & cparams, + const llama_kv_cache_hybrid_recurrent_state * kv_state); + virtual ~llm_graph_input_attn_kv_hybrid_recurrent() = default; +}; + class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { public: llm_graph_input_attn_kv_unified_iswa( @@ -509,7 +524,6 @@ struct llm_graph_context { ggml_tensor * build_inp_out_ids() const; ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; - ggml_tensor * build_inp_s_copy(const llama_kv_cache_recurrent_state * kv_state = nullptr) const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -575,8 +589,6 @@ struct llm_graph_context { float kq_scale, int il) const; - llm_graph_input_attn_kv_unified * build_attn_inp_kv_hybrid_recurrent() const; - llm_graph_input_attn_cross * build_attn_inp_cross() const; ggml_tensor * build_attn( @@ -592,23 +604,48 @@ struct llm_graph_context { float kq_scale, int il) const; + llm_graph_input_attn_kv_hybrid_recurrent * build_attn_inp_kv_hybrid_recurrent() const; + + ggml_tensor * build_attn( + llm_graph_input_attn_kv_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * wo, + ggml_tensor * wo_b, + ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] + ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens] + ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens] + ggml_tensor * kq_b, + ggml_tensor * v_mla, // [n_embd_head_v_mla, n_embd_head_v, n_head_v] + float kq_scale, + int il) const; // // recurrent // - ggml_tensor * build_recurrent_state( - ggml_cgraph * gf, - ggml_tensor * s, - ggml_tensor * state_copy, - int32_t state_size, - int32_t n_seqs, - bool avoid_copies = false, - const llama_kv_cache_recurrent_state * kv_state = nullptr) const; + llm_graph_input_rs * build_rs_inp_recurrent() const; + + ggml_tensor * build_rs( + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; + + llm_graph_input_rs_hybrid_recurrent * build_rs_inp_hybrid_recurrent() const; + + ggml_tensor * build_rs( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; ggml_tensor * build_rwkv_token_shift_load( - ggml_cgraph * gf, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, + llm_graph_input_rs * inp, + ggml_cgraph * gf, + const llama_ubatch & ubatch, int il) const; ggml_tensor * build_rwkv_token_shift_store( diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 971d1df199d53..0be0839793dc1 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -9116,7 +9116,7 @@ struct llm_build_mamba : public llm_graph_context { // {n_embd, n_tokens} inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); for (int il = 0; il < n_layer; ++il) { // norm @@ -9125,7 +9125,7 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); + cur = build_mamba_layer(rs_inp, gf, cur, ubatch, il); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -9163,11 +9163,11 @@ struct llm_build_mamba : public llm_graph_context { // TODO: split ggml_tensor * build_mamba_layer( - ggml_cgraph * gf, - ggml_tensor * cur, - ggml_tensor * state_copy, - const llama_ubatch & ubatch, - int il) const { + llm_graph_input_rs * inp, + ggml_cgraph * gf, + ggml_tensor * cur, + const llama_ubatch & ubatch, + int il) const { const auto * kv_state = static_cast(mstate); const auto kv_head = kv_state->get_head(); @@ -9192,12 +9192,12 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * ssm_states_all = kv_state->get_v_l(il); // (ab)using the KV cache to store the states - ggml_tensor * conv = build_recurrent_state( - gf, conv_states_all, state_copy, + ggml_tensor * conv = build_rs( + inp, gf, conv_states_all, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_recurrent_state( - gf, ssm_states_all, state_copy, + ggml_tensor * ssm = build_rs( + inp, gf, ssm_states_all, hparams.n_embd_v_s(), n_seqs); ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); @@ -11909,10 +11909,10 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * build_rwkv6_time_mix( + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, - ggml_tensor * state_copy, const llama_ubatch & ubatch, int il) const { const auto * kv_state = static_cast(mstate); @@ -12036,8 +12036,8 @@ struct llm_build_rwkv6_base : public llm_graph_context { k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); } - ggml_tensor * wkv_state = build_recurrent_state( - gf, kv_state->get_v_l(il), state_copy, + ggml_tensor * wkv_state = build_rs( + inp, gf, kv_state->get_v_l(il), hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -12092,7 +12092,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12102,9 +12102,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -12119,7 +12117,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12189,7 +12187,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12199,9 +12197,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -12213,7 +12209,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -12301,10 +12297,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { } ggml_tensor * build_rwkv7_time_mix( + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, - ggml_tensor * state_copy, ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { @@ -12387,8 +12383,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); - ggml_tensor * wkv_state = build_recurrent_state( - gf, kv_state->get_v_l(il), state_copy, + ggml_tensor * wkv_state = build_rs( + inp, gf, kv_state->get_v_l(il), hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12445,7 +12441,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12455,9 +12451,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -12472,7 +12466,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12538,7 +12532,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); - ggml_tensor * state_copy = build_inp_s_copy(); + auto * rs_inp = build_rs_inp_recurrent(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12548,9 +12542,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, ubatch, il - ); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -12562,7 +12554,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); From beddd622287725bed7be54169bcfe21d04c9a93a Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 16 Jun 2025 13:34:25 -0600 Subject: [PATCH 83/92] fix: Fix resize vs reserve and skip null tensors in size computation https://github.com/ggml-org/llama.cpp/pull/13979/files#r2149469788 Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart Co-Authored-By: @younesbelkada --- src/llama-kv-cache-recurrent.cpp | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache-recurrent.cpp b/src/llama-kv-cache-recurrent.cpp index be19edd316542..802025e22de17 100644 --- a/src/llama-kv-cache-recurrent.cpp +++ b/src/llama-kv-cache-recurrent.cpp @@ -60,8 +60,8 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( return it->second; }; - k_l.reserve(n_layer); - v_l.reserve(n_layer); + k_l.resize(n_layer); + v_l.resize(n_layer); for (int i = 0; i < n_layer; i++) { if (filter && !filter(i)) { @@ -647,7 +647,9 @@ size_t llama_kv_cache_recurrent::size_k_bytes() const { size_t size_k_bytes = 0; for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); + if (k != nullptr) { + size_k_bytes += ggml_nbytes(k); + } } return size_k_bytes; @@ -657,7 +659,9 @@ size_t llama_kv_cache_recurrent::size_v_bytes() const { size_t size_v_bytes = 0; for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); + if (v != nullptr) { + size_v_bytes += ggml_nbytes(v); + } } return size_v_bytes; From 8d2407307f41b08b07ef96effce618d558c998d3 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 16 Jun 2025 13:48:20 -0600 Subject: [PATCH 84/92] fix: Fix initialization of child states Since initially writing this PR, the logic in the child state types changed such that using the "init full" signature and keeping the ubatches on the parent struct no longer worked. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-kv-cache-hybrid-recurrent.cpp | 25 +++++++++++-------------- src/llama-kv-cache-hybrid-recurrent.h | 4 ---- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/src/llama-kv-cache-hybrid-recurrent.cpp b/src/llama-kv-cache-hybrid-recurrent.cpp index a2afda7647b00..9ec205c9df1ed 100644 --- a/src/llama-kv-cache-hybrid-recurrent.cpp +++ b/src/llama-kv-cache-hybrid-recurrent.cpp @@ -100,7 +100,6 @@ llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_full() { llama_memory_state_ptr llama_kv_cache_hybrid_recurrent::init_update(llama_context * lctx, bool optimize) { return std::make_unique( - this, static_cast( kv_attn ->init_update(lctx, optimize).release()), static_cast(kv_recurrent->init_update(lctx, optimize).release())); } @@ -179,16 +178,13 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(lla llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state(llama_kv_cache_hybrid_recurrent * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), - kv(kv), state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), state_recurrent(new llama_kv_cache_recurrent_state(status, kv->get_kv_recurrent())) {} llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent) : status(LLAMA_MEMORY_STATUS_NO_UPDATE), - kv(kv), state_attn(state_unified), state_recurrent(state_recurrent) {} @@ -198,20 +194,19 @@ llama_kv_cache_hybrid_recurrent_state::llama_kv_cache_hybrid_recurrent_state( std::vector heads_attn, std::vector ubatches) : status(LLAMA_MEMORY_STATUS_SUCCESS), - kv(kv), sbatch(std::move(sbatch)), - heads_attn(std::move(heads_attn)), ubatches(std::move(ubatches)), - // NOTE: these child states are only used as wrapper APIs for the - // const methods, so we use the "init full" signature since the - // actual state is not used. - state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn())), - state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent())) {} + // note: here we copy the ubatches. not sure if this is ideal + state_attn(new llama_kv_cache_unified_state(kv->get_kv_attn(), {}, std::move(heads_attn), this->ubatches)), + state_recurrent(new llama_kv_cache_recurrent_state(LLAMA_MEMORY_STATUS_SUCCESS, kv->get_kv_recurrent(), {}, this->ubatches)) {} bool llama_kv_cache_hybrid_recurrent_state::next() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); + state_attn ->next(); + state_recurrent->next(); + if (++i_next >= ubatches.size()) { return false; } @@ -222,10 +217,12 @@ bool llama_kv_cache_hybrid_recurrent_state::next() { bool llama_kv_cache_hybrid_recurrent_state::apply() { assert(status == LLAMA_MEMORY_STATUS_SUCCESS); - kv->get_kv_attn() ->apply_ubatch(heads_attn[i_next], ubatches[i_next]); - kv->get_kv_recurrent()->find_slot(ubatches[i_next]); + bool res = true; - return true; + res = res & state_attn ->apply(); + res = res & state_recurrent->apply(); + + return res; } std::vector & llama_kv_cache_hybrid_recurrent_state::out_ids() { diff --git a/src/llama-kv-cache-hybrid-recurrent.h b/src/llama-kv-cache-hybrid-recurrent.h index 93bf72ec34837..17a72c613d7b9 100644 --- a/src/llama-kv-cache-hybrid-recurrent.h +++ b/src/llama-kv-cache-hybrid-recurrent.h @@ -104,7 +104,6 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { // init update explicit llama_kv_cache_hybrid_recurrent_state( - llama_kv_cache_hybrid_recurrent * kv, llama_kv_cache_unified_state * state_unified, llama_kv_cache_recurrent_state * state_recurrent); @@ -135,14 +134,11 @@ class llama_kv_cache_hybrid_recurrent_state : public llama_memory_state_i { private: const llama_memory_status status; - llama_kv_cache_hybrid_recurrent * kv; - llama_sbatch sbatch; // the index of the next ubatch to process size_t i_next = 0; - std::vector heads_attn; std::vector ubatches; const llama_memory_state_ptr state_attn; From 95b669869811b5def318348bf1e0a92c85855aab Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 16 Jun 2025 15:17:28 -0600 Subject: [PATCH 85/92] refactor: Use a common build_recurrent_state method that is cache-agnostic This reduces the code duplication between the different build_rs impls and also retains a similar signature to the previous build_recurrent_state method while standardizing on the input-dispatched build_rs implementation. Branch: HybridRecurrentCache Signed-off-by: Gabe Goodhart --- src/llama-graph.cpp | 89 +++++++++++++++++---------------------------- src/llama-graph.h | 9 +++++ 2 files changed, 42 insertions(+), 56 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 039718c04e401..beebcd7555faf 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1494,32 +1494,15 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } - -llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { - const auto * kv_state = static_cast(mstate); - - auto inp = std::make_unique(kv_state); - - const auto n_kv = kv_state->get_n_kv(); - - auto & cur = inp->s_copy; - - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); - ggml_set_input(cur); - - return (llm_graph_input_rs *) res->add_input(std::move(inp)); -} - -ggml_tensor * llm_graph_context::build_rs( - llm_graph_input_rs * inp, +ggml_tensor * llm_graph_context::build_recurrent_state( + const llama_kv_cache_recurrent_state * kv_state, ggml_cgraph * gf, ggml_tensor * s, + ggml_tensor * state_copy, int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate); - const auto n_kv = kv_state->get_n_kv(); const auto kv_head = kv_state->get_head(); const auto rs_zero = kv_state->get_rs_z(); @@ -1537,7 +1520,7 @@ ggml_tensor * llm_graph_context::build_rs( // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); + output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); ggml_build_forward_expand(gf, output_states); } else { // FIXME: make the gathering operation happen before the copy below @@ -1546,7 +1529,7 @@ ggml_tensor * llm_graph_context::build_rs( } // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, @@ -1555,63 +1538,57 @@ ggml_tensor * llm_graph_context::build_rs( return output_states; } -llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { - auto inp = std::make_unique( - static_cast(mstate)); +llm_graph_input_rs * llm_graph_context::build_rs_inp_recurrent() const { + const auto * kv_state = static_cast(mstate); - const auto n_kv = inp->kv_state->get_n_kv(); + auto inp = std::make_unique(kv_state); + + const auto n_kv = kv_state->get_n_kv(); auto & cur = inp->s_copy; cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); ggml_set_input(cur); - return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp)); + return (llm_graph_input_rs *) res->add_input(std::move(inp)); } ggml_tensor * llm_graph_context::build_rs( - llm_graph_input_rs_hybrid_recurrent * inp, + llm_graph_input_rs * inp, ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, bool avoid_copies) const { - const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + const auto * kv_state = static_cast(mstate); + return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies); +} - const auto n_kv = kv_state->get_n_kv(); - const auto kv_head = kv_state->get_head(); - const auto rs_zero = kv_state->get_rs_z(); +llm_graph_input_rs_hybrid_recurrent * llm_graph_context::build_rs_inp_hybrid_recurrent() const { + auto inp = std::make_unique( + static_cast(mstate)); - ggml_tensor * states = ggml_reshape_2d(ctx0, s, state_size, kv_state->get_size()); + const auto n_kv = inp->kv_state->get_n_kv(); - // Clear a single state which will then be copied to the other cleared states. - // Note that this is a no-op when the view is zero-sized. - ggml_tensor * state_zero = ggml_view_1d(ctx0, states, state_size*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); - ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); + auto & cur = inp->s_copy; - ggml_tensor * output_states; + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv); + ggml_set_input(cur); - if (!avoid_copies) { - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // {state_size, kv_size} -> {state_size, n_seqs} - output_states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); - ggml_build_forward_expand(gf, output_states); - } else { - // FIXME: make the gathering operation happen before the copy below - // (maybe with an optional lambda function passed as a parameter instead of `avoid_copies`?) - output_states = states; - } + return (llm_graph_input_rs_hybrid_recurrent *) res->add_input(std::move(inp)); +} - // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); - ggml_build_forward_expand(gf, - ggml_cpy(ctx0, - states_extra, - ggml_view_1d(ctx0, s, state_size*(n_kv - n_seqs), (kv_head + n_seqs)*state_size*ggml_element_size(s)))); +ggml_tensor * llm_graph_context::build_rs( + llm_graph_input_rs_hybrid_recurrent * inp, + ggml_cgraph * gf, + ggml_tensor * s, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies) const { - return output_states; + const auto * kv_state = static_cast(mstate)->get_state_recurrent(); + return build_recurrent_state(kv_state, gf, s, inp->s_copy, state_size, n_seqs, avoid_copies); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( diff --git a/src/llama-graph.h b/src/llama-graph.h index 77f19a673c3dd..f705ea81d11c6 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -622,6 +622,15 @@ struct llm_graph_context { // recurrent // + ggml_tensor * build_recurrent_state( + const llama_kv_cache_recurrent_state * kv_state, + ggml_cgraph * gf, + ggml_tensor * s, + ggml_tensor * state_copy, + int32_t state_size, + int32_t n_seqs, + bool avoid_copies = false) const; + llm_graph_input_rs * build_rs_inp_recurrent() const; ggml_tensor * build_rs( From aff19ae21af1a51170856464a9cdea47e7e15ee5 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 17 Jun 2025 13:41:03 +0400 Subject: [PATCH 86/92] more clean ups --- src/llama-graph.cpp | 12 +++++- src/llama-hparams.cpp | 4 +- src/llama-hparams.h | 2 - src/llama-model.cpp | 90 ++++++++++++++++++------------------------- 4 files changed, 50 insertions(+), 58 deletions(-) diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 6f69e80b890ea..33df2c06039cf 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -604,6 +604,11 @@ ggml_tensor * llm_graph_context::build_ffn( case LLM_FFN_PAR: { cur = build_lora_mm(gate, cur); + + if (arch == LLM_ARCH_FALCON_H1) { + cur = ggml_scale(ctx0, cur, hparams.mlp_gate_multiplier); + } + cb(cur, "ffn_gate", il); } break; } @@ -691,6 +696,9 @@ ggml_tensor * llm_graph_context::build_ffn( // GLM4 seems to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } + if (arch == LLM_ARCH_FALCON_H1) { + cur = ggml_scale(ctx0, cur, hparams.mlp_down_multiplier); + } } if (down_b) { @@ -1519,11 +1527,11 @@ ggml_tensor * llm_graph_context::build_recurrent_state( // copy states // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv // {state_size, kv_size} -> {state_size, n_seqs} - output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_seqs, 0)); + output_states = get_state_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); ggml_build_forward_expand(gf, output_states); // copy extra states which won't be changed further (between n_seqs and n_kv) - ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, inp->s_copy, n_kv - n_seqs, n_seqs*inp->s_copy->nb[0])); + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, states_extra, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 17bf141198655..27f24d7fb679b 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -74,7 +74,7 @@ uint32_t llama_hparams::n_embd_k_s() const { // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed // Corresponds to Mamba's conv_states size - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_mamba_d_ssm + 2*ssm_n_group*ssm_d_state); } uint32_t llama_hparams::n_embd_v_s() const { @@ -84,7 +84,7 @@ uint32_t llama_hparams::n_embd_v_s() const { } // corresponds to Mamba's ssm_states size - return ssm_d_state * ssm_d_inner; + return ssm_d_state * ssm_mamba_d_ssm; } bool llama_hparams::recurrent_layer(uint32_t il) const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 694f6b7295f5b..5241631d5fb59 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -146,8 +146,6 @@ struct llama_hparams { // for hybrid state space models std::array recurrent_layer_arr; - bool ssm_dt_b_c_rms = false; - float f_clamp_kqv = 0.0f; float f_max_alibi_bias = 0.0f; float f_logit_scale = 0.0f; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 498f7e1b117d5..5f5167ef29c3b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4327,7 +4327,6 @@ bool llama_model::load_tensors(llama_model_loader & ml) { case LLM_ARCH_FALCON_H1: { // Common - const float layer_norm_epsilon = hparams.f_norm_rms_eps; // TODO layer_norm_epsilon const int64_t hidden_size = hparams.n_embd; // hidden_size const int64_t vocab_size = hparams.vocab_size; // vocab_size @@ -4338,17 +4337,12 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t ssm_intermediate_size = hparams.ssm_mamba_d_ssm > 0 ? hparams.ssm_mamba_d_ssm : int(hparams.mamba_expand * hidden_size); // TODO expand const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size; - const int64_t ssm_head_dim = hparams.ssm_head_dim; // ssm_head_dim - const bool ssm_rms_norm = hparams.mamba_rms_norm; - const int64_t ssm_chunk_size = hparams.chunk_size; const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads; - const int64_t ssm_groups_time_state_size = ssm_n_groups * ssm_state_size; // groups_time_state_size // attn params const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head const int64_t attn_num_key_value_head = hparams.n_head_kv(0); const int64_t attn_head_dim = hparams.attn_head_dim > 0 ? hparams.attn_head_dim : hidden_size / attn_num_attention_head; - const int64_t attn_num_key_value_groups = attn_num_attention_head / attn_num_key_value_head; // ffn params const int64_t ffn_intermediate_size = hparams.n_ff(0); @@ -5564,59 +5558,51 @@ struct llm_build_falcon_h1 : public llm_graph_context { cur = ggml_scale(ctx0, cur, hparams.attention_in_multiplier); // self-attention - { - ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); - cb(Qcur, "Qcur", il); - - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); - cb(Kcur, "Kcur", il); - - ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); - cb(Vcur, "Vcur", il); - - Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); - Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier); + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); - Qcur = ggml_rope_ext( - ctx0, Qcur, inp_pos, nullptr, - n_rot, 0, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow); + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); - Kcur = ggml_rope_ext( - ctx0, Kcur, inp_pos, nullptr, - n_rot, 0, n_ctx_orig, freq_base, freq_scale, - ext_factor, attn_factor, beta_fast, beta_slow - ); + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = ggml_scale(ctx0, Kcur, hparams.key_multiplier); - cb(Qcur, "Qcur", il); - cb(Kcur, "Kcur", il); - cb(Vcur, "Vcur", il); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - //std::printf("Here %d\n", il); - ggml_tensor * attn_out = build_attn(inp_attn, gf, - model.layers[il].wo, NULL, - Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); - //std::printf("Here %d - after\n", il); + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); - attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier); - - cur = build_norm(inpL, - model.layers[il].attn_norm, NULL, - LLM_NORM_RMS, il); - // Mamba2 layer - // std::printf("Here 2\n"); + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, 0, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); - cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier); - ggml_tensor * ssm_out = build_mamba2_layer(inp_rs, gf, cur, ubatch, il); - ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier); - // std::printf("Here\n"); - // // Aggregation - cur = ggml_add(ctx0, attn_out, ssm_out); - } + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + ggml_tensor * attn_out = build_attn(inp_attn, gf, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); + attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier); + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + // Mamba2 layer + cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier); + ggml_tensor * ssm_out = build_mamba2_layer(inp_rs, gf, inpSA, ubatch, il); + ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier); + // // Aggregation + cur = ggml_add(ctx0, attn_out, ssm_out); + cur = ggml_add(ctx0, cur, inpSA); if (il == n_layer - 1) { // skip computing output for unused tokens @@ -5625,7 +5611,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + ggml_tensor * ffn_inp = cur; cb(ffn_inp, "ffn_inp", il); // feed-forward network From 9150fc917fb4a3a7857da469b009281f548852cf Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 17 Jun 2025 14:08:55 +0400 Subject: [PATCH 87/92] fix inference --- src/llama-model.cpp | 44 ++++++++++++++++++++++++-------------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5f5167ef29c3b..b64c3529f1e28 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -5592,18 +5592,24 @@ struct llm_build_falcon_h1 : public llm_graph_context { model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); attn_out = ggml_scale(ctx0, attn_out, hparams.attention_out_multiplier); + cb(attn_out, "attn_out", il); cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // Mamba2 layer cur = ggml_scale(ctx0, cur, hparams.ssm_in_multiplier); - ggml_tensor * ssm_out = build_mamba2_layer(inp_rs, gf, inpSA, ubatch, il); + cb(cur, "ssm_in", il); + + ggml_tensor * ssm_out = build_mamba2_layer(inp_rs, gf, cur, ubatch, il); ssm_out = ggml_scale(ctx0, ssm_out, hparams.ssm_out_multiplier); + cb(ssm_out, "ssm_out", il); + // // Aggregation cur = ggml_add(ctx0, attn_out, ssm_out); - cur = ggml_add(ctx0, cur, inpSA); - + inpSA = ggml_add(ctx0, cur, inpSA); + cb(cur, "layer_out", il); + if (il == n_layer - 1) { // skip computing output for unused tokens ggml_tensor * inp_out_ids = build_inp_out_ids(); @@ -5611,26 +5617,24 @@ struct llm_build_falcon_h1 : public llm_graph_context { inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } - ggml_tensor * ffn_inp = cur; + ggml_tensor * ffn_inp = inpSA; cb(ffn_inp, "ffn_inp", il); // feed-forward network - { - cur = build_norm(ffn_inp, - model.layers[il].ffn_norm, NULL, - LLM_NORM_RMS, il); - cb(cur, "ffn_norm", il); + cur = build_norm(ffn_inp, + model.layers[il].ffn_pre_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); - cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, - LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "ffn_out", il); - } + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL, + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); - cur = ggml_add(ctx0, cur, ffn_inp); + cur = ggml_add(ctx0, cur, inpSA); cur = build_cvec(cur, il); cb(cur, "l_out", il); @@ -5642,7 +5646,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { cur = inpL; cur = build_norm(cur, - model.output_norm, NULL, + model.final_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); @@ -5796,7 +5800,7 @@ struct llm_build_falcon_h1 : public llm_graph_context { // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - // cb(cur, "mamba_out", il); + cb(cur, "mamba_out", il); return cur; } }; From 11857e8fd56f95ca5a23f7a6cb1dc93c698715ec Mon Sep 17 00:00:00 2001 From: Ibrahim Khadraoui Date: Tue, 17 Jun 2025 14:56:21 +0400 Subject: [PATCH 88/92] Commit credits attribution From ebfcb6fe65d5178ed0c41cfb253c1094d806ca01 Mon Sep 17 00:00:00 2001 From: Brahim Farhat Date: Tue, 17 Jun 2025 14:59:15 +0400 Subject: [PATCH 89/92] Commit credits attribution From 82f59d0d539e7a74820810059c55e3eef1f2b129 Mon Sep 17 00:00:00 2001 From: Hamza Yous Date: Tue, 17 Jun 2025 14:59:21 +0400 Subject: [PATCH 90/92] Commit credits attribution From 837be3edf92f994cdccee8358c96808f6076fd42 Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 17 Jun 2025 15:06:49 +0400 Subject: [PATCH 91/92] add convert script --- convert_hf_to_gguf.py | 151 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 1 deletion(-) diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 296c7c967e33a..4af85ba68a4d3 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -604,7 +604,10 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]: from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab)) + vocab_size = max( + self.hparams.get("vocab_size", len(tokenizer.vocab)), + len(tokenizer.vocab) + ) assert max(tokenizer.vocab.values()) < vocab_size tokpre = self.get_vocab_base_pre(tokenizer) @@ -683,6 +686,14 @@ def get_vocab_base_pre(self, tokenizer) -> str: if chkhsh == "9d032fcbd5501f4a38150912590928bfb36091efb5df11b8e2124b0390e3fb1e": # ref: https://huggingface.co/tiiuae/Falcon3-7B-Base res = "falcon3" + if ( + chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86" or + chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896" or + chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b" or + chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6" + ): + # ref: + res = "falcon-h1" if chkhsh == "8e62295832751ca1e8f92f2226f403dea30dc5165e448b5bfa05af5340c64ec7": # ref: https://huggingface.co/BAAI/bge-large-zh-v1.5 res = "bert-bge-large" @@ -4812,6 +4823,144 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield (new_name, data_torch) +@Model.register("FalconH1ForCausalLM") +class FalconH1Model(Mamba2Model): + model_arch = gguf.MODEL_ARCH.FALCON_H1 + + def __init__(self, *args, **kwargs): + # Set the hparam prefixes for Falcon Mamba2 + self.hparam_prefixes = ["mamba"] + + # Initialize the base Mamba2Model + super().__init__(*args, **kwargs) + + # Use Llama conversion for attention + self._transformer_model_class: type[Model] = LlamaModel + + # n_group and d_inner are used during reshape_tensors for mamaba2 + self.d_model = self.find_hparam(["hidden_size", "d_model"]) + self.n_group = self.find_hparam(["n_groups"]) + self.d_inner = self.find_hparam(["expand"]) * self.d_model + + # Initialize any Falcon Mamba2 specific attributes + self.has_attention = True # Falcon Mamba2 has attention components + + # Load Falcon-H1 multipliers from hyperparameters + self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True) + self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True) + self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True) + self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True) + self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True) + self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True) + self.intermediate_size = self.find_hparam(["intermediate_size"]) + + def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any: + prefixed = [] + for pfx in self.hparam_prefixes: + prefixed.extend( + "_".join([pfx, k]) + for k in keys + ) + keys = list(keys) + prefixed + return super().find_hparam(keys, *args, **kwargs) + + def _generate_mup_vector(self, block_id: int) -> torch.Tensor: + zxbcdt_multipliers = self.hparams["ssm_multipliers"] + intermediate_size = self.hparams["mamba_d_ssm"] + groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"] + vector_shape = (2 * intermediate_size + 2 * groups_time_state_size + self.hparams["mamba_n_heads"]) + + mup_vector = torch.ones(1, 1, vector_shape) + mup_vector[:, :, :intermediate_size] *= zxbcdt_multipliers[0] + mup_vector[:, :, intermediate_size:2 * intermediate_size] *= zxbcdt_multipliers[1] + mup_vector[:, :, 2 * intermediate_size:2 * intermediate_size + groups_time_state_size] *= zxbcdt_multipliers[2] + mup_vector[:, :, 2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size] *= zxbcdt_multipliers[3] + mup_vector[:, :, 2 * intermediate_size + 2 * groups_time_state_size:] *= zxbcdt_multipliers[4] + + return mup_vector + + def get_tensors(self) -> Iterator[tuple[str, Tensor]]: + for name, tensor in super().get_tensors(): + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + name = name.removeprefix("model.") + yield name, tensor + + if self.ssm_multipliers is not None: + # Insert MUP vector after mamba.dt_bias + if "mamba.dt_bias" in name: + block_match = re.search(r"(?:model\.layers\.)?(\d+)\.mamba\.dt_bias", name) + if block_match: + block_id = int(block_match.group(1)) + # Generate MUP vector with correct name format + mup_tensor = self._generate_mup_vector(block_id) + mup_name = f"blk.{block_id}.ssm_mup_vec" + logger.debug(f"Inserting MUP vector for block {block_id}: {mup_name}") + yield mup_name, mup_tensor + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + ## General Params ## + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0)) + self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) + self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"]) + + ## Mamba mixer params ## + self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"])) + self.gguf_writer.add_ssm_group_count(self.n_group) + self.gguf_writer.add_ssm_inner_size(self.d_inner) + self.gguf_writer.add_ssm_head_dim(d_head := self.find_hparam(["d_head"])) + self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"])) + + ## Attention params ## + self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) + self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"] if "num_key_value_heads" in self.hparams else self.hparams["num_attention_heads"]) + self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"]) + self.gguf_writer.add_key_length(self.hparams["head_dim"]) + self.gguf_writer.add_value_length(self.hparams["head_dim"]) + self.gguf_writer.add_float32("falcon-h1.key_multiplier", self.hparams["key_multiplier"]) + + ## Other params + self.gguf_writer.add_float32("falcon-h1.lm_head_multiplier", self.hparams["lm_head_multiplier"]) + self.gguf_writer.add_float32("falcon-h1.embedding_multiplier", self.hparams["embedding_multiplier"]) + + ## Validation ## + assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported" + assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}" + + + # Add Falcon Mamba2 specific configuration + self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_chunk_size", self.hparams["mamba_chunk_size"]) + self.gguf_writer.add_uint32("falcon-h1.attention.head_dim", self.hparams["head_dim"]) + self.gguf_writer.add_uint32("falcon-h1.ssm.mamba_d_ssm", self.hparams["mamba_d_ssm"]) + self.gguf_writer.add_uint32("falcon-h1.num_attention_heads", self.find_hparam(["num_attention_heads"])) + self.gguf_writer.add_uint32("falcon-h1.num_key_value_heads", + self.find_hparam(["num_key_value_heads"], optional=True) or + self.find_hparam(["num_attention_heads"])) + + # Add multipliers as metadata instead of tensors + self.gguf_writer.add_float32("falcon-h1.attention_in_multiplier", self.attention_in_multiplier) + self.gguf_writer.add_float32("falcon-h1.attention_out_multiplier", self.attention_out_multiplier) + self.gguf_writer.add_float32("falcon-h1.ssm_in_multiplier", self.ssm_in_multiplier) + self.gguf_writer.add_float32("falcon-h1.ssm_out_multiplier", self.ssm_out_multiplier) + + # Add MLP multipliers + if isinstance(self.mlp_multipliers, (list, tuple)) and len(self.mlp_multipliers) == 2: + self.gguf_writer.add_float32("falcon-h1.mlp_gate_multiplier", self.mlp_multipliers[0]) + self.gguf_writer.add_float32("falcon-h1.mlp_down_multiplier", self.mlp_multipliers[1]) + + # Add has MuP flag if SSM multipliers are present + if self.ssm_multipliers is not None: + self.gguf_writer.add_bool("falcon-h1.ssm.has_mup", True) + + # Add any other Falcon Mamba2 specific configuration + self.gguf_writer.add_bool("falcon-h1.mamba_use_mlp", self.find_hparam(["mamba_use_mlp"], optional=True)) + self.gguf_writer.add_bool("falcon-h1.mamba_norm_before_gate", self.find_hparam(["mamba_norm_before_gate"], optional=True)) + self.gguf_writer.add_bool("falcon-h1.mamba_rms_norm", self.find_hparam(["mamba_rms_norm"], optional=True)) + self.gguf_writer.add_float32("falcon-h1.rope_theta", self.find_hparam(["rope_theta"], optional=True)) + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R From 41f25cf79c8c57b9e500e2e0ca9acb0fbdcce55a Mon Sep 17 00:00:00 2001 From: younesbelkada Date: Tue, 17 Jun 2025 15:12:41 +0400 Subject: [PATCH 92/92] add missing elements in py files --- gguf-py/gguf/constants.py | 42 +++++++++++++++++++++++++++++++++- gguf-py/gguf/gguf_writer.py | 10 ++++++++ gguf-py/gguf/tensor_mapping.py | 17 ++++++++++++++ 3 files changed, 68 insertions(+), 1 deletion(-) diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 42a05ebb568a7..9476f1a1b11c2 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -166,6 +166,7 @@ class SSM: TIME_STEP_RANK = "{arch}.ssm.time_step_rank" GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" + HEAD_DIM = "{arch}.ssm.head_dim" class WKV: HEAD_SIZE = "{arch}.wkv.head_size" @@ -348,6 +349,7 @@ class MODEL_ARCH(IntEnum): BAILINGMOE = auto() DOTS1 = auto() ARCEE = auto() + FALCON_H1 = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -408,6 +410,7 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_MUP_VEC = auto() TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() TIME_MIX_W2 = auto() @@ -633,6 +636,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE: "bailingmoe", MODEL_ARCH.DOTS1: "dots1", MODEL_ARCH.ARCEE: "arcee", + MODEL_ARCH.FALCON_H1: "falcon-h1", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -670,7 +674,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_GATE_INP: "blk.{bid}.ffn_gate_inp", MODEL_TENSOR.FFN_GATE_INP_SHEXP: "blk.{bid}.ffn_gate_inp_shexp", MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_PRE_NORM: "blk.{bid}.ffn_pre_norm", MODEL_TENSOR.FFN_POST_NORM: "blk.{bid}.post_ffw_norm", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", @@ -693,6 +697,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_MUP_VEC: "blk.{bid}.ssm_mup_vec", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2", @@ -2174,6 +2179,41 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.BAILINGMOE: [ MODEL_TENSOR.ROPE_FREQS, ], + MODEL_ARCH.FALCON_H1: [ + # Token embedding + MODEL_TENSOR.TOKEN_EMBD, + + # Input layernorm + MODEL_TENSOR.ATTN_NORM, + + # Attention components + MODEL_TENSOR.ATTN_Q, # Query projection + MODEL_TENSOR.ATTN_K, # Key projection + MODEL_TENSOR.ATTN_V, # Value projection + MODEL_TENSOR.ATTN_OUT, # Output projection + + # SSM components (Mamba2 specific) + MODEL_TENSOR.SSM_MUP_VEC, # Mup vector + MODEL_TENSOR.SSM_IN, # Input projection for SSM + MODEL_TENSOR.SSM_CONV1D, # Convolution layer + MODEL_TENSOR.SSM_DT, # Delta time projection + MODEL_TENSOR.SSM_A, # A parameter (log form) + MODEL_TENSOR.SSM_D, # D parameter + MODEL_TENSOR.SSM_NORM, # Normalization in SSM + MODEL_TENSOR.SSM_OUT, # Output projection + + # Pre-feedforward layernorm + MODEL_TENSOR.FFN_PRE_NORM, + + # Feed-forward network components + MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU) + MODEL_TENSOR.FFN_DOWN, # Down projection + MODEL_TENSOR.FFN_UP, # Up projection + + # Post-feedforward layernorm + MODEL_TENSOR.OUTPUT_NORM, # Final layer norm + MODEL_TENSOR.OUTPUT, # Output projection (lm_head) + ], } # diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index e1e8079f377ff..fde78894d4fac 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -849,6 +849,16 @@ def add_ssm_group_count(self, value: int) -> None: def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) + def add_ssm_head_dim(self, value: int) -> None: + self.add_uint32(Keys.SSM.HEAD_DIM.format(arch=self.arch), value) + + def add_attn_head_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.HEAD_COUNT.format(arch=self.arch), count) + + def add_key_value_head_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.HEAD_COUNT_KV.format(arch=self.arch), count) + + def add_tokenizer_model(self, model: str) -> None: self.add_string(Keys.Tokenizer.MODEL, model) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 23ef00c3defa3..79d336539c483 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -286,6 +286,7 @@ class TensorNameMap: # Post feed-forward norm MODEL_TENSOR.FFN_PRE_NORM: ( "model.layers.{bid}.pre_feedforward_layernorm", # gemma2 + "model.layers.{bid}.pre_ff_layernorm.weight", # falcon-h1 ), # Post feed-forward norm @@ -356,6 +357,7 @@ class TensorNameMap: "model.layers.{bid}.block_sparse_moe.experts.w3", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.up_proj", # llama4 "encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe + "model.layers.{bid}.feed_forward.up_proj", # falcon-h1 ), MODEL_TENSOR.FFN_UP_SHEXP: ( @@ -392,6 +394,7 @@ class TensorNameMap: "model.layers.{bid}.mlp.experts.gate_proj", # qwen2moe olmoe (merged) "model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged) "model.layers.{bid}.feed_forward.experts.gate_proj", # llama4 + "model.layers.{bid}.feed_forward.down_proj", # falcon-h1 ), MODEL_TENSOR.FFN_GATE_SHEXP: ( @@ -431,6 +434,14 @@ class TensorNameMap: "transformer_encoder.{bid}.ffn.w3", # neobert ), + MODEL_TENSOR.SSM_MUP_VEC: ( + "model.layers.{bid}.mamba.mup_vector", # falcon-h1 + ), + + MODEL_TENSOR.SSM_NORM: ( + "model.layers.{bid}.mamba.norm", + ), + MODEL_TENSOR.FFN_DOWN_EXP: ( "layers.{bid}.feed_forward.experts.w2", # mixtral (merged) "transformer.decoder_layer.{bid}.moe.linear_1", # Grok (merged) @@ -483,11 +494,13 @@ class TensorNameMap: MODEL_TENSOR.SSM_IN: ( "model.layers.{bid}.in_proj", "backbone.layers.{bid}.mixer.in_proj", + "model.layers.{bid}.mamba.in_proj", # falcon-h1 ), MODEL_TENSOR.SSM_CONV1D: ( "model.layers.{bid}.conv1d", "backbone.layers.{bid}.mixer.conv1d", + "model.layers.{bid}.mamba.conv1d", # falcon-h1 ), MODEL_TENSOR.SSM_X: ( @@ -498,16 +511,19 @@ class TensorNameMap: MODEL_TENSOR.SSM_DT: ( "model.layers.{bid}.dt_proj", "backbone.layers.{bid}.mixer.dt_proj", + "model.layers.{bid}.mamba.dt_proj", # falcon-h1 ), MODEL_TENSOR.SSM_A: ( "model.layers.{bid}.A_log", "backbone.layers.{bid}.mixer.A_log", + "model.layers.{bid}.mamba.A_log", # falcon-h1 ), MODEL_TENSOR.SSM_D: ( "model.layers.{bid}.D", "backbone.layers.{bid}.mixer.D", + "model.layers.{bid}.mamba.D", # falcon-h1 ), MODEL_TENSOR.SSM_NORM: ( @@ -517,6 +533,7 @@ class TensorNameMap: MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", + "model.layers.{bid}.mamba.out_proj", # falcon-h1 ), MODEL_TENSOR.TIME_MIX_W0: (