Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 249 additions & 0 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,10 @@ def prepare_tensors(self):
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
# Kimi KDA conv weights should be F32
gguf.MODEL_TENSOR.SSM_CONV1D_Q,
gguf.MODEL_TENSOR.SSM_CONV1D_K,
gguf.MODEL_TENSOR.SSM_CONV1D_V,
)
)
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
Expand Down Expand Up @@ -5013,6 +5017,251 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_factor(1.0)


@ModelBase.register("KimiLinearModel", "KimiLinearForCausalLM")
class KimiLinearModel(TextModel):
"""Kimi-Linear model with hybrid MLA+KDA architecture"""
model_arch = gguf.MODEL_ARCH.KIMI_LINEAR

_experts: list[dict[str, Tensor]] | None = None

def set_vocab(self):
try:
self._set_vocab_gpt2()
return
except Exception:
pass

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
tokpre = self.get_vocab_base_pre(tokenizer)

if tokpre == "kimi-k2":
# Build merges list using the approach similar to HunYuanMoE
merges = []
vocab = {}
mergeable_ranks = tokenizer.model._mergeable_ranks
for token, rank in mergeable_ranks.items():
vocab[QwenModel.token_bytes_to_string(token)] = rank
if len(token) == 1:
continue
merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
if len(merged) == 2:
merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
# Build token list
vocab_size = self.hparams["vocab_size"]
special_tokens = tokenizer.special_tokens
reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in {**vocab, **special_tokens}.items()}
tokens: list[str] = []
toktypes: list[int] = []

for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token = reverse_vocab[i]
tokens.append(token)
if i in special_tokens.values():
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.NORMAL)

self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_token_merges(merges)

special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False)
special_vocab.add_to_gguf(self.gguf_writer)
# override eos id in config.json with tiktoken eos id
self.gguf_writer.add_eos_token_id(tokenizer.eos_id)
else:
raise NotImplementedError(f"Deepseek pre-tokenizer {tokpre!r} is not supported yet!")

def set_gguf_parameters(self):
# note: To enable MLA KV cache, attention needs to be converted into MQA (ie: GQA with 1 group)
self.hparams["num_key_value_heads"] = 1

super().set_gguf_parameters()
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])

if (score_func := self.find_hparam(["moe_router_activation_func"], optional=True)) is not None:
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
else:
raise ValueError(f"Unsupported expert score gating function value: {score_func}")

# KDA & MLA params
# Get ssm_d_conv from linear_attn_config.short_conv_kernel_size or ssm_d_conv
linear_attn_config = self.hparams.get("linear_attn_config", {})
# n_head == 0 for KDA layers, n_head > 0 for MLA layers
# full_attention_layers list will be used to distingush layer type
_num_kv_heads = list()
_full_attn_layers = linear_attn_config["full_attn_layers"]
for il in range(self.hparams["num_hidden_layers"]):
if il + 1 in _full_attn_layers:
_num_kv_heads.append(self.hparams["num_key_value_heads"])
else:
_num_kv_heads.append(0)
assert len(_num_kv_heads) == self.hparams["num_hidden_layers"]
self.gguf_writer.add_head_count_kv(_num_kv_heads)

if (ssm_d_conv := linear_attn_config.get("short_conv_kernel_size")) is not None:
self.gguf_writer.add_ssm_conv_kernel(ssm_d_conv)
if (kda_head_dim := linear_attn_config.get("head_dim")) is not None:
self.gguf_writer.add_kda_head_dim(kda_head_dim)

# MLA params - use add_* methods that handle arch substitution
# Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv)
if (q_lora_rank := self.find_hparam(["q_lora_rank", "n_lora_q"], optional=False)) is not None:
self.gguf_writer.add_q_lora_rank(q_lora_rank)
if (kv_lora_rank := self.find_hparam(["kv_lora_rank", "n_lora_kv"], optional=False)) is not None:
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)

