Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions examples/afd/ffn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#export NCCL_SOCKET_IFNAME=eno1
#export GLOO_SOCKET_IFNAME=eno1

python fserve.py --model="/data2/models/deepseek-v2-lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --additional-config='{"role":"ffn", "afd_size":"2a2f"}'
63 changes: 63 additions & 0 deletions examples/afd/fserve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""

import re

import torch.multiprocessing as mp

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
from vllm.entrypoints.utils import cli_env_setup
from vllm.utils import (
FlexibleArgumentParser,
get_distributed_init_method,
get_ip,
get_open_port,
)
from vllm.v1.worker.gpu_worker import AFDWorker


def create_worker(
vllm_config,
rank,
distributed_init_method,
is_driver_worker: bool = True,
):
worker = AFDWorker(
vllm_config=vllm_config,
local_rank=rank,
rank=rank,
distributed_init_method=distributed_init_method,
is_driver_worker=is_driver_worker,
)

worker.init_device()
worker.load_model()
print("ffn worker instantiated")
worker.model_runner.execute_model()


if __name__ == "__main__":
cli_env_setup()
mp.set_start_method("spawn")
parser = FlexibleArgumentParser(description="vLLM AFD FFN server.")
parser = make_arg_parser(parser)
args = parser.parse_args()
validate_parsed_serve_args(args)
engine_args = EngineArgs.from_cli_args(args)
vllm_config = engine_args.create_engine_config()
afd_size = vllm_config.additional_config.get("afd_size")
if afd_size is None or afd_size == "":
raise ValueError("Afd size must be specified in additional_config")

attn_size, ffn_size = map(int, re.match(r"(\d+)\D+(\d+)", afd_size).groups())
distributed_init_method = get_distributed_init_method(get_ip(), get_open_port())

processes = []
for rank in range(ffn_size):
p = mp.Process(
target=create_worker, args=(vllm_config, rank, distributed_init_method)
)
processes.append(p)
p.start()
21 changes: 21 additions & 0 deletions examples/afd/offline_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import LLM, SamplingParams
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--temperature", type=float, default=0.8)
parser.add_argument("--top_p", type=float, default=0.95)
args = parser.parse_args()
sampling_params = SamplingParams(temperature=args.temperature, top_p=args.top_p)


prompts = [
"1 3 5 7 9",
]

sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
llm = LLM(
model="/data2/models/deepseek-v2-lite",
enforce_eager=True,
additional_config={"role": "attn"},
)

outputs = llm.generate(prompts, sampling_params)

for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"prompt{prompt!r}, generated text: {generated_text!r}")
4 changes: 4 additions & 0 deletions examples/afd/online_attn.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#export NCCL_SOCKET_IFNAME=eno1
#export GLOO_SOCKET_IFNAME=eno1

vllm serve /data2/models/deepseek-v2-lite --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --additional-config='{"role":"attn", "afd_size":"2a2f"}'
120 changes: 120 additions & 0 deletions examples/afd/readme.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
## AFD Demo Readme

本 Demo 展示了如何将 Transformer 模型中的 Attention 层与 FFN(MoE)层解耦,分别部署在不同进程甚至不同机器上,实现分布式推理。

---

### 环境准备

#### 1. 克隆并切换到对应分支
```bash
git clone https://github.com/hsliuustc0106/vllm.git
cd vllm
git fetch origin pull/12/head:afd-demo
git checkout afd-demo
```

#### 2. 安装依赖
```bash
pip install -r requirements.txt
pip install -e .
```

### 启动步骤

#### Step 1:启动 FFN 服务(MoE 层)

以2A2F配置为例,运行以下命令启动 FFN 服务(负责 MoE 层计算):

```bash
export NCCL_SOCKET_IFNAME=eno1 # 在跨机执行时需要配置NCCL和GLOO使用的网卡
export GLOO_SOCKET_IFNAME=eno1

export MASTER_IP=<master_ip> # 在跨机执行时需要配置master节点的ip和端口信息
export MASTER_PORT=<master_port>

export CUDA_VISIBLE_DEVICES=0,1
python fserve.py --model="/home/models/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --additional-config='{"role":"ffn", "afd_size":"2A2F"}'
```

> 说明:
- 通过role来指定进程角色。
- afd_size指的是attn和ffn分别使用的卡数。符合xAyF的格式。



---

#### Step 2 启动 Attention (online_attn.sh)

