diff --git a/slime/backends/megatron_utils/megatron_to_hf/__init__.py b/slime/backends/megatron_utils/megatron_to_hf/__init__.py index 28af98ca4..053dc9bed 100644 --- a/slime/backends/megatron_utils/megatron_to_hf/__init__.py +++ b/slime/backends/megatron_utils/megatron_to_hf/__init__.py @@ -1,6 +1,7 @@ from .deepseekv3 import convert_deepseekv3_to_hf from .glm4 import convert_glm4_to_hf from .glm4moe import convert_glm4moe_to_hf +from .kimi_vl import convert_kimi_k25_to_hf, convert_kimivl_to_hf from .llama import convert_llama_to_hf from .mimo import convert_mimo_to_hf from .processors import quantize_params, remove_padding @@ -50,6 +51,10 @@ def _convert_to_hf_core(args, model_name, name, param): converted_named_tensors = convert_llama_to_hf(args, name, param) elif "mimo" in model_name: converted_named_tensors = convert_mimo_to_hf(args, name, param) + elif "kimivl" in model_name: + converted_named_tensors = convert_kimivl_to_hf(args, name, param) + elif "kimi_k25" in model_name: + converted_named_tensors = convert_kimi_k25_to_hf(args, name, param) else: raise ValueError(f"Unsupported model: {model_name}") diff --git a/slime/backends/megatron_utils/megatron_to_hf/kimi_vl.py b/slime/backends/megatron_utils/megatron_to_hf/kimi_vl.py new file mode 100644 index 000000000..d75eac6d5 --- /dev/null +++ b/slime/backends/megatron_utils/megatron_to_hf/kimi_vl.py @@ -0,0 +1,138 @@ +import re + +import torch + + +def convert_kimivl_to_hf(args, name, param): + if name.startswith("module.module.vision_model."): + hf_name = "vision_tower." + name[len("module.module.vision_model.") :] + return [(hf_name, param)] + + if name.startswith("module.module.multi_modal_projector."): + hf_name = "multi_modal_projector." + name[len("module.module.multi_modal_projector.") :] + return [(hf_name, param)] + + return convert_language_model_to_hf(args, name, param) + + +def convert_kimi_k25_to_hf(args, name, param): + if name.startswith("module.module.vision_tower."): + hf_name = "vision_tower." + name[len("module.module.vision_tower.") :] + return [(hf_name, param)] + + if name.startswith("module.module.mm_projector."): + hf_name = "mm_projector." + name[len("module.module.mm_projector.") :] + return [(hf_name, param)] + + return convert_language_model_to_hf(args, name, param) + + +def convert_language_model_to_hf(args, name, param): + if name == "module.module.language_model.embedding.word_embeddings.weight": + return [("language_model.model.embed_tokens.weight", param)] + if name == "module.module.language_model.output_layer.weight": + return [("language_model.lm_head.weight", param)] + if name == "module.module.language_model.decoder.final_layernorm.weight": + return [("language_model.model.norm.weight", param)] + + try: + head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads + except AttributeError: + head_dim = args.hidden_size // args.num_attention_heads + value_num_per_group = args.num_attention_heads // args.num_query_groups + + decoder_layers_pattern = r"module\.module\.language_model\.decoder\.layers\.(\d+)\.(.+)" + match = re.match(decoder_layers_pattern, name) + if match: + layer_idx, rest = match.groups() + + # experts + expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)" + match = re.match(expert_pattern, rest) + if match: + rest, expert_idx = match.groups() + if rest == "linear_fc1": + gate_weight, up_weight = param.chunk(2, dim=0) + outputs = [ + ( + f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", + gate_weight, + ), + (f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight), + ] + return outputs + elif rest == "linear_fc2": + outputs = [ + (f"language_model.model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param), + ] + return outputs + else: + raise ValueError(f"Unknown expert parameter name: {name}") + + # shared expert + shared_expert_pattern = r"mlp.shared_experts\.(.+)" + match = re.match(shared_expert_pattern, rest) + if match: + rest = match.groups()[0] + if rest == "linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"language_model.model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", gate_weight), + (f"language_model.model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", up_weight), + ] + elif rest == "linear_fc2.weight": + return [(f"language_model.model.layers.{layer_idx}.mlp.shared_experts.down_proj.weight", param)] + else: + raise ValueError(f"Unknown shared expert parameter name: {name}") + + if rest == "self_attention.linear_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.o_proj.weight", param)] + elif rest == "self_attention.linear_q_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_proj.weight", param)] + elif rest == "self_attention.linear_q_down_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_proj.weight", param)] + elif rest == "self_attention.linear_q_up_proj.layer_norm_weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_a_layernorm.weight", param)] + elif rest == "self_attention.linear_q_up_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.q_b_proj.weight", param)] + elif rest == "self_attention.linear_qkv.bias": + param = param.view(args.num_query_groups, -1) + q_bias, k_bias, v_bias = torch.split( + param, + split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim], + dim=1, + ) + q_bias = q_bias.contiguous().flatten() + k_bias = k_bias.contiguous().flatten() + v_bias = v_bias.contiguous().flatten() + return [ + (f"language_model.model.layers.{layer_idx}.self_attn.q_proj.bias", q_bias), + (f"language_model.model.layers.{layer_idx}.self_attn.k_proj.bias", k_bias), + (f"language_model.model.layers.{layer_idx}.self_attn.v_proj.bias", v_bias), + ] + elif rest == "mlp.linear_fc1.weight": + gate_weight, up_weight = param.chunk(2, dim=0) + return [ + (f"language_model.model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight), + (f"language_model.model.layers.{layer_idx}.mlp.up_proj.weight", up_weight), + ] + elif rest == "mlp.linear_fc2.weight": + return [(f"language_model.model.layers.{layer_idx}.mlp.down_proj.weight", param)] + elif rest == "self_attention.linear_qkv.layer_norm_weight" or rest == "input_layernorm.weight": + return [(f"language_model.model.layers.{layer_idx}.input_layernorm.weight", param)] + elif rest == "mlp.linear_fc1.layer_norm_weight": + return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + elif rest == "self_attention.linear_kv_down_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_proj_with_mqa.weight", param)] + elif rest == "self_attention.linear_kv_up_proj.layer_norm_weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_a_layernorm.weight", param)] + elif rest == "self_attention.linear_kv_up_proj.weight": + return [(f"language_model.model.layers.{layer_idx}.self_attn.kv_b_proj.weight", param)] + elif rest == "pre_mlp_layernorm.weight": + return [(f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight", param)] + elif rest == "mlp.router.weight": + return [(f"language_model.model.layers.{layer_idx}.mlp.gate.weight", param)] + elif rest == "mlp.router.expert_bias": + return [(f"language_model.model.layers.{layer_idx}.mlp.gate.e_score_correction_bias", param)] + + raise ValueError(f"Unknown parameter name: {name}")