diff --git a/examples/afd/ffn.sh b/examples/afd/ffn.sh new file mode 100644 index 000000000000..a0c03f4f78b9 --- /dev/null +++ b/examples/afd/ffn.sh @@ -0,0 +1,4 @@ +#export NCCL_SOCKET_IFNAME=eno1 +#export GLOO_SOCKET_IFNAME=eno1 + +python fserve.py --model="/data2/models/deepseek-v2-lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --additional-config='{"role":"ffn", "afd_size":"2a2f"}' diff --git a/examples/afd/fserve.py b/examples/afd/fserve.py new file mode 100644 index 000000000000..4d0730b16ba5 --- /dev/null +++ b/examples/afd/fserve.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""A GPU worker class.""" + +import re + +import torch.multiprocessing as mp + +from vllm.engine.arg_utils import EngineArgs +from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args +from vllm.entrypoints.utils import cli_env_setup +from vllm.utils import ( + FlexibleArgumentParser, + get_distributed_init_method, + get_ip, + get_open_port, +) +from vllm.v1.worker.gpu_worker import AFDWorker + + +def create_worker( + vllm_config, + rank, + distributed_init_method, + is_driver_worker: bool = True, +): + worker = AFDWorker( + vllm_config=vllm_config, + local_rank=rank, + rank=rank, + distributed_init_method=distributed_init_method, + is_driver_worker=is_driver_worker, + ) + + worker.init_device() + worker.load_model() + print("ffn worker instantiated") + worker.model_runner.execute_model() + + +if __name__ == "__main__": + cli_env_setup() + mp.set_start_method("spawn") + parser = FlexibleArgumentParser(description="vLLM AFD FFN server.") + parser = make_arg_parser(parser) + args = parser.parse_args() + validate_parsed_serve_args(args) + engine_args = EngineArgs.from_cli_args(args) + vllm_config = engine_args.create_engine_config() + afd_size = vllm_config.additional_config.get("afd_size") + if afd_size is None or afd_size == "": + raise ValueError("Afd size must be specified in additional_config") + + attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups()) + distributed_init_method = get_distributed_init_method(get_ip(), get_open_port()) + + processes = [] + for rank in range(ffn_size): + p = mp.Process( + target=create_worker, args=(vllm_config, rank, distributed_init_method) + ) + processes.append(p) + p.start() diff --git a/examples/afd/offline_attn.py b/examples/afd/offline_attn.py new file mode 100644 index 000000000000..d424383fec80 --- /dev/null +++ b/examples/afd/offline_attn.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from vllm import LLM, SamplingParams + +prompts = [ + "1 3 5 7 9", +] + +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) +llm = LLM( + model="/data2/models/deepseek-v2-lite", + enforce_eager=True, + additional_config={"role": "attn"}, +) + +outputs = llm.generate(prompts, sampling_params) + +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"prompt{prompt!r}, generated text: {generated_text!r}") diff --git a/examples/afd/online_attn.sh b/examples/afd/online_attn.sh new file mode 100644 index 000000000000..0a3c39dea6dd --- /dev/null +++ b/examples/afd/online_attn.sh @@ -0,0 +1,4 @@ +#export NCCL_SOCKET_IFNAME=eno1 +#export GLOO_SOCKET_IFNAME=eno1 + +vllm serve /data2/models/deepseek-v2-lite --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --additional-config='{"role":"attn", "afd_size":"2a2f"}' diff --git a/examples/afd/readme.md b/examples/afd/readme.md new file mode 100644 index 000000000000..3712d3fe94a9 --- /dev/null +++ b/examples/afd/readme.md @@ -0,0 +1,120 @@ +## AFD Demo Readme + +本 Demo 展示了如何将 Transformer 模型中的 Attention 层与 FFN(MoE)层解耦,分别部署在不同进程甚至不同机器上,实现分布式推理。 + +--- + +### 环境准备 + +#### 1. 克隆并切换到对应分支 +```bash +git clone https://github.com/hsliuustc0106/vllm.git +cd vllm +git fetch origin pull/12/head:afd-demo +git checkout afd-demo +``` + +#### 2. 安装依赖 +```bash +pip install -r requirements.txt +pip install -e . +``` + +### 启动步骤 + +#### Step 1:启动 FFN 服务(MoE 层) + +以2A2F配置为例,运行以下命令启动 FFN 服务(负责 MoE 层计算): + +```bash +export NCCL_SOCKET_IFNAME=eno1 # 在跨机执行时需要配置NCCL和GLOO使用的网卡 +export GLOO_SOCKET_IFNAME=eno1 + +export MASTER_IP= # 在跨机执行时需要配置master节点的ip和端口信息 +export MASTER_PORT= + +export CUDA_VISIBLE_DEVICES=0,1 +python fserve.py --model="/home/models/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --additional-config='{"role":"ffn", "afd_size":"2A2F"}' +``` + +> 说明: +- 通过role来指定进程角色。 +- afd_size指的是attn和ffn分别使用的卡数。符合xAyF的格式。 + + + +--- + +#### Step 2 启动 Attention (online_attn.sh) + +若要与 FFN 服务通信,需启动在线 Attention 服务: +```bash +#!/bin/bash +export NCCL_SOCKET_IFNAME=eno1 # 在跨机执行时需要配置NCCL和GLOO使用的网卡 +export GLOO_SOCKET_IFNAME=eno1 + +export MASTER_IP= # 在跨机执行时需要配置master节点的ip和端口信息 +export MASTER_PORT= + +export CUDA_VISIBLE_DEVICES=0,1 +vllm serve /data2/models/deepseek-v2-lite --enforce_eager --additional-config='{"role":"attn", "afd_size":"2A2F"}' + +``` +> 说明: +- 通过role来指定进程角色。 +- 该服务会将 Attention 输出通过 `afd_connector` 发送给 FFN 服务,并接收其返回结果。 +- 确保 `fserve.py` 已启动。 + + +--- + +### 流程概览 + +```text +Input Prompt + ↓ +online_attn.sh (Attention服务) + ↓ +Attention Layer Output + ↓ +AFD_CONNECTOR.send_attn_output() + ↓ +ffn_start.py(FFN服务) + ↓ +MoE Layer Output + ↓ +AFD_CONNECTOR.recv_ffn_output() + ↓ +Final Output (online_attn.sh) +``` + +--- + +### 验证是否成功 + +#### 检查日志输出 +日志中出现以下内容说明成功拉起服务: +```plain +(APIServer pid=73628) INFO: Started server process [73628] +(APIServer pid=73628) INFO: Waiting for application startup. +(APIServer pid=73628) INFO: Application startup complete. + +``` + +#### 测试请求(在线模式) + +使用 curl 或浏览器访问: +```bash +curl -v http://0.0.0.0:8000/v1/chat/completions \ +-H 'Content-Type: application/json' \ +-d \ +'{ "model": "/data2/models/deepseek-v2-lite", +"messages": [ +{"role": "user", "content": "1 3 5 7 9"} ], +"temperature": 0.6, +"repetition_penalty": 1.0, +"top_p": 0.95, +"top_k": 40, +"max_tokens": 20, +"stream": false}' +``` diff --git a/examples/afd/request.sh b/examples/afd/request.sh new file mode 100644 index 000000000000..f0adf6fcff58 --- /dev/null +++ b/examples/afd/request.sh @@ -0,0 +1,12 @@ +curl -v http://0.0.0.0:8000/v1/chat/completions \ +-H 'Content-Type: application/json' \ +-d \ +'{ "model": "/data2/models/deepseek-v2-lite", +"messages": [ +{"role": "user", "content": "1 3 5 7 9"} ], +"temperature": 0.6, +"repetition_penalty": 1.0, +"top_p": 0.95, +"top_k": 40, +"max_tokens": 20, +"stream": false}' diff --git a/vllm/distributed/afd/afd_connector.py b/vllm/distributed/afd/afd_connector.py new file mode 100644 index 000000000000..860b845c9e56 --- /dev/null +++ b/vllm/distributed/afd/afd_connector.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +import torch +from torch.distributed import ProcessGroup + +from vllm.sequence import IntermediateTensors + + +@dataclass +class AFDConnectorMetadata: + layer_idx: int # Layer index for computation + stage_idx: int # Pipeline stage index + seq_lens: list[int] # Sequence lengths for each request + dtype: torch.dtype # Tensor data type + device: torch.device # Compute device + request_id: Optional[str] # Request identifier + timestamp: Optional[float] # Timestamp for debugging + group: ProcessGroup # communication domain + topk_idx: Optional[torch.Tensor] # indices token which expert to be sended + topk_weights: Optional[torch.Tensor] # the expert weights + moe_expert_num: Optional[int] # number of moe experts + shared_expert_num: Optional[int] # number of share experts + handle: Optional[ + torch.Tensor + ] # the communication handle given by the recv_attn_output function + + +class AFDConnectorBase(ABC): + def __init__(self, process_group) -> None: + super().__init__() + self.process_group = process_group + + # ------------------------------------------------------------------- + # attn -> ffn + # ------------------------------------------------------------------- + @abstractmethod + def send_attn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ): + """ + This method will be called by the ATTN side. + + + * To send the intermediate tensors generated by ATTN instances to FFN. + """ + raise NotImplementedError + + @abstractmethod + def recv_attn_output(self) -> torch.Tensor: + """ + This method will be called by the FFN side. + + + * To receive the intermediate tensors from ATTN. + * And (Maybe) dispatch them from the receiver to other GPUs. + """ + raise NotImplementedError + + # ------------------------------------------------------------------------- + # attn <- ffn + # ------------------------------------------------------------------------- + @abstractmethod + def send_ffn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ): + """ + This method will be called by the FFN side. + + + * To send the intermediate tensors generated by FFN instances back to + the sender (this should be the same GPU as it comes from) + """ + raise NotImplementedError + + @abstractmethod + def recv_ffn_output(self) -> torch.Tensor: + """ + This method will be called by the ATTN side. + + + * To receive the MOE output intermediate tensors. + * And (Maybe) dispatch them from the receiver to other GPUs. + (this should be the same GPU as it comes from) + """ + raise NotImplementedError diff --git a/vllm/distributed/afd/p2p_connector.py b/vllm/distributed/afd/p2p_connector.py new file mode 100644 index 000000000000..60e4f6d4d915 --- /dev/null +++ b/vllm/distributed/afd/p2p_connector.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch + + +from vllm.distributed.afd.afd_connector import ( + AFDConnectorBase, + AFDConnectorMetadata, +) +from vllm.sequence import IntermediateTensors + + +class P2PConnector(AFDConnectorBase): + def __init__(self, process_group) -> None: + super().__init__(process_group) + self.process_group = process_group + + def send_attn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ): + """ + This method will be called by the ATTN side. + + + * To send the intermediate tensors generated by ATTN instances to FFN. + """ + + intermediate_tensors = IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) + try: + self.process_group.send_tensor_dict( + intermediate_tensors.tensors, + all_gather_group=None, + ) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_attn_output(self) -> IntermediateTensors: + """ + This method will be called by the FFN side. + + + * To receive the intermediate tensors from ATTN. + * And (Maybe) dispatch them from the receiver to other GPUs. + """ + intermediate_tensors = self.process_group.recv_tensor_dict( + all_gather_group=None, + ) + return intermediate_tensors["hidden_states"] + + # ------------------------------------------------------------------------- + # attn <- ffn + # ------------------------------------------------------------------------- + def send_ffn_output( + self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata + ): + """ + This method will be called by the FFN side. + + + * To send the intermediate tensors generated by FFN instances back to + the sender (this should be the same GPU as it comes from) + """ + intermediate_tensors = IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) + self.process_group.send_tensor_dict( + intermediate_tensors.tensors, + ) + + def recv_ffn_output(self) -> torch.Tensor: + """ + This method will be called by the ATTN side. + + + * To receive the MOE output intermediate tensors. + * And (Maybe) dispatch them from the receiver to other GPUs. + (this should be the same GPU as it comes from) + """ + intermediate_tensors = self.process_group.recv_tensor_dict( + all_gather_group=None, + ) + return intermediate_tensors["hidden_states"] diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index b89aee99c8d4..1dc9010198b7 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -29,6 +29,7 @@ from collections import namedtuple from contextlib import contextmanager, nullcontext from dataclasses import dataclass +from datetime import timedelta from multiprocessing import shared_memory from typing import Any, Callable, Optional, Union from unittest.mock import patch @@ -36,9 +37,14 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from torch.distributed.distributed_c10d import (PrefixStore, Store, + _new_process_group_helper, + _update_default_pg, _world, + default_pg_timeout, rendezvous) from typing_extensions import deprecated import vllm.envs as envs +from vllm.distributed.afd.afd_connector import AFDConnectorBase from vllm.distributed.device_communicators.base_device_communicator import ( DeviceCommunicatorBase) from vllm.distributed.utils import StatelessProcessGroup @@ -871,6 +877,68 @@ def init_world_group(ranks: list[int], local_rank: int, ) +def init_afd_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = None, + pg_options: Optional[Any] = None, +): + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store.") + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + backend = Backend(backend) if backend else Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + if store is None: + rendezvous_iterator = rendezvous(init_method, + rank, + world_size, + timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + store = PrefixStore(group_name, store) + + pg_options_param_name = ("backend_options" if str(torch.__version__) + >= "2.6" else "pg_options") + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + return pg + + +class DefaultProcessGroupSwitcher: + + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + def init_model_parallel_group( group_ranks: list[list[int]], local_rank: int, @@ -916,12 +984,19 @@ def get_dp_group() -> GroupCoordinator: _EP: Optional[GroupCoordinator] = None +_AFD_CONNECTOR: Optional[AFDConnectorBase] = None + def get_ep_group() -> GroupCoordinator: assert _EP is not None, ("expert parallel group is not initialized") return _EP +def get_afd_connector() -> AFDConnectorBase: + assert _AFD_CONNECTOR is not None, ("afd is not initialized") + return _AFD_CONNECTOR + + def get_pp_group() -> GroupCoordinator: assert _PP is not None, ( "pipeline model parallel group is not initialized") @@ -973,7 +1048,7 @@ def init_distributed_environment( local_rank: int = -1, backend: str = "nccl", ): - logger.debug( + logger.info( "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s", world_size, rank, local_rank, distributed_init_method, backend) diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 2b8e4427591c..9b3b1f0e3073 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -267,6 +267,6 @@ def load_weights(self, model: nn.Module, # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights - if weights_not_loaded: - raise ValueError("Following weights were not initialized from " - f"checkpoint: {weights_not_loaded}") + # if weights_not_loaded: + # raise ValueError("Following weights were not initialized from " + # f"checkpoint: {weights_not_loaded}") diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f199da135ec7..7627dfbeeeab 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -35,8 +35,9 @@ from vllm.compilation.decorators import support_torch_compile from vllm.config import (CacheConfig, ModelConfig, VllmConfig, get_current_vllm_config) -from vllm.distributed import (get_ep_group, get_pp_group, +from vllm.distributed import (get_afd_connector, get_ep_group, get_pp_group, get_tensor_model_parallel_world_size) +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -60,6 +61,8 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +#logger = init_logger(__name__) + class DeepseekV2MLP(nn.Module): @@ -536,21 +539,24 @@ def forward( class DeepseekV2DecoderLayer(nn.Module): - def __init__( - self, - config: Union[DeepseekV2Config, DeepseekV3Config], - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - enable_eplb: bool = False, - ) -> None: + def __init__(self, + config: Union[DeepseekV2Config, DeepseekV3Config], + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + enable_eplb: bool = False, + role: str = None) -> None: super().__init__() + #logger.info("*" * 50) + #logger.info("decoder init") + self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.role = role # DecoderLayers are created with `make_layers` which passes the prefix # with the layer's index. layer_idx = int(prefix.split(sep='.')[-1]) @@ -559,53 +565,72 @@ def __init__( attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV2MoE( + if self.role is None or role == "attn": + self.self_attn = attn_cls( config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank if hasattr( + config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.mlp", - enable_eplb=enable_eplb, - ) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=f"{prefix}.self_attn", ) + if self.role is None or role == "ffn": + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + enable_eplb=enable_eplb, + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.routed_scaling_factor = config.routed_scaling_factor + def forward_ffn(self): + assert self.role == "ffn" + #logger.info(f"ffn decoder layer {self.layer_idx} forwarding") + afd_connector = get_afd_connector() + hidden_states = afd_connector.recv_attn_output() + hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + + afd_connector.send_ffn_output(hidden_states, None) + def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: Optional[torch.Tensor], ) -> torch.Tensor: + # Self Attention if residual is None: residual = hidden_states @@ -613,34 +638,42 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - - if hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # We scale both hidden_states and residual before - # rmsnorm, and rmsnorm result would not affect by scale. - hidden_states *= 1. / self.routed_scaling_factor - if self.layer_idx == 0: - # The residual is shared by all layers, we only scale it on - # first layer. - residual *= 1. / self.routed_scaling_factor - - # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) - hidden_states = self.mlp(hidden_states) + if self.role is not None: # This statement should make sense no matter AE is on/off + if self.role == "attn": + #logger.info(f"attn decoder {self.layer_idx} forwarding") + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + afd_connector = get_afd_connector() + afd_connector.send_attn_output(hidden_states, None) + hidden_states = afd_connector.recv_ffn_output() - if isinstance(self.mlp, - DeepseekV2MLP) and hidden_states.dtype == torch.float16: - # Fix FP16 overflow - # Scaling the DeepseekV2MLP output, it is the input of - # input_layernorm of next decoder layer. - # The scaling of DeepseekV2MOE output would be done in the forward - # of DeepseekV2MOE - hidden_states *= 1. / self.routed_scaling_factor + else: + hidden_states = self.mlp(hidden_states) + #logger.info("ffn forwarding") + if isinstance(self.mlp, DeepseekV2MLP + ) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor return hidden_states, residual @@ -661,6 +694,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.vocab_size = config.vocab_size + self.role = vllm_config.additional_config.get("role", None) if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -673,14 +707,13 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer( - config, - prefix, - model_config=model_config, - cache_config=cache_config, - quant_config=quant_config, - enable_eplb=enable_eplb, - ), + lambda prefix: DeepseekV2DecoderLayer(config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + enable_eplb=enable_eplb, + role=self.role), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: @@ -736,6 +769,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config + self.role = vllm_config.additional_config.get("role", None) # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to @@ -774,11 +808,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): continue assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): + if (self.role is None or self.role == "ffn") and isinstance( + layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) + if self.role == "attn": + return + if example_moe is None: raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") @@ -826,6 +864,12 @@ def update_physical_experts_metadata( def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.get_input_embeddings(input_ids) + def forward_ffn(self): + assert self.role == "ffn" + #logger.info("forwarding ffn") + for layer in self.model.layers[:]: + layer.forward_ffn() + def forward( self, input_ids: torch.Tensor, @@ -858,12 +902,14 @@ def load_weights(self, weights: Iterable[tuple[str, # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) + vllm_config = get_current_vllm_config() expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=vllm_config.parallel_config. + num_redundant_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() @@ -871,6 +917,9 @@ def load_weights(self, weights: Iterable[tuple[str, if "rotary_emb.inv_freq" in name: continue + if self.role == "attn" and self.is_moe(name): + continue + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model @@ -918,6 +967,8 @@ def load_weights(self, weights: Iterable[tuple[str, # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True + if self.role is not None and self.role == "attn": + continue # Do not modify `name` since the loop may continue here # Instead, create a new variable @@ -942,6 +993,9 @@ def load_weights(self, weights: Iterable[tuple[str, name = name_mapped break else: + if self.role == "ffn" and not self.is_moe( + name) and not self.is_common(name): + continue if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank @@ -968,6 +1022,19 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params + def is_moe(self, name): + if "shared_experts" in name or "experts" in name or "gate" in name \ + or "up" in name or "down" in name: + return True + return False + + def is_common(self, name): + if "lm_head" in name or "model.norm.weight" in name or "embed_tokens" in name \ + or "input_layernorm" in name or "post_attention_layernorm" in name: + # or "model.layers.0.self_attn.o_proj.weight" in name:# for init kv cache + return True + return False + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8fb9641844fb..fe798bd2a036 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3149,3 +3149,27 @@ def _build_encoder_only_attn_metadata( group_metadata[layer_name] = (common_metadata, metadata) return group_metadata + + +class FFNModelRunner(GPUModelRunner): + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + + super().__init__(vllm_config=vllm_config, device=device) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + + def execute_model(self): + print('ffn forward begain') + # TODO: use event replace + with set_forward_context( + None, + self.vllm_config, + #num_tokens=num_input_tokens, + #num_tokens_across_dp=num_tokens_across_dp, + #skip_cuda_graphs=skip_cuda_graphs, + ): + while True: + layers_num = len(self.model.model.layers) + for i in range(layers_num): + self.model.model.layers[i].forward_ffn() diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 84f065f25f2e..dbdcafb8e549 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -4,20 +4,27 @@ import copy import gc import os +import re from contextlib import AbstractContextManager, nullcontext +from datetime import timedelta from typing import TYPE_CHECKING, Any, Optional import torch import torch.distributed import torch.nn as nn +from torch.distributed.distributed_c10d import _get_default_group +import vllm.distributed.parallel_state as ps import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, + init_afd_process_group, init_distributed_environment, - set_custom_all_reduce) + init_model_parallel_group, set_custom_all_reduce) +from vllm.distributed.afd.p2p_connector import P2PConnector from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.distributed.parallel_state import (DefaultProcessGroupSwitcher, + get_pp_group, get_tp_group) from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed @@ -30,7 +37,7 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats -from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.v1.worker.gpu_model_runner import FFNModelRunner, GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -56,7 +63,6 @@ def __init__( rank=rank, distributed_init_method=distributed_init_method, is_driver_worker=is_driver_worker) - if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -188,6 +194,7 @@ def init_device(self): else: raise RuntimeError( f"Not support device type: {self.device_config.device}") + # Initialize the distributed environment. init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, @@ -197,8 +204,13 @@ def init_device(self): set_random_seed(self.model_config.seed) # Construct the model runner - self.model_runner: GPUModelRunner = GPUModelRunner( - self.vllm_config, self.device) + self.model_runner: FFNModelRunner | GPUModelRunner + if self.vllm_config.additional_config.get("role") == "ffn": + self.model_runner = FFNModelRunner( + self.vllm_config, self.device) + else: + self.model_runner = GPUModelRunner( + self.vllm_config, self.device) if self.rank == 0: # If usage stat is enabled, collect relevant info. @@ -364,6 +376,9 @@ def execute_model( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) + logger.info("-" * 50) + logger.info(f"scheduler_output: {scheduler_output}") + logger.info(f"intermediate_tensors: {intermediate_tensors}") output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) @@ -597,6 +612,58 @@ def save_tensorized_model( tensorizer_config=tensorizer_config, ) +class AFDWorker(Worker): + + def init_device(self): + super().init_device() + + role = self.vllm_config.additional_config.get("role", None) + logger.info("AFD worker building") + + afd_size = self.vllm_config.additional_config.get("afd_size") + attn_size, ffn_size = map( + int, + re.match(r"(\d+)\D+(\d+)", afd_size).groups()) + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + world_rank = self.rank if role == "attn" else self.rank + attn_size + logger.info( + f"world_size = {ffn_size + attn_size}, world_rank = {world_rank}") + + ip = ( + os.environ["MASTER_IP"] + if os.environ["MASTER_IP"] is not None + else "127.0.0.1" + ) + port = ( + os.environ["MASTER_PORT"] + if os.environ["MASTER_PORT"] is not None + else "29500" + ) + afd_pg = init_afd_process_group( + backend="nccl", + init_method=f"tcp://{ip}:{port}", + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + + default_pg_switcher = DefaultProcessGroupSwitcher( + _get_default_group(), afd_pg) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = list([attn_ranks[i], ffn_ranks[i]]) + sub_group_ranks.append(ranks) + ae_group = init_model_parallel_group(sub_group_ranks, + self.rank, + backend="nccl", + group_name="ae") + + ps._AFD_CONNECTOR = P2PConnector(ae_group) + + def init_worker_distributed_environment( vllm_config: VllmConfig, rank: int, diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index f1c9a0ab001e..a0bfd1c83ee0 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -555,6 +555,8 @@ def init_worker(self, all_kwargs: List[Dict[str, Any]]) -> None: if isinstance(self.vllm_config.parallel_config.worker_cls, str): worker_class = resolve_obj_by_qualname( self.vllm_config.parallel_config.worker_cls) + worker_class = resolve_obj_by_qualname( + "vllm.v1.worker.gpu_worker.AFDWorker") else: logger.warning( "passing worker_cls as a class object is strongly deprecated,"