# MLA head dimensions
# Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim
qk_nope_head_dim = self.hparams.get("qk_nope_head_dim")
qk_rope_head_dim = self.hparams.get("qk_rope_head_dim")
v_head_dim = self.hparams.get("v_head_dim")
# To enable MLA KV cache, MLA needs to be converted into MQA with larger heads, then decompresses to MHA
self.gguf_writer.add_key_length(self.hparams["kv_lora_rank"] + self.hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length(self.hparams["kv_lora_rank"])

# Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
if "n_embd_head_k_mla" in self.hparams:
self.gguf_writer.add_key_length_mla(self.hparams["n_embd_head_k_mla"])
elif qk_nope_head_dim is not None and qk_rope_head_dim is not None:
n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
self.gguf_writer.add_key_length_mla(n_embd_head_k_mla)

# n_embd_head_v_mla = v_head_dim
if "n_embd_head_v_mla" in self.hparams:
self.gguf_writer.add_value_length_mla(self.hparams["n_embd_head_v_mla"])
elif v_head_dim is not None:
self.gguf_writer.add_value_length_mla(v_head_dim)

# Rotation - use qk_rope_head_dim for Kimi
if (rope_dim := self.find_hparam(["qk_rope_head_dim", "n_rot"], optional=True)) is not None:
self.gguf_writer.add_rope_dimension_count(rope_dim)
else:
# Default to head_dim
head_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(head_dim)

if (n_experts := self.find_hparam(["num_experts"], optional=False)) is not None:
self.gguf_writer.add_expert_count(n_experts)
if (n_experts_used := self.find_hparam(["num_experts_per_token"], optional=False)) is not None:
self.gguf_writer.add_expert_used_count(n_experts_used)

# moe_intermediate_size (1024 for Kimi)
if (moe_intermediate_size := self.find_hparam(["moe_intermediate_size"], optional=False)) is not None:
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)

# num_shared_experts (1 for Kimi)
if (num_shared_experts := self.find_hparam(["num_shared_experts"], optional=False)) is not None:
self.gguf_writer.add_expert_shared_count(num_shared_experts)

# first_k_dense_replace (1 for Kimi - first layer uses dense MLP)
if (first_k_dense_replace := self.find_hparam(["first_k_dense_replace"])) is not None:
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)

# Routed scaling factor (expert_weights_scale = 2.446 for Kimi)
if (routed_scaling_factor := self.find_hparam(["routed_scaling_factor"], optional=False)) is not None:
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)

def prepare_tensors(self):
super().prepare_tensors()
if self._experts is not None:
experts = [k for d in self._experts for k in d.keys()]
if len(experts) > 0:
raise ValueError(f"Unprocessed experts: {experts}")

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
logger.info(f"Processing {name}: shape before = {tuple(data_torch.shape)}")

# Handle KDA conv1d weights
# HuggingFace/vLLM stores as [d_inner, d_conv] (2D), memory layout: conv_step changes fastest
# llama.cpp expects ggml ne = [d_conv, 1, d_inner, 1], memory layout: ne[0]=d_conv changes fastest
# GGUF reverses numpy shape when writing, so numpy (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1]
# Memory layouts match: both have conv_step (d_conv) changing fastest
if name.endswith((".q_conv1d.weight", ".k_conv1d.weight", ".v_conv1d.weight")):
# HF shape: [d_inner, d_conv] e.g. [4096, 4]
# Target numpy shape: (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1]
if data_torch.ndim == 2:
d_inner, d_conv = data_torch.shape
# Reshape to (1, d_inner, 1, d_conv) - memory layout preserved (d_conv fastest)
data_torch = data_torch.reshape(1, d_inner, 1, d_conv)
logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]")
elif data_torch.ndim == 3:
# Already 3D [d_inner, 1, d_conv] from unsqueeze
d_inner, _, d_conv = data_torch.shape
data_torch = data_torch.reshape(1, d_inner, 1, d_conv)
logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, 1, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]")

# Kimi specific bias
if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias")

# Handle A_log: iHF stores as [1, 1, num_heads, 1]
# llama.cpp expects ggml ne = [1, num_heads, 1, 1]
# GGUF reverses numpy shape: numpy (1, 1, num_heads, 1) -> ggml ne = [1, num_heads, 1, 1]
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
if name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
logger.info("Changed dt_bias to dt_proj.bias")

# process the experts separately
if name.find("block_sparse_moe.experts") != -1:
n_experts = self.hparams.get("num_local_experts", self.hparams.get("num_experts"))
assert bid is not None

if self._experts is None:
self._experts = [{} for _ in range(self.block_count)]

self._experts[bid][name] = data_torch

