Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions slime/backends/megatron_utils/megatron_to_hf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .qwen2 import convert_qwen2_to_hf
from .qwen3_next import convert_qwen3_next_to_hf
from .qwen3moe import convert_qwen3moe_to_hf
from .qwen3_vl import convert_qwen3vl_to_hf


# TODO unify w/ `convert_to_hf`
Expand Down Expand Up @@ -41,6 +42,8 @@ def _convert_to_hf_core(args, model_name, name, param):
converted_named_tensors = convert_qwen3moe_to_hf(args, name, param)
elif "qwen3next" in model_name:
converted_named_tensors = convert_qwen3_next_to_hf(args, name, param)
elif "qwen3vl" in model_name:
converted_named_tensors = convert_qwen3vl_to_hf(args, name, param)
elif "qwen2" in model_name or "qwen3" in model_name:
converted_named_tensors = convert_qwen2_to_hf(args, name, param)
elif "llama" in model_name:
Expand Down
100 changes: 100 additions & 0 deletions slime/backends/megatron_utils/megatron_to_hf/qwen3-vl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import re
import torch


def convert_qwen3vl_to_hf(args, name, param):
if name.startswith("module.module.language_model."):
name = "module.module." + name[len("module.module.language_model."):]

# (Optional safety) if you ever see extra "module." prefixes
while name.startswith("module.module.module."):
name = name.replace("module.module.module.", "module.module.", 1)

if name.startswith("module.module.vision_model."):
hf_name = "model.visual." + name[len("module.module.vision_model."):]
return [(hf_name, param)]

if name == "module.module.embedding.word_embeddings.weight":
return [("model.language_model.embed_tokens.weight", param)]

if name == "module.module.output_layer.weight":
# Your key list has lm_head.weight at top-level
return [("lm_head.weight", param)]

if name == "module.module.decoder.final_layernorm.weight":
return [("model.language_model.norm.weight", param)]

try:
head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads
except AttributeError:
head_dim = args.hidden_size // args.num_attention_heads
value_num_per_group = args.num_attention_heads // args.num_query_groups

decoder_layers_pattern = r"module\.module\.decoder\.layers\.(\d+)\.(.+)"
match = re.match(decoder_layers_pattern, name)
if match:
layer_idx, rest = match.groups()
# Everything goes under model.language_model.layers.{i}.*
base = f"model.language_model.layers.{layer_idx}"

if rest == "self_attention.linear_proj.weight":
return [(f"{base}.self_attn.o_proj.weight", param)]

elif rest == "self_attention.linear_qkv.weight":
# Keep your original split logic -> q_proj/k_proj/v_proj
param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size)
q_param, k_param, v_param = torch.split(
param,
split_size_or_sections=[value_num_per_group, 1, 1],
dim=1
)
q_param = q_param.reshape(-1, args.hidden_size)
k_param = k_param.reshape(-1, args.hidden_size)
v_param = v_param.reshape(-1, args.hidden_size)
return [
(f"{base}.self_attn.q_proj.weight", q_param),
(f"{base}.self_attn.k_proj.weight", k_param),
(f"{base}.self_attn.v_proj.weight", v_param),
]

elif rest == "self_attention.linear_qkv.bias":
# Keep your original split logic -> q_proj/k_proj/v_proj bias
param = param.view(args.num_query_groups, -1)
q_bias, k_bias, v_bias = torch.split(
param,
split_size_or_sections=[value_num_per_group * head_dim, head_dim, head_dim],
dim=1,
)
q_bias = q_bias.contiguous().flatten()
k_bias = k_bias.contiguous().flatten()
v_bias = v_bias.contiguous().flatten()
return [
(f"{base}.self_attn.q_proj.bias", q_bias),
(f"{base}.self_attn.k_proj.bias", k_bias),
(f"{base}.self_attn.v_proj.bias", v_bias),
]

elif rest == "mlp.linear_fc1.weight":
# Keep your original split logic -> gate_proj/up_proj
gate_weight, up_weight = param.chunk(2, dim=0)
return [
(f"{base}.mlp.gate_proj.weight", gate_weight),
(f"{base}.mlp.up_proj.weight", up_weight),
]

elif rest == "mlp.linear_fc2.weight":
return [(f"{base}.mlp.down_proj.weight", param)]

elif rest == "self_attention.linear_qkv.layer_norm_weight":
return [(f"{base}.input_layernorm.weight", param)]

elif rest == "mlp.linear_fc1.layer_norm_weight":
return [(f"{base}.post_attention_layernorm.weight", param)]

# qk norm
elif rest == "self_attention.q_layernorm.weight":
return [(f"{base}.self_attn.q_norm.weight", param)]
elif rest == "self_attention.k_layernorm.weight":
return [(f"{base}.self_attn.k_norm.weight", param)]

raise ValueError(f"Unknown parameter name: {name}")