Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
242 changes: 8 additions & 234 deletions vllm_ascend/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from einops import rearrange
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)

from vllm.attention.layers.mm_encoder_attention import MMEncoderAttention
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
Expand All @@ -49,91 +51,6 @@
if not vllm_version_is("0.11.0"):
from vllm.model_executor.models.vision import conv3d_to_linear_weight

MIN_PAD_SIZE = 64 # min_size to pad weight
MAX_PAD_SIZE = 128 # max_size to pad weight


class AscendQwen2_5_VisionAttention(Qwen2_5_VisionAttention):

def __init__(
self,
embed_dim: int,
num_heads: int,
projection_size: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__(
embed_dim,
num_heads,
projection_size,
quant_config,
prefix,
)
self.embed_dim = embed_dim
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads)
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
self.hidden_size_per_attention_head = MAX_PAD_SIZE

def split_qkv(self, qkv: torch.Tensor) -> tuple[torch.Tensor, ...]:
# [s, b, 3 * head * head_dim]
seq_len, bs, _ = qkv.shape

# [s, b, 3 * head * head_dim] -> 3 * [s, b, head * head_dim]
q, k, v = qkv.chunk(3, dim=2)

# 3 * [s, b, head * head_dim] -> 3 * [s, b, head, head_dim]
new_shape = (seq_len, bs, self.num_attention_heads_per_partition,
self.hidden_size_per_attention_head)
q, k, v = (x.view(*new_shape) for x in (q, k, v))
return q, k, v

def forward(
self,
x: torch.Tensor,
cu_seqlens: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
) -> torch.Tensor:
# [s, b, c] --> [s, b, head * 3 * head_dim]
x, _ = self.qkv(x)

# [s, b, 3 * head * head_dim] -> 3 * [s, b, head, head_dim]
q, k, v = self.split_qkv(x)
batch_size = q.shape[1]

q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous()
for x in (q, k, v))
q = torch_npu.npu_rotary_mul(q, cos, sin)
k = torch_npu.npu_rotary_mul(k, cos, sin)

q, k, v = [
rearrange(x, "b s h d -> (b s) h d").contiguous()
for x in (q, k, v)
]

context_layer = torch.empty_like(q)

# operator requires pta version >= 2.5.1
torch_npu._npu_flash_attention_unpad(
query=q,
key=k,
value=v,
seq_len=cu_seqlens,
scale_value=self.origin_hidden_size_per_attention_head**-0.5,
num_heads=self.num_attention_heads_per_partition,
num_kv_heads=self.num_attention_heads_per_partition,
out=context_layer)

context_layer = rearrange(context_layer,
"(b s) h d -> s b (h d)",
b=batch_size).contiguous()

output, _ = self.proj(context_layer)
return output


class AscendQwen2_5_VisionBlock(Qwen2_5_VisionBlock):

Expand All @@ -149,11 +66,11 @@ def __init__(
) -> None:
super().__init__(dim, num_heads, mlp_hidden_dim, act_fn, norm_layer,
quant_config, prefix)
self.attn = AscendQwen2_5_VisionAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")
self.attn = MMEncoderAttention(embed_dim=dim,
num_heads=num_heads,
projection_size=dim,
quant_config=quant_config,
prefix=f"{prefix}.attn")

