From 172f98ea26d2126f23cbe9f39650cb7201fdcfc3 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 31 Oct 2025 19:58:39 +0800 Subject: [PATCH 01/15] support qwen3vl dense --- lmdeploy/archs.py | 6 +- lmdeploy/pytorch/config.py | 16 +- lmdeploy/pytorch/configurations/default.py | 9 +- lmdeploy/pytorch/engine/model_agent.py | 3 +- lmdeploy/pytorch/models/module_map.py | 12 + lmdeploy/pytorch/models/qwen3.py | 2 +- lmdeploy/pytorch/models/qwen3_vl.py | 893 +++++++++++++++++++++ lmdeploy/vl/model/builder.py | 1 + lmdeploy/vl/model/qwen3.py | 128 +++ 9 files changed, 1061 insertions(+), 9 deletions(-) create mode 100644 lmdeploy/pytorch/models/qwen3_vl.py create mode 100644 lmdeploy/vl/model/qwen3.py diff --git a/lmdeploy/archs.py b/lmdeploy/archs.py index faf5e88157..444a2026a3 100644 --- a/lmdeploy/archs.py +++ b/lmdeploy/archs.py @@ -109,9 +109,9 @@ def check_vl_llm(config: dict) -> bool: 'LlavaLlamaForCausalLM', 'LlavaMistralForCausalLM', 'CogVLMForCausalLM', 'InternLMXComposer2ForCausalLM', 'InternVLChatModel', 'MiniCPMV', 'LlavaForConditionalGeneration', 'LlavaNextForConditionalGeneration', 'Phi3VForCausalLM', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', - 'MllamaForConditionalGeneration', 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration', - 'Llama4ForConditionalGeneration', 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration', - 'Glm4vForConditionalGeneration' + 'Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration', 'MllamaForConditionalGeneration', + 'MolmoForCausalLM', 'Gemma3ForConditionalGeneration', 'Llama4ForConditionalGeneration', + 'InternVLForConditionalGeneration', 'InternS1ForConditionalGeneration', 'Glm4vForConditionalGeneration' ]) if arch == 'QWenLMHeadModel' and 'visual' in config: return True diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index ac3459e045..1c21ef4d67 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -28,9 +28,16 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): config.dtype = torch.float16 return config - torch_dtype = getattr(config.hf_config, 'dtype', None) - if torch_dtype is None: - torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + if hasattr(config.hf_config, 'text_config'): + torch_dtype = getattr(config.hf_config.text_config, 'dtype', None) + + if torch_dtype is None: + torch_dtype = getattr(config.hf_config.text_config, 'torch_dtype', None) + else: + torch_dtype = getattr(config.hf_config, 'dtype', None) + + if torch_dtype is None: + torch_dtype = getattr(config.hf_config, 'torch_dtype', None) # deal with case when torch_dtype is not string but torch.dtype if isinstance(torch_dtype, torch.dtype): @@ -283,7 +290,10 @@ def from_hf_config(cls, assert tp % model_config.num_key_value_heads == 0 # should after setting `hf_config` and `model_arch` attributes + print(f'?????? dtype: {dtype}') + print(f'model_config: {model_config}') model_config = _update_torch_dtype(model_config, dtype) + print(f'after update, model_config: {model_config}') # update eos_token_id to list if isinstance(model_config.eos_token_id, int): diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py index e30ae7c089..da07a3b487 100644 --- a/lmdeploy/pytorch/configurations/default.py +++ b/lmdeploy/pytorch/configurations/default.py @@ -15,7 +15,14 @@ def condition(cls, hf_config): def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" head_dim = getattr(hf_config, 'head_dim', None) - head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads + + if hasattr(hf_config, 'text_config') and hasattr(hf_config, 'vision_config'): + # for multi-modal models config with separate text and vision configs + hf_config = hf_config.text_config + head_dim = hf_config.head_dim + else: + head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads + # head_dim should not be None hf_config.head_dim = head_dim num_attention_heads = hf_config.num_attention_heads diff --git a/lmdeploy/pytorch/engine/model_agent.py b/lmdeploy/pytorch/engine/model_agent.py index d19485e83e..77821a1cdc 100644 --- a/lmdeploy/pytorch/engine/model_agent.py +++ b/lmdeploy/pytorch/engine/model_agent.py @@ -248,7 +248,8 @@ def model_forward( output = model(**input_dict) # InternVL-3.5-Flash will change the seqlen, model_metas during forward - model_metas = context.model_metas + if context.model_metas is not None and context.model_metas[0] is not None: + model_metas = context.model_metas seq_length = context.q_seqlens[:len(inputs.seq_length)] return dict(hidden_states=output, model_metas=model_metas, seq_length=seq_length) diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 498e2c6554..2559549d1d 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -147,6 +147,18 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen2_5_vl.Qwen2_5_VLForConditionalGeneration', }) +# qwen3_vl +MODULE_MAP.update({ + 'Qwen3VLForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl.Qwen3VLForConditionalGeneration', +}) + +# # qwen3_vl_moe +# MODULE_MAP.update({ +# 'Qwen3VLMoeForConditionalGeneration': +# f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration', +# }) + # starcoder2 MODULE_MAP.update({ 'Starcoder2ForCausalLM': f'{LMDEPLOY_PYTORCH_MODEL_PATH}.starcoder2.Starcoder2ForCausalLM', diff --git a/lmdeploy/pytorch/models/qwen3.py b/lmdeploy/pytorch/models/qwen3.py index c362df2fe8..fac4ca579f 100644 --- a/lmdeploy/pytorch/models/qwen3.py +++ b/lmdeploy/pytorch/models/qwen3.py @@ -47,7 +47,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim, - sliding_window=config.sliding_window, + sliding_window=(config.sliding_window if hasattr(config, 'sliding_window') else None), ) # o_proj diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py new file mode 100644 index 0000000000..4080dc7523 --- /dev/null +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -0,0 +1,893 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update + +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager +from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor +from lmdeploy.pytorch.nn import LayerNorm, RMSNorm +from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3VLVisionRotaryEmbedding +from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention +from .qwen3 import Qwen3DecoderLayer as Qwen3VLTextDecoderLayer +from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin +from .utils.model import DeployModelMixin, vlm_model + + +def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int], + position_ids: torch.Tensor, rotary_emb_func: Callable): + _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device) + _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids + cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids) + _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device) + _sin = torch.zeros_like(_cos) + mrope_section = mrope_section * 2 + + def _apply_split(src, dst): + start = 0 + for i, m in enumerate(src.split(mrope_section, dim=-1)): + dst[:, start:start + mrope_section[i]] = m[i % 3] + start += mrope_section[i] + + _apply_split(cos, _cos) + _apply_split(sin, _sin) + + return _cos, _sin + + +class Qwen3VLTextRotaryEmbedding(nn.Module): + inv_freq: torch.Tensor # fix linting for `register_buffer` + + def __init__(self, config: PretrainedConfig, device=None): + super().__init__() + if hasattr(config, 'rope_scaling') and config.rope_scaling is not None: + self.rope_type = config.rope_scaling.get('rope_type', 'default') + else: + self.rope_type = 'default' + self.max_seq_len_cached = config.max_position_embeddings + self.original_max_seq_len = config.max_position_embeddings + + self.config = config + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + + inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) + self.register_buffer('inv_freq', inv_freq, persistent=False) + self.original_inv_freq = self.inv_freq + + self.mrope_section = config.rope_scaling.get('mrope_section', [24, 20, 20]) + + def apply_interleaved_mrope(self, freqs, mrope_section): + """Apply interleaved MRoPE to 3D rotary embeddings. + + Reorganizes frequency layout from chunked [TTT...HHH...WWW] to + interleaved [THTHWHTHW...TT], preserving frequency continuity. + args: + x: (3, bs, seq_len, head_dim // 2) + mrope_section: (3,) + returns: + x_t: (bs, seq_len, head_dim // 2) + """ + freqs_t = freqs[0] # just overwrite the first dimension T + for dim, offset in enumerate((1, 2), start=1): # H, W + length = mrope_section[dim] * 3 + idx = slice(offset, length, 3) + freqs_t[..., idx] = freqs[dim, ..., idx] + return freqs_t + + @torch.no_grad() + @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) + def forward(self, x, position_ids): + # In contrast to other models, Qwen3VL has different position ids for the grids + # So we expand the inv_freq to shape (3, ...) + if position_ids.ndim == 2: + position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1) + inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1) + position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions) + + device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != 'mps' else 'cpu' + with torch.autocast(device_type=device_type, enabled=False): # Force float32 + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3) + freqs = self.apply_interleaved_mrope(freqs, self.mrope_section) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() * self.attention_scaling + sin = emb.sin() * self.attention_scaling + + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class Qwen3VLTextModel(nn.Module): + """Text part of Qwen3VL. + + not a pure text-only model, as DeepStack integrates visual features into the early hidden states. + """ + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.mrope_section = config.rope_scaling['mrope_section'] + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + Qwen3VLTextDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + # build rotary embedding + # TODO: zhouxinyu, add triton kernel for interleaved mrope + self.rotary_emb = Qwen3VLTextRotaryEmbedding(config, device=device) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mrope_position_ids: torch.LongTensor = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + ): + """visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, + *optional*): + + The mask of the visual positions. deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): The deepstack + visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). The feature is extracted from the + different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack ( + https://arxiv.org/abs/2406.04) + """ + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + if mrope_position_ids is None: + cos, sin = self.rotary_emb(hidden_states, position_ids) + else: + cos, sin = self.rotary_emb(hidden_states, mrope_position_ids) + + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[idx], + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class Qwen3VLVisionPatchEmbed(nn.Module): + + def __init__(self, config, dtype: torch.dtype = None, device: torch.device = None) -> None: + super().__init__() + self.patch_size = config.patch_size + self.temporal_patch_size = config.temporal_patch_size + self.in_channels = config.in_channels + self.embed_dim = config.hidden_size + + kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size] + self.proj = nn.Conv3d(self.in_channels, + self.embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + dtype=dtype, + device=device) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, + self.patch_size) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) + return hidden_states + + +class Qwen3VLVisionMLP(nn.Module): + """Vision mlp.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + from transformers.activations import ACT2FN + hidden_dim = config.hidden_size + intermediate_size = config.intermediate_size + quantization_config = getattr(config, 'quantization_config', None) + # gate up + self.linear_fc1 = build_colwise_linear( + hidden_dim, + intermediate_size, + bias=True, + dtype=dtype, + device=device, + quant_config=quantization_config, + is_tp=True, + ) + + # silu and mul + if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']: + self.act = nn.GELU() + else: + self.act = ACT2FN[config.hidden_act] + + # down + self.linear_fc2 = build_rowwise_linear(intermediate_size, + hidden_dim, + bias=True, + quant_config=quantization_config, + dtype=dtype, + device=device, + is_tp=True) + + def forward(self, x): + """forward.""" + return self.linear_fc2(self.act(self.linear_fc1(x))) + + +class Qwen3VLVisionBlock(nn.Module): + """Vision block.""" + + def __init__(self, + config: PretrainedConfig, + layer_idx: int, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.layer_idx = layer_idx + self.norm1 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device) + self.norm2 = LayerNorm(config.hidden_size, eps=1e-6, dtype=dtype, device=device) + + self.attn = Qwen3VLVisionAttention(config, dtype=dtype, device=device) + + self.mlp = Qwen3VLVisionMLP(config, dtype=dtype, device=device) + + def forward(self, + hidden_states, + cu_seqlens, + rotary_pos_emb, + residual: Optional[torch.Tensor] = None) -> torch.Tensor: + if residual is None: + residual = hidden_states + hidden_states = self.norm1(hidden_states) + else: + hidden_states, residual = self.norm1(hidden_states, residual) + + hidden_states = self.attn(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) + + hidden_states, residual = self.norm2(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen3VLVisionPatchMerger(nn.Module): + + def __init__(self, + config: PretrainedConfig, + use_postshuffle_norm=False, + dtype: torch.dtype = None, + device: torch.device = None) -> None: + super().__init__() + self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, + eps=1e-6, + dtype=dtype, + device=device) + self.linear_fc1 = build_colwise_linear( + self.hidden_size, + self.hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ) + self.act_fn = nn.GELU() + self.linear_fc2 = build_rowwise_linear( + self.hidden_size, + config.out_hidden_size, + bias=True, + dtype=dtype, + device=device, + is_tp=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x.view(-1, self.hidden_size) if self.use_postshuffle_norm else x).view(-1, self.hidden_size) + x = self.linear_fc2(self.act_fn(self.linear_fc1(x))) + return x + + +@vlm_model +class Qwen3VLVisionModel(nn.Module): + """Vision transformer.""" + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.config = config + self.spatial_merge_size = config.spatial_merge_size + + self.patch_embed = Qwen3VLVisionPatchEmbed(config=config, dtype=dtype, device=device) + + self.pos_embed = nn.Embedding(config.num_position_embeddings, config.hidden_size, dtype=dtype, device=device) + self.num_grid_per_side = int(config.num_position_embeddings**0.5) + + head_dim = config.hidden_size // config.num_heads + self.rotary_pos_emb = Qwen3VLVisionRotaryEmbedding(head_dim // 2, device=device) + + self.blocks = nn.ModuleList( + [Qwen3VLVisionBlock(config, layer_idx, dtype=dtype, device=device) for layer_idx in range(config.depth)]) + self.merger = Qwen3VLVisionPatchMerger(config=config, use_postshuffle_norm=False, dtype=dtype, device=device) + + self.deepstack_visual_indexes = config.deepstack_visual_indexes + self.deepstack_merger_list = nn.ModuleList([ + Qwen3VLVisionPatchMerger(config=config, use_postshuffle_norm=True, dtype=dtype, device=device) + for _ in range(len(config.deepstack_visual_indexes)) + ]) + + def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: + merge_size = self.spatial_merge_size + + max_hw = int(grid_thw[:, 1:].max().item()) + freq_table = self.rotary_pos_emb(max_hw) # (max_hw, dim // 2) + device = freq_table.device + + total_tokens = int(torch.prod(grid_thw, dim=1).sum().item()) + pos_ids = torch.empty((total_tokens, 2), dtype=torch.long, device=device) + + offset = 0 + for num_frames, height, width in grid_thw: + merged_h, merged_w = height // merge_size, width // merge_size + + block_rows = torch.arange(merged_h, device=device) # block row indices + block_cols = torch.arange(merged_w, device=device) # block col indices + intra_row = torch.arange(merge_size, device=device) # intra-block row offsets + intra_col = torch.arange(merge_size, device=device) # intra-block col offsets + + # Compute full-resolution positions + row_idx = block_rows[:, None, None, None] * merge_size + intra_row[None, None, :, None] + col_idx = block_cols[None, :, None, None] * merge_size + intra_col[None, None, None, :] + + row_idx = row_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + col_idx = col_idx.expand(merged_h, merged_w, merge_size, merge_size).reshape(-1) + + coords = torch.stack((row_idx, col_idx), dim=-1) + + if num_frames > 1: + coords = coords.repeat(num_frames, 1) + + num_tokens = coords.shape[0] + pos_ids[offset:offset + num_tokens] = coords + offset += num_tokens + + embeddings = freq_table[pos_ids] # lookup rotary embeddings + embeddings = embeddings.flatten(1) + return embeddings + + def fast_pos_embed_interpolate(self, grid_thw): + grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] + + idx_list = [[] for _ in range(4)] + weight_list = [[] for _ in range(4)] + + for t, h, w in zip(grid_ts, grid_hs, grid_ws): + h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) + w_idxs = torch.linspace(0, self.num_grid_per_side - 1, w) + + h_idxs_floor = h_idxs.int() + w_idxs_floor = w_idxs.int() + h_idxs_ceil = (h_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + w_idxs_ceil = (w_idxs.int() + 1).clip(max=self.num_grid_per_side - 1) + + dh = h_idxs - h_idxs_floor + dw = w_idxs - w_idxs_floor + + base_h = h_idxs_floor * self.num_grid_per_side + base_h_ceil = h_idxs_ceil * self.num_grid_per_side + + indices = [ + (base_h[None].T + w_idxs_floor[None]).flatten(), + (base_h[None].T + w_idxs_ceil[None]).flatten(), + (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), + (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), + ] + + weights = [ + ((1 - dh)[None].T * (1 - dw)[None]).flatten(), + ((1 - dh)[None].T * dw[None]).flatten(), + (dh[None].T * (1 - dw)[None]).flatten(), + (dh[None].T * dw[None]).flatten(), + ] + + for i in range(4): + idx_list[i].extend(indices[i].tolist()) + weight_list[i].extend(weights[i].tolist()) + + idx_tensor = torch.tensor(idx_list, dtype=torch.long, device=self.pos_embed.weight.device) + weight_tensor = torch.tensor(weight_list, + dtype=self.pos_embed.weight.dtype, + device=self.pos_embed.weight.device) + pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] + patch_pos_embeds = pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] + + patch_pos_embeds = patch_pos_embeds.split([h * w for h, w in zip(grid_hs, grid_ws)]) + + patch_pos_embeds_permute = [] + merge_size = self.config.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = (pos_embed.view(t, h // merge_size, merge_size, w // merge_size, merge_size, + -1).permute(0, 1, 3, 2, 4, 5).flatten(0, 4)) + patch_pos_embeds_permute.append(pos_embed) + patch_pos_embeds = torch.cat(patch_pos_embeds_permute) + return patch_pos_embeds + + def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor, + pos_embeds: torch.Tensor) -> torch.Tensor: + """forward.""" + hidden_states = self.patch_embed(hidden_states) + hidden_states = hidden_states + pos_embeds + cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) + + residual = None + deepstack_feature_lists = [] + for layer_num, blk in enumerate(self.blocks): + hidden_states, residual = blk(hidden_states, + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + residual=residual) + if layer_num in self.deepstack_visual_indexes: + deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( + (hidden_states + residual)) + deepstack_feature_lists.append(deepstack_feature) + + hidden_states = hidden_states + residual + + return self.merger(hidden_states), deepstack_feature_lists + + +class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__() + self.config = config + self.ctx_mgr = ctx_mgr + + # build preprocessor + self.input_processor = Qwen3VLInputProcessor(self.config) + + # build vision model + self.visual = Qwen3VLVisionModel( + config.vision_config, + dtype=dtype, + device=device, + ) + + # build text model + self.language_model = Qwen3VLTextModel(config.text_config, dtype=dtype, device=device) + + # build lm_head + self.lm_head = build_rowwise_linear(config.text_config.hidden_size, + config.text_config.vocab_size, + bias=False, + dtype=dtype, + device=device) + + def _prepare_multimodal_inputs(self, input_ids: torch.Tensor, pixel_values: torch.Tensor, image_mask: torch.Tensor, + grid_thw: torch.Tensor, vis_cu_seqlens: torch.Tensor, vis_pos_emb: torch.Tensor, + pos_embeds: torch.Tensor): + """Prepare multimodal inputs for language model.""" + inputs_embeds = self.get_input_embeddings()(input_ids) + if pixel_values is None: + return inputs_embeds, None, None + + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype)) + + # get image embeds and deepstack visual embeds + image_embeds, deepstack_visual_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb, + pos_embeds=pos_embeds) + + # split image embeds per sample + split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) + + # mask and scatter to create final input embeddings + expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + final_inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) + + visual_pos_masks = expanded_image_mask[..., 0] + + return final_inputs_embeds, visual_pos_masks, deepstack_visual_embeds + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + past_key_values: List[List[torch.Tensor]], + attn_metadata: Any = None, + inputs_embeds: torch.Tensor = None, + mrope_position_ids: torch.Tensor = None, + pixel_values: torch.Tensor = None, + vis_cu_seqlens: torch.Tensor = None, + vis_pos_emb: torch.Tensor = None, + image_mask: torch.Tensor = None, + pos_embeds: torch.Tensor = None, + grid_thw: torch.Tensor = None, + **kwargs, + ): + """Model forward, return logits.""" + + visual_pos_masks = None + deepstack_visual_embeds = None + if inputs_embeds is None: + inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self._prepare_multimodal_inputs( + input_ids, pixel_values, image_mask, grid_thw, vis_cu_seqlens, vis_pos_emb, pos_embeds) + + hidden_states = self.language_model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + # args for deepstack + visual_pos_masks=visual_pos_masks, + deepstack_visual_embeds=deepstack_visual_embeds, + ) + return hidden_states + + def get_logits(self, hidden_states: torch.Tensor): + """Compute logits of the model output.""" + return self.lm_head(hidden_states) + + def update_weights(self): + """Update weights.""" + if self.config.tie_word_embeddings: + self.lm_head.weight = self.language_model.embed_tokens.weight + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.language_model.get_input_embeddings() + + def prepare_inputs_for_generation( + self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None, + ): + """Prepare input.""" + + # get input_ids, position_ids and attention metadatas + input_ids = context.input_ids + position_ids = context.position_ids + attn_metadata = context.attn_metadata + + pixel_values = None + vis_cu_seqlens = None + vis_pos_emb = None + image_mask = None + grid_thw = None + pos_embeds = None + if context.input_multimodals is not None: + image_data = [input_mm.get('image', []) for input_mm in context.input_multimodals] + if len(image_data) > 0: + # flatten batch + image_data = [data for im_data in image_data for data in im_data] + pixel_values = torch.cat([data.data for data in image_data]) + image_token_id = image_data[0].meta['image_token_id'] + image_mask = input_ids == image_token_id + grid_thw = torch.cat([data.meta['grid_thw'] for data in image_data]).cpu() + vis_pos_emb = self.visual.rot_pos_emb(grid_thw) + pos_embeds = self.visual.fast_pos_embed_interpolate(grid_thw) + vis_cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], + grid_thw[:, 0]).to(pixel_values.device) + vis_cu_seqlens = vis_cu_seqlens.cumsum(dim=0, dtype=torch.int32) + vis_pos_emb = vis_pos_emb.repeat(1, 2) + vis_pos_emb = (vis_pos_emb.cos(), vis_pos_emb.sin()) + + mrope_position_ids = getattr(context, 'mrope_position_ids', None) + + # process vision embeddings + vision_embeddings = context.input_embeddings + vision_embedding_indexing = context.input_embedding_indexing + if vision_embeddings is not None and len(vision_embeddings) > 0: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + inputs_embeds[:, vision_embedding_indexing, :] = vision_embeddings.to(inputs_embeds) + + # inputs of forward + return dict( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + attn_metadata=attn_metadata, + inputs_embeds=inputs_embeds, + mrope_position_ids=mrope_position_ids, + pixel_values=pixel_values, + vis_cu_seqlens=vis_cu_seqlens, + vis_pos_emb=vis_pos_emb, + image_mask=image_mask, + grid_thw=grid_thw, + pos_embeds=pos_embeds, + ) + + def rename_weight(self, name: str) -> str: + """Rename weight.""" + if name.startswith('model.language_model.'): + return 'language_model.' + name[len('model.language_model.'):] + elif name.startswith('model.visual.'): + return 'visual.' + name[len('model.visual.'):] + elif name.startswith('model.'): + return name[len('model.'):] + return name + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + def make_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Make cudagraph buffers from forward inputs.""" + max_tokens = graph_meta.max_tokens + + input_buffers = super().make_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + mrope_position_ids = kwargs.get('mrope_position_ids', None) + if mrope_position_ids is not None: + input_buffers['mrope_position_ids'] = mrope_position_ids.new_zeros(3, max_tokens) + + return input_buffers + + def fill_buffers_cudagraph(self, graph_meta: CudaGraphMeta, **kwargs): + """Fill cudagraph buffers from forward inputs.""" + + new_inputs = super().fill_buffers_cudagraph(graph_meta=graph_meta, **kwargs) + + input_ids = kwargs.get('input_ids') + num_tokens = input_ids.size(-1) + new_batch_size = graph_meta.max_batchs + + is_decoding = graph_meta.is_decoding + input_buffers = graph_meta.input_buffers + mrope_position_ids = kwargs.get('mrope_position_ids', None) + if mrope_position_ids is not None: + input_buffers['mrope_position_ids'][:, :num_tokens] = mrope_position_ids + if is_decoding: + new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'][:, :new_batch_size] + else: + new_inputs['mrope_position_ids'] = input_buffers['mrope_position_ids'] + + return new_inputs + + def _get_model_metas(self, context: StepContext): + """Get model metas.""" + model_metas = context.model_metas + if model_metas is None: + batch_size = context.q_seqlens.numel() + return [dict(mrope_delta=0)] * batch_size + return [dict(mrope_delta=0) if meta is None else meta for meta in model_metas] + + def _update_model_meta_decoding(self, context: StepContext): + """Update model meta for decoding.""" + model_metas = self._get_model_metas(context) + position_ids = context.position_ids + + mrope_deltas = [meta['mrope_delta'] for meta in model_metas] + mrope_deltas = position_ids.new_tensor(mrope_deltas) + mrope_position_ids = position_ids + mrope_deltas[None] + mrope_position_ids = mrope_position_ids.expand(3, -1) + + context.mrope_position_ids = mrope_position_ids + return model_metas + + def _get_multimodal_pos_ids(self, grid_thw: list, device: torch.device): + """Get mrope ids.""" + t, h, w = grid_thw + h //= 2 + w //= 2 + stride = torch.tensor([h * w, w, 1], device=device)[:, None] + size = torch.tensor([t, h, w], device=device)[:, None] + pos_ids = torch.arange(t * h * w, device=device)[None].expand(3, -1) + pos_ids = pos_ids // stride % size + return pos_ids + + def _update_model_meta_prefilling(self, context: StepContext): + """Update model meta for prefilling.""" + model_metas = self._get_model_metas(context) + input_multimodals = context.input_multimodals + if input_multimodals is None: + input_multimodals = [None] * len(model_metas) + position_ids = context.position_ids + batched_pos_ids = position_ids[0].split(context.q_seqlens.tolist()) + mrope_position_ids = [] + new_model_metas = [] + for pos_ids, model_meta, input_mm in zip(batched_pos_ids, model_metas, input_multimodals): + images = [] + if input_mm is not None: + images = input_mm.get('image', []) + if model_meta is None or 'mrope_delta' not in model_meta: + mrope_delta = 0 + else: + mrope_delta = model_meta['mrope_delta'] + + pos_start = pos_ids[0].item() + mrope_pos_ids = pos_ids + mrope_delta + mrope_pos_ids = mrope_pos_ids[None].expand(3, -1).clone() + for img in images: + grid_thw = img.meta['grid_thw'][0].tolist() + _, h, w = grid_thw + h //= 2 + w //= 2 + num_pad = img.end - img.start - max(h, w) + mrope_delta -= num_pad + fill_start = img.start - pos_start + fill_end = img.end - pos_start + img_pos_ids = self._get_multimodal_pos_ids(grid_thw, pos_ids.device) + img_pos_ids += mrope_pos_ids[:, fill_start:fill_start + 1] + mrope_pos_ids[:, fill_end:] -= num_pad + mrope_pos_ids[:, fill_start:fill_end] = img_pos_ids + + mrope_position_ids.append(mrope_pos_ids) + new_model_metas.append(dict(mrope_delta=mrope_delta)) + + mrope_position_ids = torch.cat(mrope_position_ids, dim=1) + context.mrope_position_ids = mrope_position_ids + + return new_model_metas + + def update_model_metas(self, + past_key_values: List[List[torch.Tensor]], + inputs_embeds: Optional[torch.Tensor] = None, + context: StepContext = None): + """Update model meta.""" + if context.is_decoding: + return self._update_model_meta_decoding(context) + else: + return self._update_model_meta_prefilling(context) + + def get_input_processor(self) -> BaseModelInputProcessor: + """Get input processor.""" + return self.input_processor + + +InputMultiModalType = List[Dict[str, Any]] + + +class Qwen3VLInputProcessor(BaseModelInputProcessor): + """Qwen3 input processor.""" + + def __init__(self, config: PretrainedConfig) -> None: + self.config = config + + def preprocess_input(self, + input_ids: List[int], + input_multimodals: List[Dict[str, Any]] = None, + **kwargs) -> PreprocessInputResult: + """Prepare multimodal input.""" + if input_multimodals is None or len(input_multimodals) == 0: + return input_ids, input_multimodals + + input_imgs = [] + for input_mm in input_multimodals: + pixel_values = input_mm['pixel_values'] + image_grid_thw = input_mm['image_grid_thw'] + offset = input_mm['offset'] + start = offset + image_token_id = input_mm['image_token_id'] + num_pad = input_mm['image_tokens'] + if isinstance(num_pad, torch.Tensor): + num_pad = num_pad.item() + + mm_data = MultiModalTensor(data=pixel_values, + start=start, + end=start + num_pad, + meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) + input_imgs.append(mm_data) + + result = PreprocessInputResult( + input_ids=input_ids, + input_multimodals=dict(image=input_imgs), + ) + return result diff --git a/lmdeploy/vl/model/builder.py b/lmdeploy/vl/model/builder.py index 089b66a211..91db56c30a 100644 --- a/lmdeploy/vl/model/builder.py +++ b/lmdeploy/vl/model/builder.py @@ -28,6 +28,7 @@ from .phi3_vision import Phi3VisionModel # noqa F401 from .qwen import QwenVisionModel # noqa F401 from .qwen2 import Qwen2VLModel # noqa F401 +from .qwen3 import Qwen3VLModel # noqa F401 from .xcomposer2 import Xcomposer2VisionModel # noqa F401 from .yi import YiVisionModel # noqa F401 diff --git a/lmdeploy/vl/model/qwen3.py b/lmdeploy/vl/model/qwen3.py new file mode 100644 index 0000000000..40f2bf485c --- /dev/null +++ b/lmdeploy/vl/model/qwen3.py @@ -0,0 +1,128 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + +import torch + +from lmdeploy.vl.model.base import VISION_MODELS, VisonModel + + +def check_transformers(): + try: + from transformers import Qwen3VLForConditionalGeneration, Qwen3VLMoeForConditionalGeneration # noqa: F401 + except ImportError: + raise ImportError('please install latest transformers by ' + 'pip install git+https://github.com/huggingface/transformers.git') + + +@VISION_MODELS.register_module() +class Qwen3VLModel(VisonModel): + """Qwen3VL model.""" + + _arch = ['Qwen3VLForConditionalGeneration', 'Qwen3VLMoeForConditionalGeneration'] + + def build_preprocessor(self): + check_transformers() + from transformers import AutoProcessor + self.processor = AutoProcessor.from_pretrained(self.model_path) + tokenizer = self.processor.tokenizer + image_token = self.processor.image_token + self.image_token_id = tokenizer.encode(image_token)[-1] + + def preprocess(self, messages: List[Dict]) -> List[Dict]: + """Refer to `super().preprocess()` for spec.""" + images = self.collect_images(messages) + optional_keys = {'resized_height', 'resized_width', 'min_pixels', 'max_pixels'} + outputs = [] + for image, params in images: + image = image.convert('RGB') + + item = dict(type='image', image=image) + item.update({key: params[key] for key in params.keys() if key in optional_keys}) + result = self.processor.image_processor(images=image, videos=None, return_tensors='pt') + merge_length = self.processor.image_processor.merge_size**2 + image_tokens = result['image_grid_thw'].prod(dim=1) // merge_length + result.update(dict(image_size=image.size, image_tokens=image_tokens, image_token_id=self.image_token_id)) + outputs.append(result) + messages.append(dict(role='preprocess', content=outputs)) + return messages + + def build_model(self): + # TODO: implement for turbomind + pass + + @torch.no_grad() + def forward(self, messages: List[Dict], max_batch_size: int = 1) -> List[Dict]: + """Extract image feature. ONLY implement it when the backend is + turbomind engine. + + Args: + messages(List[Dict]): the outputs of `preprocess` + max_batch_size(int): the max batch size when forwarding vision + model + Return: + the message list with forwarding results included + """ + # TODO: implement for turbomind + pass + + @staticmethod + def proc_messages(messages, chat_template, sequence_start): + """Apply chat template to get the prompt.""" + prompt_messages = [] + IMAGE_TOKEN = '' + for message in messages: + if isinstance(message['content'], str): + prompt_messages.append(message) + continue + elif message['role'] in ['images', 'preprocess', 'forward']: + continue + n_images = len([1 for x in message['content'] if x['type'] == 'image']) + content = [item['text'] for item in message['content'] if item['type'] == 'text'] + prompt = content[0] + if IMAGE_TOKEN in prompt and '<|vision_start|>' not in prompt: + prompt = prompt.replace(IMAGE_TOKEN, f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>') + else: + # Qwen2-VL-2B-Instruct will concat image and user prompt + # according to their order in the content list + # we insert image token before user prompt by default. The + # user can use custom image token position if they want the + # same decorated prompt as Qwen2-VL + prompt = f'<|vision_start|>{IMAGE_TOKEN}<|vision_end|>' * \ + n_images + prompt + prompt_messages.append(dict(role=message['role'], content=prompt)) + prompt = chat_template.messages2prompt(prompt_messages, sequence_start) + return prompt, IMAGE_TOKEN + + @staticmethod + def get_mrope_info(seq_len: int, + grid_thws: List[Tuple[int, int, int]] = None, + ranges: List[Tuple[int, int]] = None): + mrope_position_ids = [torch.arange(ranges[0][0]).expand(3, -1)] + st_idx = ranges[0][0] + for i, (grid_thw, embedding_range) in enumerate(zip(grid_thws, ranges)): + llm_grid_t, llm_grid_h, llm_grid_w = grid_thw + llm_grid_h //= 2 + llm_grid_w //= 2 + t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() + mrope_position_ids.append(torch.stack([t_index, h_index, w_index]) + st_idx) + st_idx += max(llm_grid_h, llm_grid_w) + if i < len(ranges) - 1: + text_len = ranges[i + 1][0] - ranges[i][1] + else: + text_len = seq_len - embedding_range[1] + mrope_position_ids.append(torch.arange(text_len).expand(3, -1) + st_idx) + st_idx += text_len + mrope_position_ids = torch.cat(mrope_position_ids, dim=-1) + mrope_position_delta = torch.tensor([st_idx - seq_len], dtype=torch.long) + return mrope_position_ids, mrope_position_delta + + def to_pytorch(self, messages, chat_template, tokenizer, sequence_start, **kwargs): + """Return to the information needed by pytorch engine.""" + prompt, IMAGE_TOKEN = self.proc_messages(messages, chat_template, sequence_start) + return self.to_pytorch_aux(messages, prompt, IMAGE_TOKEN, tokenizer, sequence_start) + + def to_turbomind(self, messages, chat_template, tokenizer, sequence_start, **kwargs): + # TODO: implement for turbomind + pass From e472bd0ecd95e190e015c83c072064063f4dc199 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 31 Oct 2025 20:01:08 +0800 Subject: [PATCH 02/15] cleanups --- lmdeploy/pytorch/config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 1c21ef4d67..d0ac7c91a6 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -290,10 +290,7 @@ def from_hf_config(cls, assert tp % model_config.num_key_value_heads == 0 # should after setting `hf_config` and `model_arch` attributes - print(f'?????? dtype: {dtype}') - print(f'model_config: {model_config}') model_config = _update_torch_dtype(model_config, dtype) - print(f'after update, model_config: {model_config}') # update eos_token_id to list if isinstance(model_config.eos_token_id, int): From a4c10aa9826d412dccd472bb5626f46aa1e76b86 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 31 Oct 2025 20:10:25 +0800 Subject: [PATCH 03/15] cleanups --- lmdeploy/pytorch/models/qwen3_vl.py | 23 +---------------------- 1 file changed, 1 insertion(+), 22 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 4080dc7523..3b40f18968 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch from torch import nn @@ -21,27 +21,6 @@ from .utils.model import DeployModelMixin, vlm_model -def _apply_mrope_selection(hidden_states: torch.Tensor, mrope_position_ids: torch.Tensor, mrope_section: List[int], - position_ids: torch.Tensor, rotary_emb_func: Callable): - _mrope_position_ids = torch.zeros(3, position_ids.shape[-1], dtype=position_ids.dtype, device=position_ids.device) - _mrope_position_ids[:, :mrope_position_ids.shape[-1]] = mrope_position_ids - cos, sin = rotary_emb_func(hidden_states, _mrope_position_ids) - _cos = torch.zeros(cos.shape[1], cos.shape[-1], dtype=cos.dtype, device=cos.device) - _sin = torch.zeros_like(_cos) - mrope_section = mrope_section * 2 - - def _apply_split(src, dst): - start = 0 - for i, m in enumerate(src.split(mrope_section, dim=-1)): - dst[:, start:start + mrope_section[i]] = m[i % 3] - start += mrope_section[i] - - _apply_split(cos, _cos) - _apply_split(sin, _sin) - - return _cos, _sin - - class Qwen3VLTextRotaryEmbedding(nn.Module): inv_freq: torch.Tensor # fix linting for `register_buffer` From 73a5e95ad6c39dda42c1f934b1b27ea955d74b42 Mon Sep 17 00:00:00 2001 From: zxy Date: Fri, 31 Oct 2025 20:19:06 +0800 Subject: [PATCH 04/15] reuse input processor --- lmdeploy/pytorch/models/qwen3_vl.py | 42 ++--------------------------- 1 file changed, 2 insertions(+), 40 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 3b40f18968..c7fcdde089 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -7,14 +7,14 @@ from transformers.configuration_utils import PretrainedConfig from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update -from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor, PreprocessInputResult +from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.multimodal.data_type import MultiModalTensor from lmdeploy.pytorch.nn import LayerNorm, RMSNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3VLVisionRotaryEmbedding +from .qwen2_5_vl import Qwen2_5_VLInputProcessor as Qwen3VLInputProcessor from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention from .qwen3 import Qwen3DecoderLayer as Qwen3VLTextDecoderLayer from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin @@ -832,41 +832,3 @@ def get_input_processor(self) -> BaseModelInputProcessor: InputMultiModalType = List[Dict[str, Any]] - - -class Qwen3VLInputProcessor(BaseModelInputProcessor): - """Qwen3 input processor.""" - - def __init__(self, config: PretrainedConfig) -> None: - self.config = config - - def preprocess_input(self, - input_ids: List[int], - input_multimodals: List[Dict[str, Any]] = None, - **kwargs) -> PreprocessInputResult: - """Prepare multimodal input.""" - if input_multimodals is None or len(input_multimodals) == 0: - return input_ids, input_multimodals - - input_imgs = [] - for input_mm in input_multimodals: - pixel_values = input_mm['pixel_values'] - image_grid_thw = input_mm['image_grid_thw'] - offset = input_mm['offset'] - start = offset - image_token_id = input_mm['image_token_id'] - num_pad = input_mm['image_tokens'] - if isinstance(num_pad, torch.Tensor): - num_pad = num_pad.item() - - mm_data = MultiModalTensor(data=pixel_values, - start=start, - end=start + num_pad, - meta=dict(grid_thw=image_grid_thw, image_token_id=image_token_id)) - input_imgs.append(mm_data) - - result = PreprocessInputResult( - input_ids=input_ids, - input_multimodals=dict(image=input_imgs), - ) - return result From 01b7a97a6d926549fac05ebd6e657adaed398086 Mon Sep 17 00:00:00 2001 From: zxy Date: Mon, 3 Nov 2025 17:43:22 +0800 Subject: [PATCH 05/15] support qwen3vl moe, add docs --- README.md | 2 + README_ja.md | 2 + README_zh-CN.md | 2 + docs/en/supported_models/supported_models.md | 2 + .../supported_models/supported_models.md | 2 + lmdeploy/pytorch/models/module_map.py | 10 +- lmdeploy/pytorch/models/qwen3_moe.py | 2 +- lmdeploy/pytorch/models/qwen3_vl_moe.py | 253 ++++++++++++++++++ 8 files changed, 269 insertions(+), 6 deletions(-) create mode 100644 lmdeploy/pytorch/models/qwen3_vl_moe.py diff --git a/README.md b/README.md index 0589d9bf5a..275fe71bec 100644 --- a/README.md +++ b/README.md @@ -162,6 +162,8 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • +
  • Qwen3-VL (2B - 32B)
  • +
  • Qwen3-VL-MOE (30B, 235B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • diff --git a/README_ja.md b/README_ja.md index 75d05390ad..c7c8b42898 100644 --- a/README_ja.md +++ b/README_ja.md @@ -148,6 +148,8 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • +
  • Qwen3-VL (2B - 32B)
  • +
  • Qwen3-VL-MOE (30B, 235B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index ddec4838b9..fca38477ae 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -163,6 +163,8 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • +
  • Qwen3-VL (2B - 32B)
  • +
  • Qwen3-VL-MOE (30B, 235B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index aa28854d8a..4c53d5cf09 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -87,6 +87,8 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | +| QWen3-VL | 2B - 32B | MLLM | Yes | No | No | No | No | +| QWen3-VL-MOE | 30B, 235B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 8e9e3fef20..207265addb 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -87,6 +87,8 @@ | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | +| QWen3-VL | 2B - 32B | MLLM | Yes | No | No | No | No | +| QWen3-VL-MOE | 30B, 235B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index 2559549d1d..5441e1c5d3 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -153,11 +153,11 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl.Qwen3VLForConditionalGeneration', }) -# # qwen3_vl_moe -# MODULE_MAP.update({ -# 'Qwen3VLMoeForConditionalGeneration': -# f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration', -# }) +# qwen3_vl_moe +MODULE_MAP.update({ + 'Qwen3VLMoeForConditionalGeneration': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.qwen3_vl_moe.Qwen3VLMoeForConditionalGeneration', +}) # starcoder2 MODULE_MAP.update({ diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index 464953f264..0b212edfb5 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -52,7 +52,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim, - sliding_window=config.sliding_window, + sliding_window=(config.sliding_window if hasattr(config, 'sliding_window') else None), ) # o_proj diff --git a/lmdeploy/pytorch/models/qwen3_vl_moe.py b/lmdeploy/pytorch/models/qwen3_vl_moe.py new file mode 100644 index 0000000000..b74517150a --- /dev/null +++ b/lmdeploy/pytorch/models/qwen3_vl_moe.py @@ -0,0 +1,253 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from typing import Any, Dict, Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.configuration_utils import PretrainedConfig + +from lmdeploy.pytorch.model_inputs import StepContextManager +from lmdeploy.pytorch.nn import RMSNorm +from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight + +from .qwen3_moe import Qwen3MoeDecoderLayer as Qwen3VLMoeTextDecoderLayer +from .qwen3_vl import Qwen3VLForConditionalGeneration +from .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding + + +class Qwen3VLMoeTextModel(nn.Module): + """Text part of Qwen3VLMoe. + + not a pure text-only model, as DeepStack integrates visual features into the early hidden states. + """ + + def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): + super().__init__() + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.mrope_section = config.rope_scaling['mrope_section'] + + self.embed_tokens = nn.Embedding(config.vocab_size, + config.hidden_size, + self.padding_idx, + dtype=dtype, + device=device) + + # build all decode layers + self.layers = nn.ModuleList([ + Qwen3VLMoeTextDecoderLayer(config, layer_idx, dtype=dtype, device=device) + for layer_idx in range(config.num_hidden_layers) + ]) + + # build norm + self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + + # build rotary embedding + # TODO: zhouxinyu, add triton kernel for interleaved mrope + self.rotary_emb = Qwen3VLMoeTextRotaryEmbedding(config, device=device) + + def forward( + self, + input_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + attn_metadata: Any = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + mrope_position_ids: torch.LongTensor = None, + # args for deepstack + visual_pos_masks: Optional[torch.Tensor] = None, + deepstack_visual_embeds: Optional[list[torch.Tensor]] = None, + ): + """visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, + *optional*): + + The mask of the visual positions. deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): The deepstack + visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). The feature is extracted from the + different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack ( + https://arxiv.org/abs/2406.04) + """ + + # token embedding + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + # rotary embedding + if mrope_position_ids is None: + cos, sin = self.rotary_emb(hidden_states, position_ids) + else: + cos, sin = self.rotary_emb(hidden_states, mrope_position_ids) + + cos, sin = cos[0], sin[0] + rotary_pos_emb = (cos, sin) + + # decoding + residual = None + for idx, decoder_layer in enumerate(self.layers): + past_key_value = past_key_values[idx] + hidden_states, residual = decoder_layer( + hidden_states, + rotary_pos_emb=rotary_pos_emb, + past_key_value=past_key_value, + residual=residual, + attn_metadata=attn_metadata, + ) + + # add visual features to the hidden states of first several layers + if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)): + hidden_states = self._deepstack_process( + hidden_states, + visual_pos_masks, + deepstack_visual_embeds[idx], + ) + + # norm + hidden_states, _ = self.norm(hidden_states, residual) + + return hidden_states + + def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor, + visual_embeds: torch.Tensor): + visual_pos_masks = visual_pos_masks.to(hidden_states.device) + visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) + local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds + hidden_states[visual_pos_masks, :] = local_this + return hidden_states + + def get_input_embeddings(self): + """Get input embeddings.""" + return self.embed_tokens + + +class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): + """ModelForCausalLM.""" + + packed_modules_mapping = { + 'qkv_proj': [ + 'q_proj', + 'k_proj', + 'v_proj', + ], + 'gate_up_proj': [ + 'gate_proj', + 'up_proj', + ], + } + + def __init__(self, + config: PretrainedConfig, + ctx_mgr: StepContextManager, + dtype: torch.dtype = None, + device: torch.device = None): + super().__init__(config=config, ctx_mgr=ctx_mgr, dtype=dtype, device=device) + + self.language_model = Qwen3VLMoeTextModel(config.text_config, dtype=dtype, device=device) + + def _load_weight_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + expert_params_mapping: List): + """Load weight experts.""" + + for (param_name, weight_name, expert_id, shard_id) in expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, expert_id=expert_id, shard_id=shard_id) + else: + param = params_dict[name] + load_weight(param, loaded_weight) + + # modify from vllm qwen3vlmoe fused expert loading + def _load_weight_fused_experts(self, name: str, loaded_weight: torch.Tensor, params_dict: Dict[str, nn.Parameter], + fused_expert_params_mapping: List): + """Load weight of fused expert weights.""" + num_experts = self.config.text_config.num_experts + + for (param_name, weight_name) in fused_expert_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + + loaded_weight = loaded_weight.transpose(-1, -2) # no bias + if 'gate_up' in name: + loaded_weight = loaded_weight.chunk(2, dim=-2) + w1 = loaded_weight[0] + w3 = loaded_weight[1] + for expert_id in range(num_experts): + load_weight(param, w1[expert_id], expert_id=expert_id, shard_id='gate') + load_weight(param, w3[expert_id], expert_id=expert_id, shard_id='up') + elif 'down' in name: + w2 = loaded_weight + for expert_id in range(num_experts): + load_weight(param, w2[expert_id], expert_id=expert_id, shard_id='down') + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights.""" + # modify from vllm + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ('.qkv_proj', '.q_proj', 'q'), + ('.qkv_proj', '.k_proj', 'k'), + ('.qkv_proj', '.v_proj', 'v'), + ('.gate_up_proj', '.gate_proj', 0), + ('.gate_up_proj', '.up_proj', 1), + ] + + # expert mapping + num_experts = self.config.text_config.num_experts + expert_params_mapping = [] + for exp_id in range(num_experts): + # (param_name, weight_name, expert_id, shard_id) + gate_param = ('.experts.gate_up', f'.experts.{exp_id}.gate_proj', exp_id, 'gate') + up_param = ('.experts.gate_up', f'.experts.{exp_id}.up_proj', exp_id, 'up') + down_param = ('.experts.down', f'.experts.{exp_id}.down_proj', exp_id, 'down') + expert_params_mapping += [gate_param, up_param, down_param] + + # fused expert mapping + fused_expert_params_mapping = [ + # (param_name, weight_name) + ('.experts.gate_up.weight', '.experts.gate_up_proj'), + ('.experts.down.weight', '.experts.down_proj'), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if 'rotary_emb.inv_freq' in name: + continue + if ('rotary_emb.cos_cached' in name or 'rotary_emb.sin_cached' in name): + continue + if self.config.tie_word_embeddings and 'lm_head.weight' in name: + continue + name = name.replace('.block_sparse_moe.', '.mlp.') + if '.experts' in name: + is_fused_expert = ('experts.gate_up_proj' in name or 'experts.down_proj' in name) + if is_fused_expert: + self._load_weight_fused_experts(name, + loaded_weight, + params_dict, + fused_expert_params_mapping=fused_expert_params_mapping) + else: + self._load_weight_experts(name, + loaded_weight, + params_dict, + expert_params_mapping=expert_params_mapping) + else: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + load_weight(param, loaded_weight, shard_id=shard_id) + break + else: + if '.qkv.' in name: + param = params_dict[name] + q, k, v = param.weight_spliter(loaded_weight) + load_weight(param, q, shard_id='q') + load_weight(param, k, shard_id='k') + load_weight(param, v, shard_id='v') + else: + param = params_dict[name] + load_weight(param, loaded_weight) From c979730cae7b6baf087b26d72637900a811f1ced Mon Sep 17 00:00:00 2001 From: zxy Date: Mon, 3 Nov 2025 17:53:55 +0800 Subject: [PATCH 06/15] format --- lmdeploy/pytorch/models/qwen3_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index c7fcdde089..0d602ff60c 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -129,8 +129,8 @@ def forward( The mask of the visual positions. deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). The feature is extracted from the - different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack ( - https://arxiv.org/abs/2406.04) + different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack( + https://arxiv.org/abs/2406.04334). """ # token embedding From d0347513234f22e77d8af9e249bc4be7ef9e4100 Mon Sep 17 00:00:00 2001 From: zxy Date: Mon, 3 Nov 2025 17:55:02 +0800 Subject: [PATCH 07/15] Revert "format" This reverts commit c979730cae7b6baf087b26d72637900a811f1ced. --- lmdeploy/pytorch/models/qwen3_vl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 0d602ff60c..c7fcdde089 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -129,8 +129,8 @@ def forward( The mask of the visual positions. deepstack_visual_embeds (`list[torch.Tensor]`, *optional*): The deepstack visual embeddings. The shape is (num_layers, visual_seqlen, embed_dim). The feature is extracted from the - different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack( - https://arxiv.org/abs/2406.04334). + different visual encoder layers, and fed to the decoder hidden states. It's from the paper DeepStack ( + https://arxiv.org/abs/2406.04) """ # token embedding From db0c654e78bb11d5cf5f7a8b3bf3efe03eae5059 Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 4 Nov 2025 11:27:46 +0800 Subject: [PATCH 08/15] fix docs --- README.md | 3 +-- README_ja.md | 3 +-- README_zh-CN.md | 3 +-- docs/en/supported_models/supported_models.md | 3 +-- docs/zh_cn/supported_models/supported_models.md | 3 +-- 5 files changed, 5 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index f849706ac6..d4922e29e9 100644 --- a/README.md +++ b/README.md @@ -162,8 +162,7 @@ LMDeploy is a toolkit for compressing, deploying, and serving LLM, developed by
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • -
  • Qwen3-VL (2B - 32B)
  • -
  • Qwen3-VL-MOE (30B, 235B)
  • +
  • Qwen3-VL (2B - 235B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • diff --git a/README_ja.md b/README_ja.md index c7c8b42898..5dda14c041 100644 --- a/README_ja.md +++ b/README_ja.md @@ -148,8 +148,7 @@ LMDeploy TurboMindエンジンは卓越した推論能力を持ち、さまざ
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • -
  • Qwen3-VL (2B - 32B)
  • -
  • Qwen3-VL-MOE (30B, 235B)
  • +
  • Qwen3-VL (2B - 235B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 0a00cc1364..2e5f124d20 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -163,8 +163,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
  • Qwen-VL (7B)
  • Qwen2-VL (2B, 7B, 72B)
  • Qwen2.5-VL (3B, 7B, 72B)
  • -
  • Qwen3-VL (2B - 32B)
  • -
  • Qwen3-VL-MOE (30B, 235B)
  • +
  • Qwen3-VL (2B - 235B)
  • DeepSeek-VL (7B)
  • DeepSeek-VL2 (3B, 16B, 27B)
  • InternVL-Chat (v1.1-v1.5)
  • diff --git a/docs/en/supported_models/supported_models.md b/docs/en/supported_models/supported_models.md index 4c53d5cf09..d7f2bffa05 100644 --- a/docs/en/supported_models/supported_models.md +++ b/docs/en/supported_models/supported_models.md @@ -87,8 +87,7 @@ The following tables detail the models supported by LMDeploy's TurboMind engine | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes\* | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | -| QWen3-VL | 2B - 32B | MLLM | Yes | No | No | No | No | -| QWen3-VL-MOE | 30B, 235B | MLLM | Yes | No | No | No | No | +| QWen3-VL | 2B - 235B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | diff --git a/docs/zh_cn/supported_models/supported_models.md b/docs/zh_cn/supported_models/supported_models.md index 207265addb..73dd304e98 100644 --- a/docs/zh_cn/supported_models/supported_models.md +++ b/docs/zh_cn/supported_models/supported_models.md @@ -87,8 +87,7 @@ | Qwen3 | 0.6B - 235B | LLM | Yes | Yes | Yes\* | - | Yes | | QWen2-VL | 2B, 7B | MLLM | Yes | Yes | No | No | Yes | | QWen2.5-VL | 3B - 72B | MLLM | Yes | No | No | No | No | -| QWen3-VL | 2B - 32B | MLLM | Yes | No | No | No | No | -| QWen3-VL-MOE | 30B, 235B | MLLM | Yes | No | No | No | No | +| QWen3-VL | 2B - 235B | MLLM | Yes | No | No | No | No | | DeepSeek-MoE | 16B | LLM | Yes | No | No | No | No | | DeepSeek-V2 | 16B, 236B | LLM | Yes | No | No | No | No | | DeepSeek-V2.5 | 236B | LLM | Yes | No | No | No | No | From 1821f65dc0f906dad75ebf4685be823b8c09067c Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 4 Nov 2025 14:05:49 +0800 Subject: [PATCH 09/15] fix --- lmdeploy/pytorch/models/qwen3.py | 2 +- lmdeploy/pytorch/models/qwen3_moe.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3.py b/lmdeploy/pytorch/models/qwen3.py index fac4ca579f..381bfb72cb 100644 --- a/lmdeploy/pytorch/models/qwen3.py +++ b/lmdeploy/pytorch/models/qwen3.py @@ -47,7 +47,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim, - sliding_window=(config.sliding_window if hasattr(config, 'sliding_window') else None), + sliding_window=getattr(config, 'sliding_window', None), ) # o_proj diff --git a/lmdeploy/pytorch/models/qwen3_moe.py b/lmdeploy/pytorch/models/qwen3_moe.py index 0b212edfb5..d66ad10ebf 100644 --- a/lmdeploy/pytorch/models/qwen3_moe.py +++ b/lmdeploy/pytorch/models/qwen3_moe.py @@ -52,7 +52,7 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: head_dim, num_kv_heads=num_key_value_heads, v_head_size=head_dim, - sliding_window=(config.sliding_window if hasattr(config, 'sliding_window') else None), + sliding_window=getattr(config, 'sliding_window', None), ) # o_proj From f5c59e86544db046314eda4f775e5bb0aed565bd Mon Sep 17 00:00:00 2001 From: zxy Date: Tue, 4 Nov 2025 14:44:21 +0800 Subject: [PATCH 10/15] improve config check conditions --- lmdeploy/pytorch/config.py | 17 +++++++++-------- lmdeploy/pytorch/configurations/default.py | 13 +++++++------ 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index d0ac7c91a6..7724f7d00b 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -28,16 +28,17 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): config.dtype = torch.float16 return config - if hasattr(config.hf_config, 'text_config'): - torch_dtype = getattr(config.hf_config.text_config, 'dtype', None) + language_hf_config = config.hf_config - if torch_dtype is None: - torch_dtype = getattr(config.hf_config.text_config, 'torch_dtype', None) - else: - torch_dtype = getattr(config.hf_config, 'dtype', None) + # for multi-modal models, get the language model config to determine dtype + if hasattr(config.hf_config, 'text_config'): + language_hf_config = config.hf_config.text_config + elif hasattr(config.hf_config, 'llm_config'): + language_hf_config = config.hf_config.llm_config - if torch_dtype is None: - torch_dtype = getattr(config.hf_config, 'torch_dtype', None) + torch_dtype = getattr(language_hf_config, 'dtype', None) + if torch_dtype is None: + torch_dtype = getattr(language_hf_config, 'torch_dtype', None) # deal with case when torch_dtype is not string but torch.dtype if isinstance(torch_dtype, torch.dtype): diff --git a/lmdeploy/pytorch/configurations/default.py b/lmdeploy/pytorch/configurations/default.py index da07a3b487..4d06cd10ce 100644 --- a/lmdeploy/pytorch/configurations/default.py +++ b/lmdeploy/pytorch/configurations/default.py @@ -14,14 +14,15 @@ def condition(cls, hf_config): @classmethod def build(cls, hf_config, model_path: str = None, **kwargs): """build.""" - head_dim = getattr(hf_config, 'head_dim', None) - if hasattr(hf_config, 'text_config') and hasattr(hf_config, 'vision_config'): - # for multi-modal models config with separate text and vision configs + # for multi-modal models, get the language model config to build model config + if hasattr(hf_config, 'text_config'): hf_config = hf_config.text_config - head_dim = hf_config.head_dim - else: - head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads + elif hasattr(hf_config, 'llm_config'): + hf_config = hf_config.llm_config + + head_dim = getattr(hf_config, 'head_dim', None) + head_dim = head_dim or hf_config.hidden_size // hf_config.num_attention_heads # head_dim should not be None hf_config.head_dim = head_dim From bd68377c4a2ccf83a86c4e8257299d4729ddfbcc Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 6 Nov 2025 11:37:15 +0800 Subject: [PATCH 11/15] fix config --- lmdeploy/pytorch/config.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/lmdeploy/pytorch/config.py b/lmdeploy/pytorch/config.py index 7724f7d00b..da1d27f8c7 100644 --- a/lmdeploy/pytorch/config.py +++ b/lmdeploy/pytorch/config.py @@ -28,17 +28,9 @@ def _update_torch_dtype(config: 'ModelConfig', dtype: str): config.dtype = torch.float16 return config - language_hf_config = config.hf_config - - # for multi-modal models, get the language model config to determine dtype - if hasattr(config.hf_config, 'text_config'): - language_hf_config = config.hf_config.text_config - elif hasattr(config.hf_config, 'llm_config'): - language_hf_config = config.hf_config.llm_config - - torch_dtype = getattr(language_hf_config, 'dtype', None) + torch_dtype = getattr(config.llm_config, 'dtype', None) if torch_dtype is None: - torch_dtype = getattr(language_hf_config, 'torch_dtype', None) + torch_dtype = getattr(config.llm_config, 'torch_dtype', None) # deal with case when torch_dtype is not string but torch.dtype if isinstance(torch_dtype, torch.dtype): From 004b646769d7e6abf6fdfe41bfb6cd2619cb6d0c Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 6 Nov 2025 11:43:57 +0800 Subject: [PATCH 12/15] some optimizations --- lmdeploy/pytorch/models/qwen3_vl.py | 114 ++++++++++++---------------- 1 file changed, 47 insertions(+), 67 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index c7fcdde089..2a6ac06b85 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -162,11 +162,13 @@ def forward( # add visual features to the hidden states of first several layers if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)): + hidden_states = hidden_states + residual hidden_states = self._deepstack_process( hidden_states, visual_pos_masks, deepstack_visual_embeds[idx], ) + residual = None # norm hidden_states, _ = self.norm(hidden_states, residual) @@ -232,19 +234,16 @@ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: is_tp=True, ) - # silu and mul - if config.hidden_act in ['gelu', 'gelu_fast', 'quick_gelu', 'gelu_python']: - self.act = nn.GELU() - else: - self.act = ACT2FN[config.hidden_act] + # gelu_pytorch_tanh + self.act = ACT2FN[config.hidden_act] # down self.linear_fc2 = build_rowwise_linear(intermediate_size, hidden_dim, bias=True, - quant_config=quantization_config, dtype=dtype, device=device, + quant_config=quantization_config, is_tp=True) def forward(self, x): @@ -270,21 +269,16 @@ def __init__(self, self.mlp = Qwen3VLVisionMLP(config, dtype=dtype, device=device) def forward(self, - hidden_states, - cu_seqlens, - rotary_pos_emb, - residual: Optional[torch.Tensor] = None) -> torch.Tensor: - if residual is None: - residual = hidden_states - hidden_states = self.norm1(hidden_states) - else: - hidden_states, residual = self.norm1(hidden_states, residual) - - hidden_states = self.attn(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) - - hidden_states, residual = self.norm2(hidden_states, residual) - hidden_states = self.mlp(hidden_states) - return hidden_states, residual + hidden_states: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb: Optional[torch.Tensor] = None) -> torch.Tensor: + hidden_states = hidden_states + self.attn( + self.norm1(hidden_states), + cu_seqlens=cu_seqlens, + rotary_pos_emb=rotary_pos_emb, + ) + hidden_states = hidden_states + self.mlp(self.norm2(hidden_states)) + return hidden_states class Qwen3VLVisionPatchMerger(nn.Module): @@ -297,10 +291,10 @@ def __init__(self, super().__init__() self.hidden_size = config.hidden_size * (config.spatial_merge_size**2) self.use_postshuffle_norm = use_postshuffle_norm - self.norm = nn.LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, - eps=1e-6, - dtype=dtype, - device=device) + self.norm = LayerNorm(self.hidden_size if use_postshuffle_norm else config.hidden_size, + eps=1e-6, + dtype=dtype, + device=device) self.linear_fc1 = build_colwise_linear( self.hidden_size, self.hidden_size, @@ -456,21 +450,17 @@ def forward(self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, rotary_ hidden_states = hidden_states + pos_embeds cu_seqlens = torch.nn.functional.pad(cu_seqlens, (1, 0), value=0) - residual = None deepstack_feature_lists = [] for layer_num, blk in enumerate(self.blocks): - hidden_states, residual = blk(hidden_states, - cu_seqlens=cu_seqlens, - rotary_pos_emb=rotary_pos_emb, - residual=residual) + hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb) if layer_num in self.deepstack_visual_indexes: - deepstack_feature = self.deepstack_merger_list[self.deepstack_visual_indexes.index(layer_num)]( - (hidden_states + residual)) + deepstack_merge_idx = self.deepstack_visual_indexes.index(layer_num) + deepstack_feature = self.deepstack_merger_list[deepstack_merge_idx](hidden_states) deepstack_feature_lists.append(deepstack_feature) - hidden_states = hidden_states + residual + hidden_states = self.merger(hidden_states) - return self.merger(hidden_states), deepstack_feature_lists + return hidden_states, deepstack_feature_lists class Qwen3VLForConditionalGeneration(nn.Module, DeployModelMixin, CudaGraphMixin): @@ -517,37 +507,6 @@ def __init__(self, dtype=dtype, device=device) - def _prepare_multimodal_inputs(self, input_ids: torch.Tensor, pixel_values: torch.Tensor, image_mask: torch.Tensor, - grid_thw: torch.Tensor, vis_cu_seqlens: torch.Tensor, vis_pos_emb: torch.Tensor, - pos_embeds: torch.Tensor): - """Prepare multimodal inputs for language model.""" - inputs_embeds = self.get_input_embeddings()(input_ids) - if pixel_values is None: - return inputs_embeds, None, None - - dtype = inputs_embeds.dtype - pixel_values = pixel_values.to(dtype) - vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype)) - - # get image embeds and deepstack visual embeds - image_embeds, deepstack_visual_embeds = self.visual(pixel_values, - cu_seqlens=vis_cu_seqlens, - rotary_pos_emb=vis_pos_emb, - pos_embeds=pos_embeds) - - # split image embeds per sample - split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() - image_embeds = torch.split(image_embeds, split_sizes) - image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) - - # mask and scatter to create final input embeddings - expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) - final_inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) - - visual_pos_masks = expanded_image_mask[..., 0] - - return final_inputs_embeds, visual_pos_masks, deepstack_visual_embeds - def forward( self, input_ids: torch.Tensor, @@ -569,8 +528,29 @@ def forward( visual_pos_masks = None deepstack_visual_embeds = None if inputs_embeds is None: - inputs_embeds, visual_pos_masks, deepstack_visual_embeds = self._prepare_multimodal_inputs( - input_ids, pixel_values, image_mask, grid_thw, vis_cu_seqlens, vis_pos_emb, pos_embeds) + inputs_embeds = self.get_input_embeddings()(input_ids) + + if pixel_values is not None: + dtype = inputs_embeds.dtype + pixel_values = pixel_values.to(dtype) + vis_pos_emb = (vis_pos_emb[0].to(dtype), vis_pos_emb[1].to(dtype)) + + # get image embeds and deepstack visual embeds + image_embeds, deepstack_visual_embeds = self.visual(pixel_values, + cu_seqlens=vis_cu_seqlens, + rotary_pos_emb=vis_pos_emb, + pos_embeds=pos_embeds) + + # split image embeds per sample + split_sizes = (grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist() + image_embeds = torch.split(image_embeds, split_sizes) + image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, dtype) + + # mask and scatter to create final input embeddings + expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) + inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) + + visual_pos_masks = expanded_image_mask[..., 0] hidden_states = self.language_model( input_ids=input_ids, From 0be838ccc92344dd324d9af982c53a211958135c Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 6 Nov 2025 12:03:28 +0800 Subject: [PATCH 13/15] reuse qwen3, qwen3-moe --- lmdeploy/pytorch/models/qwen3_vl.py | 30 +++------------------- lmdeploy/pytorch/models/qwen3_vl_moe.py | 33 +++++-------------------- 2 files changed, 10 insertions(+), 53 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 2a6ac06b85..4a88ac58f1 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -9,14 +9,14 @@ from lmdeploy.pytorch.engine.input_process import BaseModelInputProcessor from lmdeploy.pytorch.model_inputs import StepContext, StepContextManager -from lmdeploy.pytorch.nn import LayerNorm, RMSNorm +from lmdeploy.pytorch.nn import LayerNorm from lmdeploy.pytorch.nn.linear import build_colwise_linear, build_rowwise_linear from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight from .qwen2_5_vl import Qwen2_5_VisionRotaryEmbedding as Qwen3VLVisionRotaryEmbedding from .qwen2_5_vl import Qwen2_5_VLInputProcessor as Qwen3VLInputProcessor from .qwen2_5_vl import Qwen2_5_VLVisionAttention as Qwen3VLVisionAttention -from .qwen3 import Qwen3DecoderLayer as Qwen3VLTextDecoderLayer +from .qwen3 import Qwen3model from .utils.cudagraph import CudaGraphMeta, CudaGraphMixin from .utils.model import DeployModelMixin, vlm_model @@ -81,32 +81,14 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) -class Qwen3VLTextModel(nn.Module): +class Qwen3VLTextModel(Qwen3model): """Text part of Qwen3VL. not a pure text-only model, as DeepStack integrates visual features into the early hidden states. """ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): - super().__init__() - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.mrope_section = config.rope_scaling['mrope_section'] - - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) - - # build all decode layers - self.layers = nn.ModuleList([ - Qwen3VLTextDecoderLayer(config, layer_idx, dtype=dtype, device=device) - for layer_idx in range(config.num_hidden_layers) - ]) - - # build norm - self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + super().__init__(config=config, dtype=dtype, device=device) # build rotary embedding # TODO: zhouxinyu, add triton kernel for interleaved mrope @@ -183,10 +165,6 @@ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torc hidden_states[visual_pos_masks, :] = local_this return hidden_states - def get_input_embeddings(self): - """Get input embeddings.""" - return self.embed_tokens - class Qwen3VLVisionPatchEmbed(nn.Module): diff --git a/lmdeploy/pytorch/models/qwen3_vl_moe.py b/lmdeploy/pytorch/models/qwen3_vl_moe.py index b74517150a..e48762daf2 100644 --- a/lmdeploy/pytorch/models/qwen3_vl_moe.py +++ b/lmdeploy/pytorch/models/qwen3_vl_moe.py @@ -7,40 +7,21 @@ from transformers.configuration_utils import PretrainedConfig from lmdeploy.pytorch.model_inputs import StepContextManager -from lmdeploy.pytorch.nn import RMSNorm from lmdeploy.pytorch.weight_loader.model_weight_loader import load_weight -from .qwen3_moe import Qwen3MoeDecoderLayer as Qwen3VLMoeTextDecoderLayer +from .qwen3_moe import Qwen3MoeModel from .qwen3_vl import Qwen3VLForConditionalGeneration from .qwen3_vl import Qwen3VLTextRotaryEmbedding as Qwen3VLMoeTextRotaryEmbedding -class Qwen3VLMoeTextModel(nn.Module): - """Text part of Qwen3VLMoe. +class Qwen3VLMoeTextModel(Qwen3MoeModel): + """Text part of Qwen3VL. not a pure text-only model, as DeepStack integrates visual features into the early hidden states. """ def __init__(self, config: PretrainedConfig, dtype: torch.dtype = None, device: torch.device = None): - super().__init__() - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - self.mrope_section = config.rope_scaling['mrope_section'] - - self.embed_tokens = nn.Embedding(config.vocab_size, - config.hidden_size, - self.padding_idx, - dtype=dtype, - device=device) - - # build all decode layers - self.layers = nn.ModuleList([ - Qwen3VLMoeTextDecoderLayer(config, layer_idx, dtype=dtype, device=device) - for layer_idx in range(config.num_hidden_layers) - ]) - - # build norm - self.norm = RMSNorm(config.hidden_size, config.rms_norm_eps, dtype=dtype, device=device) + super().__init__(config=config, dtype=dtype, device=device) # build rotary embedding # TODO: zhouxinyu, add triton kernel for interleaved mrope @@ -96,11 +77,13 @@ def forward( # add visual features to the hidden states of first several layers if deepstack_visual_embeds is not None and idx in range(len(deepstack_visual_embeds)): + hidden_states = hidden_states + residual hidden_states = self._deepstack_process( hidden_states, visual_pos_masks, deepstack_visual_embeds[idx], ) + residual = None # norm hidden_states, _ = self.norm(hidden_states, residual) @@ -115,10 +98,6 @@ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torc hidden_states[visual_pos_masks, :] = local_this return hidden_states - def get_input_embeddings(self): - """Get input embeddings.""" - return self.embed_tokens - class Qwen3VLMoeForConditionalGeneration(Qwen3VLForConditionalGeneration): """ModelForCausalLM.""" From 6eba480ff5c6ab4cd6f49953cd2925a7ef0f9e75 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 6 Nov 2025 18:53:50 +0800 Subject: [PATCH 14/15] fix mrope acc bug --- lmdeploy/pytorch/models/qwen3_vl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index 4a88ac58f1..f7ed3a231e 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -125,6 +125,7 @@ def forward( if mrope_position_ids is None: cos, sin = self.rotary_emb(hidden_states, position_ids) else: + mrope_position_ids = mrope_position_ids.unsqueeze(1) cos, sin = self.rotary_emb(hidden_states, mrope_position_ids) cos, sin = cos[0], sin[0] From 9f061e192d7e5b39e93ebb4aba39c3dd6c9fc6f5 Mon Sep 17 00:00:00 2001 From: zxy Date: Thu, 6 Nov 2025 20:13:47 +0800 Subject: [PATCH 15/15] fix moe, optimize deepstack process --- lmdeploy/pytorch/models/qwen3_vl.py | 7 ++++--- lmdeploy/pytorch/models/qwen3_vl_moe.py | 6 ++++-- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/lmdeploy/pytorch/models/qwen3_vl.py b/lmdeploy/pytorch/models/qwen3_vl.py index f7ed3a231e..6844c6b8d0 100644 --- a/lmdeploy/pytorch/models/qwen3_vl.py +++ b/lmdeploy/pytorch/models/qwen3_vl.py @@ -162,8 +162,9 @@ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torc visual_embeds: torch.Tensor): visual_pos_masks = visual_pos_masks.to(hidden_states.device) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) - local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds - hidden_states[visual_pos_masks, :] = local_this + local = torch.zeros_like(hidden_states) + local.masked_scatter_(visual_pos_masks, visual_embeds) + hidden_states += local return hidden_states @@ -529,7 +530,7 @@ def forward( expanded_image_mask = image_mask.unsqueeze(-1).expand_as(inputs_embeds) inputs_embeds = inputs_embeds.masked_scatter(expanded_image_mask, image_embeds) - visual_pos_masks = expanded_image_mask[..., 0] + visual_pos_masks = expanded_image_mask hidden_states = self.language_model( input_ids=input_ids, diff --git a/lmdeploy/pytorch/models/qwen3_vl_moe.py b/lmdeploy/pytorch/models/qwen3_vl_moe.py index e48762daf2..1dc7e32de9 100644 --- a/lmdeploy/pytorch/models/qwen3_vl_moe.py +++ b/lmdeploy/pytorch/models/qwen3_vl_moe.py @@ -58,6 +58,7 @@ def forward( if mrope_position_ids is None: cos, sin = self.rotary_emb(hidden_states, position_ids) else: + mrope_position_ids = mrope_position_ids.unsqueeze(1) cos, sin = self.rotary_emb(hidden_states, mrope_position_ids) cos, sin = cos[0], sin[0] @@ -94,8 +95,9 @@ def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torc visual_embeds: torch.Tensor): visual_pos_masks = visual_pos_masks.to(hidden_states.device) visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype) - local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds - hidden_states[visual_pos_masks, :] = local_this + local = torch.zeros_like(hidden_states) + local.masked_scatter_(visual_pos_masks, visual_embeds) + hidden_states += local return hidden_states