diff --git a/vllm_ascend/models/qwen2_5_vl.py b/vllm_ascend/models/qwen2_5_vl.py index 6f07afdc61d..7ceefe737f3 100644 --- a/vllm_ascend/models/qwen2_5_vl.py +++ b/vllm_ascend/models/qwen2_5_vl.py @@ -27,6 +27,8 @@ from einops import rearrange from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import ( Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig) + +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention from vllm.config import VllmConfig from vllm.distributed import parallel_state from vllm.distributed import utils as dist_utils @@ -49,91 +51,6 @@ if not vllm_version_is("0.11.0"): from vllm.model_executor.models.vision import conv3d_to_linear_weight -MIN_PAD_SIZE = 64 # min_size to pad weight -MAX_PAD_SIZE = 128 # max_size to pad weight - - -class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention): - - def __init__( - self, - embed_dim: int, - num_heads: int, - projection_size: int, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - super().__init__( - embed_dim, - num_heads, - projection_size, - quant_config, - prefix, - ) - self.embed_dim = embed_dim - self.hidden_size_per_attention_head = dist_utils.divide( - projection_size, num_heads) - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: - # [s, b, 3 * head * head_dim] - seq_len, bs, _ = qkv.shape - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] - q, k, v = qkv.chunk(3, dim=2) - - # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] - new_shape = (seq_len, bs, self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head) - q, k, v = (x.view(*new_shape) for x in (q, k, v)) - return q, k, v - - def forward( - self, - x: torch.Tensor, - cu_seqlens: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - # [s, b, c] --> [s, b, head * 3 * head_dim] - x, _ = self.qkv(x) - - # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] - q, k, v = self.split_qkv(x) - batch_size = q.shape[1] - - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() - for x in (q, k, v)) - q = torch_npu.npu_rotary_mul(q, cos, sin) - k = torch_npu.npu_rotary_mul(k, cos, sin) - - q, k, v = [ - rearrange(x, "b s h d -> (b s) h d").contiguous() - for x in (q, k, v) - ] - - context_layer = torch.empty_like(q) - - # operator requires pta version >= 2.5.1 - torch_npu._npu_flash_attention_unpad( - query=q, - key=k, - value=v, - seq_len=cu_seqlens, - scale_value=self.origin_hidden_size_per_attention_head**-0.5, - num_heads=self.num_attention_heads_per_partition, - num_kv_heads=self.num_attention_heads_per_partition, - out=context_layer) - - context_layer = rearrange(context_layer, - "(b s) h d -> s b (h d)", - b=batch_size).contiguous() - - output, _ = self.proj(context_layer) - return output - class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock): @@ -149,11 +66,11 @@ def __init__( ) -> None: super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer, quant_config, prefix) - self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim, - num_heads=num_heads, - projection_size=dim, - quant_config=quant_config, - prefix=f"{prefix}.attn") + self.attn = MMEncoderAttention(embed_dim=dim, + num_heads=num_heads, + projection_size=dim, + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: @@ -194,7 +111,7 @@ def __init__( super().__init__(vision_config, norm_eps, quant_config, prefix) norm_layer = partial(RMSNorm, eps=norm_eps) self.interleaved = interleaved - self.enable_pad = False + head_dim = self.hidden_size // self.num_heads self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim // 2) @@ -222,131 +139,6 @@ def __init__( self.hidden_size_per_attention_head = dist_utils.divide( self.hidden_size, self.num_heads) - if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE: - self.enable_pad = True - self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head - self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 - self.half_pad_hidden_size_per_attention_head = ( - MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 - self.hidden_size_per_attention_head = MAX_PAD_SIZE - - def cal_cos_sin(self, rotary_pos_emb): - cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2] - sin = rotary_pos_emb.sin() - if self.enable_pad: - cos = torch.nn.functional.pad( - cos, (0, self.half_pad_hidden_size_per_attention_head)) - sin = torch.nn.functional.pad( - sin, (0, self.half_pad_hidden_size_per_attention_head)) - - if not self.interleaved: - cos_new = torch.cat((cos, cos), dim=-1) - sin_new = torch.cat((sin, sin), dim=-1) - else: - cos_new = rearrange(torch.stack((cos, cos), dim=-1), - "... d two -> ...(d two)", - two=2) - sin_new = rearrange(torch.stack((sin, sin), dim=-1), - "... d two -> ...(d two)", - two=2) - cos_new = cos_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - sin_new = sin_new.reshape(1, -1, 1, - self.hidden_size_per_attention_head) - return cos_new, sin_new - - def pad_qkv_bias(self, bias): - first_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, :self.half_origin_hidden_size_per_attention_head] - second_half = bias.reshape( - -1, 3, self.origin_hidden_size_per_attention_head - )[:, :, self.half_origin_hidden_size_per_attention_head:] - first_half_padded = torch.nn.functional.pad( - first_half, (0, self.half_pad_hidden_size_per_attention_head)) - second_half_padded = torch.nn.functional.pad( - second_half, (0, self.half_pad_hidden_size_per_attention_head)) - bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2) - bias_final = bias_padded.reshape(-1) - return bias_final - - def pad_qkv_weight(self, data): - qkv_weight_first_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, :self.half_origin_hidden_size_per_attention_head, :] - qkv_weight_second_half = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size - )[:, :, self.half_origin_hidden_size_per_attention_head:, :] - - qkv_weight_first_half_padded = torch.nn.functional.pad( - qkv_weight_first_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_second_half_padded = torch.nn.functional.pad( - qkv_weight_second_half, - (0, 0, 0, self.half_pad_hidden_size_per_attention_head)) - qkv_weight_padded = torch.cat( - [qkv_weight_first_half_padded, qkv_weight_second_half_padded], - dim=2) - qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size) - - if is_enable_nz(): - qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_( - qkv_weight_final) - qkv_weight_final_copy = torch_npu.npu_format_cast( - qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND) - return qkv_weight_final_copy - - return qkv_weight_final - - def pad_proj_weight(self, data): - out_weight = torch.nn.functional.pad( - data.reshape(self.hidden_size, -1, - self.half_origin_hidden_size_per_attention_head), - (0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape( - self.hidden_size, -1) - - if is_enable_nz(): - out_weight_copy = torch.empty_like(out_weight).copy_(out_weight) - out_weight_copy = torch_npu.npu_format_cast( - out_weight_copy, ACL_FORMAT_FRACTAL_ND) - return out_weight_copy - - return out_weight - - def pad_qkv_weight_scale_offset(self, data): - reshaped_data = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head, 1) - data1 = reshaped_data[:, :, :self. - half_origin_hidden_size_per_attention_head, :] - data2 = reshaped_data[:, :, self. - half_origin_hidden_size_per_attention_head:, :] - data1_paded = torch.nn.functional.pad( - data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, - 0, 0, 0)) - data2_paded = torch.nn.functional.pad( - data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0, - 0, 0, 0)) - res = torch.cat([data1_paded, data2_paded], dim=2) - res = res.reshape(-1, 1) - return res - - def pad_qkv_deq_scale_quant_bias(self, data): - reshaped_data = data.reshape( - -1, 3, self.origin_hidden_size_per_attention_head) - data1 = reshaped_data[:, :, :self. - half_origin_hidden_size_per_attention_head] - data2 = reshaped_data[:, :, - self.half_origin_hidden_size_per_attention_head:] - - data1_paded = torch.nn.functional.pad( - data1, (0, self.half_pad_hidden_size_per_attention_head)) - data2_paded = torch.nn.functional.pad( - data2, (0, self.half_pad_hidden_size_per_attention_head)) - - res = torch.cat([data1_paded, data2_paded], dim=2) - res = res.reshape(-1) - return res - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [ @@ -377,24 +169,6 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - if ("attn.proj.weight_scale" in name or - "attn.proj.weight_offset" in name) and self.enable_pad: - continue - elif ("attn.proj.deq_scale" in name - or "attn.proj.quant_bias" in name) and self.enable_pad: - continue - elif ("attn.qkv.weight_scale" in name - or "attn.qkv.weight_offset" in name) and self.enable_pad: - param.data = self.pad_qkv_weight_scale_offset(param.data) - elif ("attn.qkv.deq_scale" in name - or "attn.qkv.quant_bias" in name) and self.enable_pad: - param.data = self.pad_qkv_deq_scale_quant_bias(param.data) - elif ("attn.proj.weight" in name) and self.enable_pad: - param.data = self.pad_proj_weight(param.data) - elif ("attn.qkv.weight" in name) and self.enable_pad: - param.data = self.pad_qkv_weight(param.data) - elif ("attn.qkv.bias" in name) and self.enable_pad: - param.data = self.pad_qkv_bias(param.data) loaded_params.add(name) return loaded_params diff --git a/vllm_ascend/ops/multi_modal/mm_encoder_attention.py b/vllm_ascend/ops/multi_modal/mm_encoder_attention.py new file mode 100644 index 00000000000..790bc875830 --- /dev/null +++ b/vllm_ascend/ops/multi_modal/mm_encoder_attention.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn.functional as F + +from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention + + +MIN_PAD_SIZE = 64 # min_size to pad weight +MAX_PAD_SIZE = 128 # max_size to pad weight + + +class AscendMMEncoderAttention(MMEncoderAttention): + + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + embed_dim, + num_heads, + projection_size, + quant_config, + prefix, + ) + + self.embed_dim = embed_dim + self.hidden_size_per_attention_head = dist_utils.divide( + projection_size, num_heads) + self.enable_pad = False + self.finish_pad = False + + # TODO(shen-shanshan): Add verification for env vars (enable unpad). + if self.hidden_size_per_attention_head > MIN_PAD_SIZE \ + and self.hidden_size_per_attention_head < MAX_PAD_SIZE: + self.enable_pad = True + self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head + self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2 + self.half_pad_hidden_size_per_attention_head = ( + MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2 + self.hidden_size_per_attention_head = MAX_PAD_SIZE + + def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]: + # [s, b, 3 * head * head_dim] + seq_len, bs, _ = qkv.shape + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim] + q, k, v = qkv.chunk(3, dim=2) + + # 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim] + new_shape = (seq_len, bs, self.num_attention_heads_per_partition, + self.hidden_size_per_attention_head) + q, k, v = (x.view(*new_shape) for x in (q, k, v)) + return q, k, v + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + # [s, b, c] --> [s, b, head * 3 * head_dim] + x, _ = self.qkv(x) + + # [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim] + q, k, v = self.split_qkv(x) + batch_size = q.shape[1] + + q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() + for x in (q, k, v)) + q = torch_npu.npu_rotary_mul(q, cos, sin) + k = torch_npu.npu_rotary_mul(k, cos, sin) + + q, k, v = [ + rearrange(x, "b s h d -> (b s) h d").contiguous() + for x in (q, k, v) + ] + + # TODO(shen-shanshan): use context manager. + if self.enable_pad: + origin_shape = q.shape[-1] + pad_len = MAX_PAD_SIZE - origin_shape + q = F.pad(q, (0, pad_len), mode="constant", value=0) + k = F.pad(k, (0, pad_len), mode="constant", value=0) + v = F.pad(v, (0, pad_len), mode="constant", value=0) + + context_layer = torch.empty_like(q) + + # operator requires pta version >= 2.5.1 + torch_npu._npu_flash_attention_unpad( + query=q, + key=k, + value=v, + seq_len=cu_seqlens, + scale_value=self.origin_hidden_size_per_attention_head**-0.5, + num_heads=self.num_attention_heads_per_partition, + num_kv_heads=self.num_attention_heads_per_partition, + out=context_layer) + + # TODO(shen-shanshan): use context manager. + if self.enable_pad: + context_layer = context_layer[..., :origin_shape] + + context_layer = rearrange(context_layer, + "(b s) h d -> s b (h d)", + b=batch_size).contiguous() + + output, _ = self.proj(context_layer) + return output diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 381510809a8..7789286edf2 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -698,6 +698,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm_ascend.ops.vocab_parallel_embedding import ( AscendLogitsProcessor, AscendParallelLMHead, AscendVocabParallelEmbedding) + from vllm_ascend.ops.multi_modal.mm_encoder_attention import AscendMMEncoderAttention global REGISTERED_ASCEND_OPS REGISTERED_ASCEND_OPS = { @@ -719,6 +720,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): "GemmaRMSNorm": AscendGemmaRMSNorm, "FusedMoE": AscendFusedMoE, "SharedFusedMoE": AscendSharedFusedMoE, + "MMEncoderAttention": AscendMMEncoderAttention, } mla_to_register = "MultiHeadLatentAttention" if vllm_version_is( "0.11.0") else "MultiHeadLatentAttentionWrapper"