diff --git a/lmdeploy/pytorch/models/cogvlm.py b/lmdeploy/pytorch/models/cogvlm.py index 45d7a4d01d..8a522f0238 100644 --- a/lmdeploy/pytorch/models/cogvlm.py +++ b/lmdeploy/pytorch/models/cogvlm.py @@ -5,6 +5,9 @@ import torch.distributed as dist from torch import nn from transformers.modeling_outputs import BaseModelOutputWithPast +from lmdeploy.pytorch.kernels.ascend.fused_rotary_emb import fused_rotary_emb as fused_rotary_emb_ascend +from lmdeploy.pytorch.kernels.ascend.paged_attention_fwd import paged_attention_fwd as paged_attention_fwd_ascend +from lmdeploy.pytorch.kernels.ascend.fill_kv_cache import fill_kv_cache as fill_kv_cache_ascend from ..kernels import fill_kv_cache, fused_rotary_emb, paged_attention_fwd from ..weight_loader.dist_utils import (colwise_split_parallelize_linear, @@ -238,6 +241,162 @@ def forward( ) +class PatchedVisionExpertAttentionAscend(nn.Module): + + def _contiguous_batching_forward_impl( + self, + hidden_states: torch.Tensor, + token_type_ids: torch.LongTensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + world_size: int = 1, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite implementation of Attention.forward. + + Add continuous batching support. Add paged attention support. + """ + context = self.context.context + q_start_loc = context.q_start_loc + q_seq_length = context.q_seq_length + kv_seq_length = context.kv_seq_length + block_offsets = context.block_offsets + max_q_seq_length = context.max_q_seq_length + num_heads = self.config.num_attention_heads // world_size + num_kv_heads = getattr(self.config, 'num_multi_query_heads', + self.config.num_attention_heads) // world_size + + head_dim = self.config.hidden_size // self.config.num_attention_heads + hidden_size = num_heads * head_dim + only_has_language = context.is_decoding + if not context.is_decoding: + # for embedding splitting + if hasattr(context, 'vision_token_mask') and hasattr( + context, 'language_token_mask'): + vision_token_mask = context.vision_token_mask + language_token_mask = context.language_token_mask + only_has_language = vision_token_mask.numel() == 0 + else: + only_has_language = True + + def __qkv_proj(hidden_states): + """qkv_proj.""" + if only_has_language: + mixed_raw_layer = self.language_expert_query_key_value( + hidden_states) + else: + shape = list(hidden_states.shape) + shape[-1] = hidden_size + head_dim * num_kv_heads * 2 + mixed_raw_layer = torch.empty(shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + mixed_raw_layer[:, + vision_token_mask, :] = self.vision_expert_query_key_value( + hidden_states[:, vision_token_mask, :]) + mixed_raw_layer[:, + language_token_mask, :] = self.language_expert_query_key_value( + hidden_states[:, language_token_mask, :]) + query_states, key_states, value_states = torch.split( + mixed_raw_layer, [ + hidden_size, head_dim * num_kv_heads, + head_dim * num_kv_heads + ], + dim=-1) + return query_states, key_states, value_states + + def __rotary_emb_fn(query_states, key_states, value_states): + """rotary embedding func.""" + scaling_factor = getattr(self.rotary_emb, 'scaling_factor', 1.0) + inv_freq = self.rotary_emb.inv_freq + + query_states, key_states = fused_rotary_emb_ascend( + query_states[None], + key_states[None], + position_ids[None], + inv_freq=inv_freq, + scaling_factor=scaling_factor, + out_q=query_states[None], + out_k=key_states[None], + context=context) + return query_states[0], key_states[0], value_states + + query_states, key_states, value_states = __qkv_proj(hidden_states) + + query_states = query_states.view(-1, num_heads, head_dim) + key_states = key_states.view(-1, num_kv_heads, head_dim) + value_states = value_states.view(-1, num_kv_heads, head_dim) + + query_states, key_states, value_states = __rotary_emb_fn( + query_states, key_states, value_states) + + fill_kv_cache_ascend( + key_states, + value_states, + past_key_value[0], + past_key_value[1], + q_start_loc, + q_seq_length, + kv_seq_length=kv_seq_length, + max_q_seq_length=max_q_seq_length, + block_offsets=block_offsets, + context=context + ) + + context_layer = query_states + paged_attention_fwd_ascend( + query_states, + key_states, + value_states, + past_key_value[0], + past_key_value[1], + context_layer, + block_offsets, + q_start_loc=q_start_loc, + q_seqlens=q_seq_length, + kv_seqlens=kv_seq_length, + max_seqlen=max_q_seq_length, + context=context + ) + context_layer = context_layer.reshape(*hidden_states.shape[:-1], -1) + + if only_has_language: + attn_output = self.language_expert_dense(context_layer) + else: + ctx_shape = list(context_layer.shape) + ctx_shape[-1] *= world_size + attn_output = torch.empty(ctx_shape, + dtype=hidden_states.dtype, + device=hidden_states.device) + + attn_output[:, vision_token_mask, :] = self.vision_expert_dense( + context_layer[:, vision_token_mask, :]) + attn_output[:, + language_token_mask, :] = self.language_expert_dense( + context_layer[:, language_token_mask, :]) + + return attn_output, None, past_key_value + + def forward( + self, + hidden_states: torch.Tensor, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], + Optional[Tuple[torch.Tensor]]]: + """Rewrite of forward.""" + world_size = 1 + if dist.is_initialized(): + world_size = dist.get_world_size() + return self._contiguous_batching_forward_impl( + hidden_states, + position_ids=position_ids, + past_key_value=past_key_value, + world_size=world_size, + ) + + class PatchedCogVLMModel(nn.Module): def forward( diff --git a/lmdeploy/pytorch/models/module_map.py b/lmdeploy/pytorch/models/module_map.py index e0f49715b6..82ea59b600 100644 --- a/lmdeploy/pytorch/models/module_map.py +++ b/lmdeploy/pytorch/models/module_map.py @@ -391,6 +391,12 @@ f'{LMDEPLOY_PYTORCH_MODEL_PATH}.internlm2.PatchedInternLM2AttentionAscend', }) +# ascend cogvlm +ASCEND_MODULE_MAP.update({ + 'modeling_cogvlm.VisionExpertAttention': + f'{LMDEPLOY_PYTORCH_MODEL_PATH}.cogvlm.PatchedVisionExpertAttentionAscend', +}) + # ascend mixtral ASCEND_MODULE_MAP.update({ 'transformers.models.mixtral.modeling_mixtral.MixtralAttention':