if len(self._experts[bid]) >= n_experts * 3:
# merge the experts into a single 3d tensor
tensors = []
# w1: gate, w2: down, w3: up
for wid, tname in [("w1", gguf.MODEL_TENSOR.FFN_GATE_EXP),
("w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP),
("w3", gguf.MODEL_TENSOR.FFN_UP_EXP)]:
datas: list[Tensor] = []
for xid in range(n_experts):
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
datas.append(self._experts[bid][ename])
del self._experts[bid][ename]

data_torch = torch.stack(datas, dim=0)
new_name = self.format_tensor_name(tname, bid)
tensors.append((new_name, data_torch))
return tensors
return []

# note: MLA with the absorption optimization, needs these two split and k_b_proj transposed
if name.endswith("kv_b_proj.weight"):
name_kb = name.replace("kv_b_proj", "k_b_proj")
name_vb = name.replace("kv_b_proj", "v_b_proj")
n_head_kv = self.hparams["num_key_value_heads"]
v_head_dim = self.hparams["v_head_dim"]
qk_nope_head_dim = self.hparams["qk_nope_head_dim"]
logger.info("Split kv_b n_head_kv %d\n" % n_head_kv)
assert data_torch.shape[0] == n_head_kv * (v_head_dim + qk_nope_head_dim)
kv_b = data_torch.view(n_head_kv, v_head_dim + qk_nope_head_dim, data_torch.shape[-1])
k_b, v_b = torch.split(kv_b, [qk_nope_head_dim, v_head_dim], dim=1)
k_b = k_b.transpose(1, 2)
return [(self.map_tensor_name(name_kb), k_b), (self.map_tensor_name(name_vb), v_b)]

mapped_name = self.map_tensor_name(name)
logger.info(f"Returning {mapped_name}: shape after = {tuple(data_torch.shape)}")
return [(mapped_name, data_torch)]


@ModelBase.register("InternLM2ForCausalLM")
class InternLM2Model(TextModel):
model_arch = gguf.MODEL_ARCH.INTERNLM2
Expand Down
65 changes: 65 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ class SSM:
GROUP_COUNT = "{arch}.ssm.group_count"
DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms"

class KDA:
HEAD_DIM = "{arch}.kda.head_dim"

class WKV:
HEAD_SIZE = "{arch}.wkv.head_size"

Expand Down Expand Up @@ -459,6 +462,7 @@ class MODEL_ARCH(IntEnum):
MIMO2 = auto()
LLAMA_EMBED = auto()
MAINCODER = auto()
KIMI_LINEAR = auto()


class VISION_PROJECTOR_TYPE(IntEnum):
Expand Down Expand Up @@ -549,6 +553,14 @@ class MODEL_TENSOR(IntEnum):
SSM_NORM = auto()
SSM_OUT = auto()
SSM_BETA_ALPHA = auto() # qwen3next
SSM_CONV1D_Q = auto() # Kimi Linear
SSM_CONV1D_K = auto() # Kimi Linear
SSM_CONV1D_V = auto() # Kimi Linear
SSM_F_A = auto() # Kimi Linear
SSM_F_B = auto() # Kimi Linear
SSM_BETA = auto() # Kimi Linear
SSM_G_A = auto() # Kimi Linear
SSM_G_B = auto() # Kimi Linear
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
Expand Down Expand Up @@ -880,6 +892,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.MIMO2: "mimo2",
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
MODEL_ARCH.MAINCODER: "maincoder",
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
}

VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
Expand Down Expand Up @@ -967,6 +980,14 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear
MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear
MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear
MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear
MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
Expand Down Expand Up @@ -3377,6 +3398,47 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.KIMI_LINEAR: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_Q_A,
MODEL_TENSOR.ATTN_Q_B,
MODEL_TENSOR.ATTN_KV_A_MQA,
MODEL_TENSOR.ATTN_KV_B,
MODEL_TENSOR.ATTN_K_B,
MODEL_TENSOR.ATTN_V_B,
MODEL_TENSOR.ATTN_Q_A_NORM,
MODEL_TENSOR.ATTN_KV_A_NORM,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.SSM_CONV1D_Q,
MODEL_TENSOR.SSM_CONV1D_K,
MODEL_TENSOR.SSM_CONV1D_V,
MODEL_TENSOR.SSM_F_A,
MODEL_TENSOR.SSM_F_B,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_G_A,
MODEL_TENSOR.SSM_G_B,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.FFN_EXP_PROBS_B,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
],
# TODO
}

Expand Down Expand Up @@ -3704,6 +3766,9 @@ class VisionProjectorType:
KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT
KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS

# KDA
KEY_KDA_HEAD_DIM = Keys.KDA.HEAD_DIM

# tokenization
KEY_TOKENIZER_MODEL = Keys.Tokenizer.MODEL
KEY_TOKENIZER_PRE = Keys.Tokenizer.PRE
Expand Down
3 changes: 3 additions & 0 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,9 @@ def add_ssm_group_count(self, value: int) -> None:
def add_ssm_dt_b_c_rms(self, value: bool) -> None:
self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value)

def add_kda_head_dim(self, value: int) -> None:
self.add_uint32(Keys.KDA.HEAD_DIM.format(arch=self.arch), value)

def add_tokenizer_model(self, model: str) -> None:
self.add_string(Keys.Tokenizer.MODEL, model)

Expand Down
Loading
Loading