Skip to content

Commit 4e92646

Browse files
author
shixingliang.sxl
committed
support qwen3vl
Signed-off-by: shixingliang.sxl <shixingliang.sxl@antgroup.com>
1 parent 9bfe152 commit 4e92646

3 files changed

Lines changed: 205 additions & 0 deletions

File tree

slime/backends/megatron_utils/megatron_to_hf/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@
77
from .qwen2 import convert_qwen2_to_hf
88
from .qwen3_next import convert_qwen3_next_to_hf
99
from .qwen3moe import convert_qwen3moe_to_hf
10+
from .qwen3vl import convert_qwen3vl_to_hf
11+
<<<<<<< HEAD
12+
=======
13+
from .qwen3vlmoe import convert_qwen3vlmoe_to_hf
14+
>>>>>>> 77b1769 (support qwen3vl)
1015

1116

1217
# TODO unify w/ `convert_to_hf`
@@ -37,6 +42,13 @@ def _convert_to_hf_core(args, model_name, name, param):
3742
converted_named_tensors = convert_glm4moe_to_hf(args, name, param)
3843
elif "glm4" in model_name:
3944
converted_named_tensors = convert_glm4_to_hf(args, name, param)
45+
elif "qwen3vl" in model_name:
46+
converted_named_tensors = convert_qwen3vl_to_hf(args, name, param)
47+
<<<<<<< HEAD
48+
=======
49+
elif "qwen3vlmoe" in model_name:
50+
converted_named_tensors = convert_qwen3vlmoe_to_hf(args, name, param)
51+
>>>>>>> 77b1769 (support qwen3vl)
4052
elif "qwen3moe" in model_name:
4153
converted_named_tensors = convert_qwen3moe_to_hf(args, name, param)
4254
elif "qwen3next" in model_name:
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import re
2+
import torch
3+
4+
5+
def convert_qwen3vl_to_hf(args, name, param):
6+
if name == "module.module.language_model.embedding.word_embeddings.weight":
7+
return [("model.language_model.embed_tokens.weight", param)]
8+
if name == "module.module.language_model.output_layer.weight":
9+
return [("lm_head.weight", param)]
10+
if name == "module.module.language_model.decoder.final_layernorm.weight":
11+
return [("model.language_model.norm.weight", param)]
12+
13+
try:
14+
head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads
15+
except AttributeError:
16+
head_dim = args.hidden_size // args.num_attention_heads
17+
value_num_per_group = args.num_attention_heads // args.num_query_groups
18+
19+
decoder_layers_pattern = r"module\.module\.language_model\.decoder\.layers\.(\d+)\.(.+)"
20+
match = re.match(decoder_layers_pattern, name)
21+
if match:
22+
layer_idx, rest = match.groups()
23+
if rest == "self_attention.linear_proj.weight":
24+
return [(f"model.language_model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
25+
elif rest == "self_attention.linear_qkv.weight":
26+
param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size)
27+
q_param, k_param, v_param = torch.split(param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1)
28+
q_param = q_param.reshape(-1, args.hidden_size)
29+
k_param = k_param.reshape(-1, args.hidden_size)
30+
v_param = v_param.reshape(-1, args.hidden_size)
31+
return [
32+
(f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight", q_param),
33+
(f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight", k_param),
34+
(f"model.language_model.layers.{layer_idx}.self_attn.v_proj.weight", v_param),
35+
]
36+
elif rest == "mlp.linear_fc1.weight":
37+
gate_weight, up_weight = param.chunk(2, dim=0)
38+
return [
39+
(f"model.language_model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight),
40+
(f"model.language_model.layers.{layer_idx}.mlp.up_proj.weight", up_weight),
41+
]
42+
elif rest == "mlp.linear_fc2.weight":
43+
return [(f"model.language_model.layers.{layer_idx}.mlp.down_proj.weight", param)]
44+
elif rest == "self_attention.linear_qkv.layer_norm_weight":
45+
return [(f"model.language_model.layers.{layer_idx}.input_layernorm.weight", param)]
46+
elif rest == "mlp.linear_fc1.layer_norm_weight":
47+
return [(f"model.language_model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
48+
49+
# qk norm
50+
elif rest == "self_attention.q_layernorm.weight":
51+
return [(f"model.language_model.layers.{layer_idx}.self_attn.q_norm.weight", param)]
52+
elif rest == "self_attention.k_layernorm.weight":
53+
return [(f"model.language_model.layers.{layer_idx}.self_attn.k_norm.weight", param)]
54+
55+
# patch embed / pos embed
56+
vision_prefix_table = {
57+
"module.module.vision_model.patch_embed.proj.weight": "model.visual.patch_embed.proj.weight",
58+
"module.module.vision_model.patch_embed.proj.bias": "model.visual.patch_embed.proj.bias",
59+
"module.module.vision_model.pos_embed.weight": "model.visual.pos_embed.weight",
60+
"module.module.vision_model.merger.norm.weight": "model.visual.merger.norm.weight",
61+
"module.module.vision_model.merger.norm.bias": "model.visual.merger.norm.bias",
62+
"module.module.vision_model.merger.linear_fc1.weight": "model.visual.merger.linear_fc1.weight",
63+
"module.module.vision_model.merger.linear_fc1.bias": "model.visual.merger.linear_fc1.bias",
64+
"module.module.vision_model.merger.linear_fc2.weight": "model.visual.merger.linear_fc2.weight",
65+
"module.module.vision_model.merger.linear_fc2.bias": "model.visual.merger.linear_fc2.bias",
66+
}
67+
if name in vision_prefix_table:
68+
return [(vision_prefix_table[name], param)]
69+
70+
# deepstack_merger_list
71+
deepstack_merger_pattern = r"module\.module\.vision_model\.deepstack_merger_list\.(\d+)\.(.+)"
72+
deepstack_match = re.match(deepstack_merger_pattern, name)
73+
if deepstack_match:
74+
idx, rest = deepstack_match.groups()
75+
return [(f"model.visual.deepstack_merger_list.{idx}.{rest}", param)]
76+
77+
# vision transformer blocks
78+
vision_model_block_pattern = r"module\.module\.vision_model\.blocks\.(\d+)\.(.+)"
79+
vision_model_match = re.match(vision_model_block_pattern, name)
80+
if vision_model_match:
81+
block_idx, rest = vision_model_match.groups()
82+
return [(f"model.visual.blocks.{block_idx}.{rest}", param)]
83+
84+
raise ValueError(f"Unknown parameter name: {name}")
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import re
2+
3+
import torch
4+
5+
6+
def convert_qwen3vlmoe_to_hf(args, name, param):
7+
if name == "module.module.language_model.embedding.word_embeddings.weight":
8+
return [("model.language_model.embed_tokens.weight", param)]
9+
if name == "module.module.language_model.output_layer.weight":
10+
return [("lm_head.weight", param)]
11+
if name == "module.module.language_model.decoder.final_layernorm.weight":
12+
return [("model.language_model.norm.weight", param)]
13+
14+
try:
15+
head_dim = args.kv_channels if args.kv_channels is not None else args.hidden_size // args.num_attention_heads
16+
except AttributeError:
17+
head_dim = args.hidden_size // args.num_attention_heads
18+
value_num_per_group = args.num_attention_heads // args.num_query_groups
19+
20+
decoder_layers_pattern = r"module\.module\.language_model\.decoder\.layers\.(\d+)\.(.+)"
21+
match = re.match(decoder_layers_pattern, name)
22+
if match:
23+
layer_idx, rest = match.groups()
24+
25+
# experts
26+
expert_pattern = r"mlp.experts\.(.+)\.weight(\d+)"
27+
match = re.match(expert_pattern, rest)
28+
if match:
29+
rest, expert_idx = match.groups()
30+
if rest == "linear_fc1":
31+
gate_weight, up_weight = param.chunk(2, dim=0)
32+
outputs = [
33+
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight", gate_weight),
34+
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight", up_weight),
35+
]
36+
return outputs
37+
elif rest == "linear_fc2":
38+
outputs = [
39+
(f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight", param),
40+
]
41+
return outputs
42+
else:
43+
raise ValueError(f"Unknown expert parameter name: {name}")
44+
45+
if rest == "self_attention.linear_proj.weight":
46+
return [(f"model.language_model.layers.{layer_idx}.self_attn.o_proj.weight", param)]
47+
elif rest == "self_attention.linear_qkv.weight":
48+
param = param.view(args.num_query_groups, -1, head_dim, args.hidden_size)
49+
q_param, k_param, v_param = torch.split(param, split_size_or_sections=[value_num_per_group, 1, 1], dim=1)
50+
q_param = q_param.reshape(-1, args.hidden_size)
51+
k_param = k_param.reshape(-1, args.hidden_size)
52+
v_param = v_param.reshape(-1, args.hidden_size)
53+
return [
54+
(f"model.language_model.layers.{layer_idx}.self_attn.q_proj.weight", q_param),
55+
(f"model.language_model.layers.{layer_idx}.self_attn.k_proj.weight", k_param),
56+
(f"model.language_model.layers.{layer_idx}.self_attn.v_proj.weight", v_param),
57+
]
58+
elif rest == "mlp.linear_fc1.weight":
59+
gate_weight, up_weight = param.chunk(2, dim=0)
60+
return [
61+
(f"model.language_model.layers.{layer_idx}.mlp.gate_proj.weight", gate_weight),
62+
(f"model.language_model.layers.{layer_idx}.mlp.up_proj.weight", up_weight),
63+
]
64+
elif rest == "mlp.linear_fc2.weight":
65+
return [(f"model.language_model.layers.{layer_idx}.mlp.down_proj.weight", param)]
66+
elif rest == "self_attention.linear_qkv.layer_norm_weight":
67+
return [(f"model.language_model.layers.{layer_idx}.input_layernorm.weight", param)]
68+
elif rest == "mlp.linear_fc1.layer_norm_weight":
69+
return [(f"model.language_model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
70+
elif rest == "pre_mlp_layernorm.weight":
71+
return [(f"model.layers.{layer_idx}.post_attention_layernorm.weight", param)]
72+
elif rest == "mlp.router.weight":
73+
return [(f"model.layers.{layer_idx}.mlp.gate.weight", param)]
74+
# qk norm
75+
elif rest == "self_attention.q_layernorm.weight":
76+
return [(f"model.language_model.layers.{layer_idx}.self_attn.q_norm.weight", param)]
77+
elif rest == "self_attention.k_layernorm.weight":
78+
return [(f"model.language_model.layers.{layer_idx}.self_attn.k_norm.weight", param)]
79+
80+
# patch embed / pos embed
81+
vision_prefix_table = {
82+
"module.module.vision_model.patch_embed.proj.weight": "model.visual.patch_embed.proj.weight",
83+
"module.module.vision_model.patch_embed.proj.bias": "model.visual.patch_embed.proj.bias",
84+
"module.module.vision_model.pos_embed.weight": "model.visual.pos_embed.weight",
85+
"module.module.vision_model.merger.norm.weight": "model.visual.merger.norm.weight",
86+
"module.module.vision_model.merger.norm.bias": "model.visual.merger.norm.bias",
87+
"module.module.vision_model.merger.linear_fc1.weight": "model.visual.merger.linear_fc1.weight",
88+
"module.module.vision_model.merger.linear_fc1.bias": "model.visual.merger.linear_fc1.bias",
89+
"module.module.vision_model.merger.linear_fc2.weight": "model.visual.merger.linear_fc2.weight",
90+
"module.module.vision_model.merger.linear_fc2.bias": "model.visual.merger.linear_fc2.bias",
91+
}
92+
if name in vision_prefix_table:
93+
return [(vision_prefix_table[name], param)]
94+
95+
# deepstack_merger_list
96+
deepstack_merger_pattern = r"module\.module\.vision_model\.deepstack_merger_list\.(\d+)\.(.+)"
97+
deepstack_match = re.match(deepstack_merger_pattern, name)
98+
if deepstack_match:
99+
idx, rest = deepstack_match.groups()
100+
return [(f"model.visual.deepstack_merger_list.{idx}.{rest}", param)]
101+
102+
# vision transformer blocks
103+
vision_model_block_pattern = r"module\.module\.vision_model\.blocks\.(\d+)\.(.+)"
104+
vision_model_match = re.match(vision_model_block_pattern, name)
105+
if vision_model_match:
106+
block_idx, rest = vision_model_match.groups()
107+
return [(f"model.visual.blocks.{block_idx}.{rest}", param)]
108+
109+
raise ValueError(f"Unknown parameter name: {name}")

0 commit comments

Comments
 (0)