若要与 FFN 服务通信,需启动在线 Attention 服务:
```bash
#!/bin/bash
export NCCL_SOCKET_IFNAME=eno1 # 在跨机执行时需要配置NCCL和GLOO使用的网卡
export GLOO_SOCKET_IFNAME=eno1

export MASTER_IP=<master_ip> # 在跨机执行时需要配置master节点的ip和端口信息
export MASTER_PORT=<master_port>

export CUDA_VISIBLE_DEVICES=0,1
vllm serve /data2/models/deepseek-v2-lite --enforce_eager --additional-config='{"role":"attn", "afd_size":"2A2F"}'

```
> 说明:
- 通过role来指定进程角色。
- 该服务会将 Attention 输出通过 `afd_connector` 发送给 FFN 服务,并接收其返回结果。
- 确保 `fserve.py` 已启动。


---

### 流程概览

```text
Input Prompt
online_attn.sh (Attention服务)
Attention Layer Output
AFD_CONNECTOR.send_attn_output()
ffn_start.py(FFN服务)
MoE Layer Output
AFD_CONNECTOR.recv_ffn_output()
Final Output (online_attn.sh)
```

---

### 验证是否成功

#### 检查日志输出
日志中出现以下内容说明成功拉起服务:
```plain
(APIServer pid=73628) INFO: Started server process [73628]
(APIServer pid=73628) INFO: Waiting for application startup.
(APIServer pid=73628) INFO: Application startup complete.

```

#### 测试请求(在线模式)

使用 curl 或浏览器访问:
```bash
curl -v http://0.0.0.0:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d \
'{ "model": "/data2/models/deepseek-v2-lite",
"messages": [
{"role": "user", "content": "1 3 5 7 9"} ],
"temperature": 0.6,
"repetition_penalty": 1.0,
"top_p": 0.95,
"top_k": 40,
"max_tokens": 20,
"stream": false}'
```
12 changes: 12 additions & 0 deletions examples/afd/request.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
curl -v http://0.0.0.0:8000/v1/chat/completions \
-H 'Content-Type: application/json' \
-d \
'{ "model": "/data2/models/deepseek-v2-lite",
"messages": [
{"role": "user", "content": "1 3 5 7 9"} ],
"temperature": 0.6,
"repetition_penalty": 1.0,
"top_p": 0.95,
"top_k": 40,
"max_tokens": 20,
"stream": false}'
89 changes: 89 additions & 0 deletions vllm/distributed/afd/afd_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Optional

import torch
from torch.distributed import ProcessGroup

from vllm.sequence import IntermediateTensors


@dataclass
class AFDConnectorMetadata:
layer_idx: int # Layer index for computation
stage_idx: int # Pipeline stage index
seq_lens: list[int] # Sequence lengths for each request
dtype: torch.dtype # Tensor data type
device: torch.device # Compute device
request_id: Optional[str] # Request identifier
timestamp: Optional[float] # Timestamp for debugging
group: ProcessGroup # communication domain
topk_idx: Optional[torch.Tensor] # indices token which expert to be sended
topk_weights: Optional[torch.Tensor] # the expert weights
moe_expert_num: Optional[int] # number of moe experts
shared_expert_num: Optional[int] # number of share experts
handle: Optional[
torch.Tensor
] # the communication handle given by the recv_attn_output function


class AFDConnectorBase(ABC):
def __init__(self, process_group) -> None:
super().__init__()
self.process_group = process_group

# -------------------------------------------------------------------
# attn -> ffn
# -------------------------------------------------------------------
@abstractmethod
def send_attn_output(
self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata
):
"""
This method will be called by the ATTN side.


* To send the intermediate tensors generated by ATTN instances to FFN.
"""
raise NotImplementedError

@abstractmethod
def recv_attn_output(self) -> torch.Tensor:
"""
This method will be called by the FFN side.


* To receive the intermediate tensors from ATTN.
* And (Maybe) dispatch them from the receiver to other GPUs.
"""
raise NotImplementedError

# -------------------------------------------------------------------------
# attn <- ffn
# -------------------------------------------------------------------------
@abstractmethod
def send_ffn_output(
self, hidden_states: torch.Tensor, metadata: AFDConnectorMetadata
):
"""
This method will be called by the FFN side.


* To send the intermediate tensors generated by FFN instances back to
the sender (this should be the same GPU as it comes from)
"""
raise NotImplementedError

@abstractmethod
def recv_ffn_output(self) -> torch.Tensor:
"""
This method will be called by the ATTN side.


* To receive the MOE output intermediate tensors.
* And (Maybe) dispatch them from the receiver to other GPUs.
(this should be the same GPU as it comes from)
"""
raise NotImplementedError
Loading