def forward(self, x: torch.Tensor, cu_seqlens: torch.Tensor,
cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -194,7 +111,7 @@ def __init__(
super().__init__(vision_config, norm_eps, quant_config, prefix)
norm_layer = partial(RMSNorm, eps=norm_eps)
self.interleaved = interleaved
self.enable_pad = False

head_dim = self.hidden_size // self.num_heads
self.rotary_pos_emb = AscendQwen2_5_VisionRotaryEmbedding(head_dim //
2)
Expand Down Expand Up @@ -222,131 +139,6 @@ def __init__(
self.hidden_size_per_attention_head = dist_utils.divide(
self.hidden_size, self.num_heads)

if self.hidden_size_per_attention_head > MIN_PAD_SIZE and self.hidden_size_per_attention_head < MAX_PAD_SIZE:
self.enable_pad = True
self.origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head
self.half_origin_hidden_size_per_attention_head = self.hidden_size_per_attention_head // 2
self.half_pad_hidden_size_per_attention_head = (
MAX_PAD_SIZE - self.hidden_size_per_attention_head) // 2
self.hidden_size_per_attention_head = MAX_PAD_SIZE

def cal_cos_sin(self, rotary_pos_emb):
cos = rotary_pos_emb.cos() # [seqlen, rotary_dim / 2]
sin = rotary_pos_emb.sin()
if self.enable_pad:
cos = torch.nn.functional.pad(
cos, (0, self.half_pad_hidden_size_per_attention_head))
sin = torch.nn.functional.pad(
sin, (0, self.half_pad_hidden_size_per_attention_head))

if not self.interleaved:
cos_new = torch.cat((cos, cos), dim=-1)
sin_new = torch.cat((sin, sin), dim=-1)
else:
cos_new = rearrange(torch.stack((cos, cos), dim=-1),
"... d two -> ...(d two)",
two=2)
sin_new = rearrange(torch.stack((sin, sin), dim=-1),
"... d two -> ...(d two)",
two=2)
cos_new = cos_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
sin_new = sin_new.reshape(1, -1, 1,
self.hidden_size_per_attention_head)
return cos_new, sin_new

def pad_qkv_bias(self, bias):
first_half = bias.reshape(
-1, 3, self.origin_hidden_size_per_attention_head
)[:, :, :self.half_origin_hidden_size_per_attention_head]
second_half = bias.reshape(
-1, 3, self.origin_hidden_size_per_attention_head
)[:, :, self.half_origin_hidden_size_per_attention_head:]
first_half_padded = torch.nn.functional.pad(
first_half, (0, self.half_pad_hidden_size_per_attention_head))
second_half_padded = torch.nn.functional.pad(
second_half, (0, self.half_pad_hidden_size_per_attention_head))
bias_padded = torch.cat([first_half_padded, second_half_padded], dim=2)
bias_final = bias_padded.reshape(-1)
return bias_final

def pad_qkv_weight(self, data):
qkv_weight_first_half = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
)[:, :, :self.half_origin_hidden_size_per_attention_head, :]
qkv_weight_second_half = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head, self.hidden_size
)[:, :, self.half_origin_hidden_size_per_attention_head:, :]

qkv_weight_first_half_padded = torch.nn.functional.pad(
qkv_weight_first_half,
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
qkv_weight_second_half_padded = torch.nn.functional.pad(
qkv_weight_second_half,
(0, 0, 0, self.half_pad_hidden_size_per_attention_head))
qkv_weight_padded = torch.cat(
[qkv_weight_first_half_padded, qkv_weight_second_half_padded],
dim=2)
qkv_weight_final = qkv_weight_padded.reshape(-1, self.hidden_size)

if is_enable_nz():
qkv_weight_final_copy = torch.empty_like(qkv_weight_final).copy_(
qkv_weight_final)
qkv_weight_final_copy = torch_npu.npu_format_cast(
qkv_weight_final_copy, ACL_FORMAT_FRACTAL_ND)
return qkv_weight_final_copy

return qkv_weight_final

def pad_proj_weight(self, data):
out_weight = torch.nn.functional.pad(
data.reshape(self.hidden_size, -1,
self.half_origin_hidden_size_per_attention_head),
(0, self.half_pad_hidden_size_per_attention_head, 0, 0)).reshape(
self.hidden_size, -1)

if is_enable_nz():
out_weight_copy = torch.empty_like(out_weight).copy_(out_weight)
out_weight_copy = torch_npu.npu_format_cast(
out_weight_copy, ACL_FORMAT_FRACTAL_ND)
return out_weight_copy

return out_weight

def pad_qkv_weight_scale_offset(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head, 1)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head, :]
data2 = reshaped_data[:, :, self.
half_origin_hidden_size_per_attention_head:, :]
data1_paded = torch.nn.functional.pad(
data1, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
data2_paded = torch.nn.functional.pad(
data2, (0, 0, 0, self.half_pad_hidden_size_per_attention_head, 0,
0, 0, 0))
res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1, 1)
return res

def pad_qkv_deq_scale_quant_bias(self, data):
reshaped_data = data.reshape(
-1, 3, self.origin_hidden_size_per_attention_head)
data1 = reshaped_data[:, :, :self.
half_origin_hidden_size_per_attention_head]
data2 = reshaped_data[:, :,
self.half_origin_hidden_size_per_attention_head:]

data1_paded = torch.nn.functional.pad(
data1, (0, self.half_pad_hidden_size_per_attention_head))
data2_paded = torch.nn.functional.pad(
data2, (0, self.half_pad_hidden_size_per_attention_head))

res = torch.cat([data1_paded, data2_paded], dim=2)
res = res.reshape(-1)
return res

def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping: list[tuple[str, str, Union[str, int]]] = [
Expand Down Expand Up @@ -377,24 +169,6 @@ def load_weights(self, weights: Iterable[Tuple[str,
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
if ("attn.proj.weight_scale" in name or
"attn.proj.weight_offset" in name) and self.enable_pad:
continue
elif ("attn.proj.deq_scale" in name
or "attn.proj.quant_bias" in name) and self.enable_pad:
continue
elif ("attn.qkv.weight_scale" in name
or "attn.qkv.weight_offset" in name) and self.enable_pad:
param.data = self.pad_qkv_weight_scale_offset(param.data)
elif ("attn.qkv.deq_scale" in name
or "attn.qkv.quant_bias" in name) and self.enable_pad:
param.data = self.pad_qkv_deq_scale_quant_bias(param.data)
elif ("attn.proj.weight" in name) and self.enable_pad:
param.data = self.pad_proj_weight(param.data)
elif ("attn.qkv.weight" in name) and self.enable_pad:
param.data = self.pad_qkv_weight(param.data)
elif ("attn.qkv.bias" in name) and self.enable_pad:
param.data = self.pad_qkv_bias(param.data)
loaded_params.add(name)
return loaded_params

Expand Down
Loading
Loading