From a86f24b96648fee6d9ac71cc80895a69e287a7cb Mon Sep 17 00:00:00 2001 From: hsliu Date: Tue, 30 Sep 2025 11:58:09 +0800 Subject: [PATCH] AFD: update DeepSeek support --- benchmarks/kernels/benchmark_moe.py | 2 +- .../afd/deepseek-v2-lite/readme.md | 16 + examples/online_serving/afd/step3/README.md | 29 ++ vllm/attention/layer.py | 43 +- vllm/config/__init__.py | 75 ++++ vllm/distributed/afd_transfer/__init__.py | 13 + .../afd_transfer/afd_connector/__init__.py | 13 + .../afd_transfer/afd_connector/base.py | 139 ++++++ .../afd_connector/dummy_connector.py | 214 +++++++++ .../afd_transfer/afd_connector/factory.py | 96 ++++ .../afd_transfer/afd_connector/metadata.py | 163 +++++++ .../afd_connector/p2p_connector.py | 205 +++++++++ .../afd_connector/stepmesh_connector.py | 420 ++++++++++++++++++ vllm/distributed/parallel_state.py | 53 +++ vllm/engine/arg_utils.py | 16 +- vllm/entrypoints/afd_ffn_server.py | 91 ++++ vllm/entrypoints/cli/fserver.py | 51 +++ vllm/entrypoints/cli/main.py | 2 + vllm/forward_context.py | 62 ++- ...vice_name=NVIDIA_GB200,dtype=fp8_w8a8.json | 147 ++++++ ...vice_name=NVIDIA_GB200,dtype=fp8_w8a8.json | 146 ++++++ ...vice_name=NVIDIA_GB200,dtype=fp8_w8a8.json | 146 ++++++ .../layers/fused_moe/fused_moe.py | 5 +- vllm/model_executor/models/deepseek_v2.py | 227 ++++++++-- vllm/model_executor/models/step3_text.py | 259 +++++++++-- vllm/model_executor/models/step3_vl.py | 8 + vllm/triton_utils/importing.py | 2 +- vllm/v1/attention/backends/flash_attn.py | 234 +++++++++- vllm/v1/attention/backends/utils.py | 17 +- vllm/v1/worker/gpu_ffn_model_runner.py | 399 +++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 277 ++++++++++-- vllm/v1/worker/gpu_worker.py | 69 ++- 32 files changed, 3509 insertions(+), 130 deletions(-) create mode 100644 examples/online_serving/afd/deepseek-v2-lite/readme.md create mode 100644 examples/online_serving/afd/step3/README.md create mode 100644 vllm/distributed/afd_transfer/__init__.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/__init__.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/base.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/dummy_connector.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/factory.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/metadata.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/p2p_connector.py create mode 100644 vllm/distributed/afd_transfer/afd_connector/stepmesh_connector.py create mode 100644 vllm/entrypoints/afd_ffn_server.py create mode 100644 vllm/entrypoints/cli/fserver.py create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json create mode 100644 vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json create mode 100644 vllm/v1/worker/gpu_ffn_model_runner.py diff --git a/benchmarks/kernels/benchmark_moe.py b/benchmarks/kernels/benchmark_moe.py index 94f3f1ae11f2..837b2b0c1044 100644 --- a/benchmarks/kernels/benchmark_moe.py +++ b/benchmarks/kernels/benchmark_moe.py @@ -560,7 +560,7 @@ def save_configs( filename = os.path.join(save_dir, filename) print(f"Writing best config to {filename}...") with open(filename, "w") as f: - json.dump(configs, f, indent=4) + json.dump({"triton_version": triton.__version__, **configs}, f, indent=4) f.write("\n") diff --git a/examples/online_serving/afd/deepseek-v2-lite/readme.md b/examples/online_serving/afd/deepseek-v2-lite/readme.md new file mode 100644 index 000000000000..3216670b01ee --- /dev/null +++ b/examples/online_serving/afd/deepseek-v2-lite/readme.md @@ -0,0 +1,16 @@ +# P2P Connector +P2P connector is used for testing the afd implementation for deepseek-v2-lite models. It uses torch.distributed to send/recv intermediate tensors between attn and ffn instances. + +1. Attn + +``` +vllm serve "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "afd_role": "attention", "afd_host":"127.0.0.1", "afd_port":"29500","num_afd_stages":"1","afd_extra_config":{"afd_size":"2A2F"}}' + +``` + +2. FFN + +``` +vllm fserver "/path/to/DeepSeek-V2-Lite" --tensor_parallel_size=2 --enable_expert_parallel --enforce_eager --afd-config '{"afd_connector":"p2pconnector", "num_afd_stages":"1", "afd_role": "ffn", "afd_host":"127.0.0.1", "afd_port":"29500", "afd_extra_config":{"afd_size":"2A2F"}}' +``` + diff --git a/examples/online_serving/afd/step3/README.md b/examples/online_serving/afd/step3/README.md new file mode 100644 index 000000000000..881171ec30cd --- /dev/null +++ b/examples/online_serving/afd/step3/README.md @@ -0,0 +1,29 @@ +# Dummy Connector +Dummy connector is used for testing basic functions. Attn and FFN server would not be connected as dummy connector would intermediately return input tensors. + +1. Attn + +``` +vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` + +2. FFN + +``` +vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "dummy", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` + +# StepMesh Connector +StepMesh connector is used for production deployment. Make sure [stepmesh](https://github.com/stepfun-ai/StepMesh) is installed and `afd_host` and `afd_port` are correctly set. + +1. Attn + +``` +vllm fserver /path/step3v -dp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "attention", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` + +2. FFN + +``` +vllm fserver /path/step3v -tp 8 --afd-config '{"afd_connector": "stepmesh", "afd_role": "ffn", "afd_host": "127.0.0.0"}' --max-num-batched-tokens 384 --max-num-seqs 384 --compilation-config '{"cudagraph_capture_sizes": [1, 8]}' +``` \ No newline at end of file diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 22dc6dcbc8d6..4cfbe55c0c32 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -294,6 +294,13 @@ def forward( attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] + afd_metadata = forward_context.afd_metadata + if afd_metadata is not None: + afd_stage_idx = afd_metadata.afd_stage_idx + if afd_stage_idx < len(attn_metadata): + attn_metadata = attn_metadata[afd_stage_idx] + else: + attn_metadata = None # padding self_kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, query, @@ -312,6 +319,13 @@ def forward( attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[self.layer_name] + afd_metadata = forward_context.afd_metadata + if afd_metadata is not None: + afd_stage_idx = afd_metadata.afd_stage_idx + if afd_stage_idx < len(attn_metadata): + attn_metadata = attn_metadata[afd_stage_idx] + else: + attn_metadata = None # padding self_kv_cache = self.kv_cache[forward_context.virtual_engine] return self.impl.forward(self, query, key, value, self_kv_cache, attn_metadata) @@ -535,8 +549,15 @@ def maybe_save_kv_layer_to_connector( if attn_metadata is None: return assert isinstance(attn_metadata, dict) - connector.save_kv_layer(layer_name, kv_cache_layer, - attn_metadata[layer_name]) + if forward_context.afd_metadata: + afd_stage_idx = forward_context.afd_metadata.afd_stage_idx + if afd_stage_idx < len(attn_metadata[layer_name]): + attn_metadata_to_save = attn_metadata[layer_name][afd_stage_idx] + else: + attn_metadata_to_save = None # padding + else: + attn_metadata_to_save = attn_metadata[layer_name] + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata_to_save) def unified_attention( @@ -551,6 +572,12 @@ def unified_attention( attn_metadata = forward_context.attn_metadata if isinstance(attn_metadata, dict): attn_metadata = attn_metadata[layer_name] + if forward_context.afd_metadata: + afd_stage_idx = forward_context.afd_metadata.afd_stage_idx + if afd_stage_idx < len(attn_metadata): + attn_metadata = attn_metadata[afd_stage_idx] + else: + attn_metadata = None # padding self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] output = self.impl.forward(self, query, key, value, kv_cache, @@ -590,8 +617,18 @@ def unified_attention_with_output( wait_for_kv_layer_from_connector(layer_name) forward_context: ForwardContext = get_forward_context() attn_metadata = forward_context.attn_metadata - if isinstance(attn_metadata, dict): + afd_stage_idx = forward_context.afd_metadata.afd_stage_idx + if isinstance(attn_metadata, dict) and afd_stage_idx > 1: attn_metadata = attn_metadata[layer_name] + if forward_context.afd_metadata: + afd_stage_idx = forward_context.afd_metadata.afd_stage_idx + if afd_stage_idx < len(attn_metadata): + attn_metadata = attn_metadata[afd_stage_idx] + else: + attn_metadata = None # padding + else: + attn_metadata = attn_metadata[ + layer_name] if attn_metadata != None else None self = forward_context.no_compile_layers[layer_name] kv_cache = self.kv_cache[forward_context.virtual_engine] self.impl.forward(self, diff --git a/vllm/config/__init__.py b/vllm/config/__init__.py index 0847fba878aa..3bb534a0d92c 100644 --- a/vllm/config/__init__.py +++ b/vllm/config/__init__.py @@ -2368,6 +2368,75 @@ def __repr__(self) -> str: return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})" +@config +@dataclass +class AFDConfig: + """Configuration for AFD (Attention FFN Disaggregation) distributed + computation.""" + + afd_connector: str = "dummy" + """The AFD connector for vLLM to communicate between attention and FFN + nodes. Available connectors: 'dummy', 'stepmesh'""" + + afd_role: Literal["attention", "ffn"] = "attention" + """Role of this vLLM instance in AFD. 'attention' for attention workers, + 'ffn' for FFN servers.""" + + afd_port: int = 1239 + """Port number for stepmesh parameter server communication.""" + + afd_host: str = "127.0.0.1" + """Host address for stepmesh parameter server communication.""" + + num_afd_stages: int = 3 + """Number of pipeline stages for stage parallelism.""" + + num_attention_servers: int = 1 + """Number of attention servers.""" + + num_ffn_servers: int = 1 + """Number of FFN servers.""" + + afd_server_rank: int = 0 + """Rank of this AFD server.""" + + afd_extra_config: dict[str, Any] = field(default_factory=dict) + """Extra configuration for specific AFD connectors.""" + + def compute_hash(self) -> str: + """ + WARNING: Whenever a new field is added to this config, + ensure that it is included in the factors list if + it affects the computation graph. + + Provide a hash that uniquely identifies all the configs + that affect the structure of the computation + graph from input ids/embeddings to the final hidden states, + excluding anything before input ids/embeddings and after + the final hidden states. + """ + # AFD configuration affects the computation graph structure + # as it changes how FFN computation is performed + factors: list[Any] = [ + self.afd_connector, + self.afd_role, + self.num_afd_stages, + self.num_attention_servers, + self.num_ffn_servers, + ] + return hashlib.sha256(str(factors).encode()).hexdigest() + + @property + def is_attention_server(self) -> bool: + """Check if this instance is configured as an attention server.""" + return self.afd_role == "attention" + + @property + def is_ffn_server(self) -> bool: + """Check if this instance is configured as an FFN server.""" + return self.afd_role == "ffn" + + @config @dataclass class PoolerConfig: @@ -3015,6 +3084,8 @@ class VllmConfig: """The configurations for distributed KV cache transfer.""" kv_events_config: Optional[KVEventsConfig] = None """The configurations for event publishing.""" + afd_config: Optional[AFDConfig] = None + """AFD (Attention FFN Disaggregation) configuration.""" # some opaque config, only used to provide additional information # for the hash computation, mainly used for testing, debugging or out of # tree config registration. @@ -3099,6 +3170,10 @@ def compute_hash(self) -> str: vllm_factors.append(self.kv_transfer_config.compute_hash()) else: vllm_factors.append("None") + if self.afd_config: + vllm_factors.append(self.afd_config.compute_hash()) + else: + vllm_factors.append("None") if self.additional_config: if isinstance(additional_config := self.additional_config, dict): additional_config_hash = hashlib.md5( diff --git a/vllm/distributed/afd_transfer/__init__.py b/vllm/distributed/afd_transfer/__init__.py new file mode 100644 index 000000000000..282a2bce2682 --- /dev/null +++ b/vllm/distributed/afd_transfer/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD (Attention-FFN Disaggregation) transfer components. + +This module provides the distributed infrastructure for AFD, enabling +disaggregated FFN computation across different machines while keeping +attention computation local. +""" + +from .afd_connector import (AFDConnectorBase, AFDConnectorFactory, + AFDConnectorMetadata) + +__all__ = ["AFDConnectorBase", "AFDConnectorMetadata", "AFDConnectorFactory"] diff --git a/vllm/distributed/afd_transfer/afd_connector/__init__.py b/vllm/distributed/afd_transfer/afd_connector/__init__.py new file mode 100644 index 000000000000..56f6e91160af --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/__init__.py @@ -0,0 +1,13 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD connector implementations for different transport backends.""" + +from .base import AFDConnectorBase +from .factory import AFDConnectorFactory +from .metadata import AFDConnectorMetadata + +__all__ = [ + "AFDConnectorBase", + "AFDConnectorFactory", + "AFDConnectorMetadata", +] diff --git a/vllm/distributed/afd_transfer/afd_connector/base.py b/vllm/distributed/afd_transfer/afd_connector/base.py new file mode 100644 index 000000000000..422e61f3e569 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/base.py @@ -0,0 +1,139 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +AFDConnectorBase Class for Distributed AFD FFN computation + +The class provides the four core AFD communication interfaces: +1. send_attn_output(): Send attention output to FFN servers (Attention Worker) +2. recv_ffn_output(): Receive FFN computation result (Attention Worker) +3. recv_attn_output(): Receive attention output from workers (FFN Server) +4. send_ffn_output(): Send FFN computation result back (FFN Server) +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Any, Optional + +import torch + +if TYPE_CHECKING: + from vllm.config import VllmConfig + + from .metadata import AFDConnectorMetadata + + +class AFDConnectorBase(ABC): + """ + Abstract base class for AFD connectors. + + This provides the four core interfaces for AFD communication between + attention workers and FFN servers. + """ + + @abstractmethod + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + """Initialize the AFD connector. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + """ + raise NotImplementedError + + @abstractmethod + def close(self) -> None: + """Close the connector and release resources.""" + raise NotImplementedError + + @abstractmethod + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + raise NotImplementedError + + @property + @abstractmethod + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + raise NotImplementedError + + def get_connector_rank(self) -> int: + """Get the rank of this connector.""" + return getattr(self, 'rank', 0) + + def get_connector_local_rank(self) -> int: + """Get the local rank of this connector.""" + return getattr(self, 'local_rank', 0) + + @abstractmethod + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> Any: + """Send attention output to FFN servers. + + Args: + hidden_states: Attention output tensor + metadata: AFD metadata containing layer_idx, stage_idx, seq_len info + + Returns: + Any: Handle for tracking this request (backend-specific) + """ + raise NotImplementedError + + @abstractmethod + def recv_ffn_output( + self, + handle: Any, + ) -> torch.Tensor: + """Wait for and receive FFN computation result. + + Args: + handle: Handle returned by send_attn_output() + + Returns: + torch.Tensor: FFN computation result + """ + raise NotImplementedError + + @abstractmethod + def recv_attn_output( + self, + timeout_ms: Optional[int] = None, + ) -> tuple[torch.Tensor, "AFDConnectorMetadata"]: + """Receive attention output from attention workers. + + Args: + timeout_ms: Optional timeout in milliseconds + + Returns: + tuple: (hidden_states, metadata) + - hidden_states: Concatenated attention outputs + - metadata: Inferred AFD metadata containing + seq_lens and other info + """ + raise NotImplementedError + + @abstractmethod + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> None: + """Send FFN computation result back to attention workers. + + Args: + ffn_output: Computed FFN result + metadata: AFD metadata containing seq_lens + for splitting and routing info + """ + raise NotImplementedError diff --git a/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py b/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py new file mode 100644 index 000000000000..be1e35a8d385 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/dummy_connector.py @@ -0,0 +1,214 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Dummy AFD Connector for testing and local development. + +This connector provides a no-op AFDConnectorBase interface, +useful for testing and development scenarios where actual +distributed FFN computation is not needed. +""" + +import time +from collections import deque +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class DummyAFDConnector(AFDConnectorBase): + """Dummy AFD connector that returns zero tensors. + + This connector is useful for: + 1. Testing AFD infrastructure without actual remote computation + 2. Development scenarios where FFN computation should be disabled + 3. Fallback behavior when remote FFN servers are unavailable + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + """Initialize the dummy AFD connector. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + """ + self.afd_config = config.afd_config + self.rank = rank + self.local_rank = local_rank + self._is_initialized = False + self.hidden_size = config.model_config.hf_config.hidden_size + self.num_stages = config.afd_config.num_afd_stages + + self.events: deque = deque(maxlen=self.num_stages) + + logger.info("DummyAFDConnector initialized for rank %s", rank) + + self.init_afd_connector() + + def init_afd_connector(self) -> None: + """Initialize the dummy connector. + + This is a no-op for the dummy connector. + """ + if self._is_initialized: + return + + logger.info("Initializing DummyAFDConnector (no-op)") + self._is_initialized = True + + def close(self) -> None: + """Close the dummy connector. + + This is a no-op for the dummy connector. + """ + if not self._is_initialized: + return + + logger.info("Closing DummyAFDConnector (no-op)") + self._is_initialized = False + + @property + def is_initialized(self) -> bool: + """Check if the connector is initialized. + + Returns: + bool: True if initialized, False otherwise + """ + return self._is_initialized + + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> Any: + """ + Send attention output to FFN servers (dummy implementation). + """ + logger.debug( + "DummyAFDConnector: send_attn_output layer=%s, stage=%s", + metadata.layer_idx, + metadata.stage_idx, + ) + + # Validate metadata consistency + if not metadata.validate_tensor_shape(hidden_states.shape): + raise ValueError( + "Tensor shape %s doesn't match metadata %s", + hidden_states.shape, + metadata, + ) + + if not metadata.is_single_sequence: + raise ValueError("Attention side should have single sequence") + + self.events.append((None, metadata)) + + return None + + def recv_ffn_output( + self, + timeout_ms: Optional[float] = None, + ) -> torch.Tensor: + """Receive FFN computation result (dummy implementation).""" + logger.debug("DummyAFDConnector: recv_ffn_output timeout_ms=%s", + timeout_ms) + + _, metadata = self.events.popleft() + seq_len = metadata.seq_lens[0] # Single sequence for attention side + return torch.zeros( + seq_len, + self.hidden_size, + dtype=metadata.dtype, + device=metadata.device, + ) + + def recv_attn_output( + self, + timeout_ms: Optional[int] = None, + ) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """ + Receive attention output from attention workers (dummy implementation). + """ + logger.debug("DummyAFDConnector: recv_attn_output timeout_ms=%s", + timeout_ms) + + # Generate dummy data that simulates multiple attention workers + dummy_seq_lens = [ + 2, + 2, + 2, + ] # Variable sequence lengths from different workers + total_tokens = sum(dummy_seq_lens) + + dummy_tensor = torch.zeros(total_tokens, + self.hidden_size, + dtype=torch.bfloat16, + device="cuda") + + # Create dummy metadata + dummy_metadata = AFDConnectorMetadata.create_ffn_metadata( + layer_idx=0, # Dummy layer + stage_idx=0, # Dummy stage + dtype=torch.bfloat16, + device=torch.device("cuda"), + seq_lens=dummy_seq_lens, + request_id=f"dummy_ffn_batch_{time.time()}", + ) + + # Cache metadata for send_ffn_output + self._current_metadata = dummy_metadata + time.sleep(1) + + return dummy_tensor, dummy_metadata + + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """Send FFN computation result back (dummy implementation).""" + logger.debug( + "DummyAFDConnector: send_ffn_output layer=%s, stage=%s", + metadata.layer_idx, + metadata.stage_idx, + ) + + # Validate that ffn_output shape matches metadata + if not metadata.validate_tensor_shape(ffn_output.shape): + logger.warning( + "FFN output shape %s doesn't match metadata %s", + ffn_output.shape, + metadata, + ) + + # Log the splitting information for debugging + logger.debug( + "DummyAFDConnector: Split FFN output into %s parts with lengths %s", + metadata.num_sequences, + metadata.seq_lens, + ) + + # Simulate splitting (for logging purposes) + if metadata.get_split_indices(): + split_outputs = torch.split(ffn_output, metadata.seq_lens, dim=0) + logger.debug( + "DummyAFDConnector: Split shapes: %s", + [s.shape for s in split_outputs], + ) + + time.sleep(1) + # No-op for dummy connector - just log the operation diff --git a/vllm/distributed/afd_transfer/afd_connector/factory.py b/vllm/distributed/afd_transfer/afd_connector/factory.py new file mode 100644 index 000000000000..862aba71d1b1 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/factory.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Factory for creating AFD connectors based on configuration.""" + +import importlib +from typing import TYPE_CHECKING, Callable + +from vllm.logger import init_logger + +from .base import AFDConnectorBase + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class AFDConnectorFactory: + _registry: dict[str, Callable[[], type[AFDConnectorBase]]] = {} + + @classmethod + def register_connector(cls, name: str, module_path: str, + class_name: str) -> None: + """Register a connector with a lazy-loading module and class name.""" + if name in cls._registry: + raise ValueError(f"Connector '{name}' is already registered.") + + def loader() -> type[AFDConnectorBase]: + module = importlib.import_module(module_path) + return getattr(module, class_name) + + cls._registry[name] = loader + + @classmethod + def create_connector(cls, rank: int, local_rank: int, + config: "VllmConfig") -> AFDConnectorBase: + """Create an AFD connector based on the configuration. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFDConfig + + Returns: + AFDConnectorBase: The created connector instance + + Raises: + ValueError: If the transport backend is not supported + ImportError: If required dependencies are not available + """ + afd_config = config.afd_config + connector_name = afd_config.afd_connector + + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + connector_cls = cls._registry[connector_name]() + assert issubclass(connector_cls, AFDConnectorBase) + return connector_cls(rank, local_rank, config) + + @classmethod + def get_connector_class(cls, + connector_name: str) -> type[AFDConnectorBase]: + """Get the connector class for a given connector name. + + Args: + connector_name: The connector name + + Returns: + type[AFDConnectorBase]: The connector class + + Raises: + ValueError: If the connector name is not supported + """ + if connector_name not in cls._registry: + raise ValueError(f"Unsupported connector type: {connector_name}") + + return cls._registry[connector_name]() + + +# Register various connectors here. +# The registration should not be done in each individual file, as we want to +# only load the files corresponding to the current connector. +AFDConnectorFactory.register_connector( + "stepmesh", + "vllm.distributed.afd_transfer.afd_connector.stepmesh_connector", + "StepMeshAFDConnector") + +AFDConnectorFactory.register_connector( + "dummy", "vllm.distributed.afd_transfer.afd_connector.dummy_connector", + "DummyAFDConnector") + +AFDConnectorFactory.register_connector( + "p2pconnector", + "vllm.distributed.afd_transfer.afd_connector.p2p_connector", + "P2PAFDConnector") diff --git a/vllm/distributed/afd_transfer/afd_connector/metadata.py b/vllm/distributed/afd_transfer/afd_connector/metadata.py new file mode 100644 index 000000000000..a5e390784494 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/metadata.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""AFD metadata definitions for communication between attention and +FFN workers.""" + +import time +from dataclasses import dataclass +from typing import Optional +import typing + +import torch + +class FFNNeedForwardData: + + def __init__(self, + moe_comm_method: typing.Any, + num_input_tokens: int, + with_prefill: bool, + total_num_scheduled_tokens: Optional[int], + is_dummy_run:bool = False): + self.moe_comm_method = moe_comm_method + self.num_input_tokens = num_input_tokens + self.with_prefill = with_prefill + self.total_num_scheduled_tokens = total_num_scheduled_tokens + self.is_dummy_run = is_dummy_run + +@dataclass +class AFDConnectorMetadata: + """Lightweight AFD metadata containing core information needed for + communication.""" + layer_idx: int + stage_idx: int + seq_lens: list[ + int] # Length of each sequence, supports variable length and + # multiple sequences + dtype: torch.dtype + device: torch.device + + topk_idx: Optional[torch.Tensor] + # indices token which expert to be sended + topk_weights: Optional[torch.Tensor] + # the expert weights + moe_expert_num: Optional[int] + # number of moe experts + shared_expert_num: Optional[int] + # number of share experts + scale: Optional[torch.Tensor] + # quant scale + expertTokenNumsOut: Optional[torch.Tensor] + # The number of tokens received by each expert used as input for GMM + handle: Optional[torch.Tensor] + # the communication handle given by the recv_attn_output function + + # Optional fields for debugging and extensibility + request_id: Optional[str] = None + timestamp: Optional[float] = None + """ffn need forward data""" + ffn_need_forward_data: Optional[FFNNeedForwardData] = None + + def __post_init__(self): + """Validate data consistency.""" + if not self.seq_lens: + raise ValueError("seq_lens cannot be empty") + if any(length <= 0 for length in self.seq_lens): + raise ValueError("All sequence lengths must be positive") + + @property + def total_tokens(self) -> int: + """Total number of tokens.""" + return sum(self.seq_lens) + + @property + def num_sequences(self) -> int: + """Number of sequences.""" + return len(self.seq_lens) + + @property + def is_single_sequence(self) -> bool: + """Whether this is a single sequence (attention side characteristic).""" + return len(self.seq_lens) == 1 + + @property + def is_multi_sequence(self) -> bool: + """Whether this is multiple sequences (FFN side characteristic).""" + return len(self.seq_lens) > 1 + + @classmethod + def create_attention_metadata( + cls, + layer_idx: int, + stage_idx: int, + seq_len: int, + dtype: torch.dtype, + device: torch.device, + request_id: Optional[str] = None, + ffn_need_forward_data:Optional[FFNNeedForwardData] = None) -> "AFDConnectorMetadata": + """Create metadata for attention side (single sequence).""" + return cls(layer_idx=layer_idx, + stage_idx=stage_idx, + seq_lens=[seq_len], + dtype=dtype, + device=device, + request_id=request_id, + ffn_need_forward_data = ffn_need_forward_data, + timestamp=time.time()) + + @classmethod + def create_ffn_metadata( + cls, + layer_idx: int, + stage_idx: int, + seq_lens: list[int], + dtype: torch.dtype, + device: torch.device, + request_id: Optional[str] = None) -> "AFDConnectorMetadata": + """Create metadata for FFN side (multiple sequences).""" + return cls( + layer_idx=layer_idx, + stage_idx=stage_idx, + seq_lens=seq_lens.copy(), # Prevent external modification + dtype=dtype, + device=device, + request_id=request_id, + timestamp=time.time()) + + def get_split_indices(self) -> list[int]: + """Get tensor split indices for FFN side output splitting.""" + if len(self.seq_lens) <= 1: + return [] + + indices = [] + cumsum = 0 + for length in self.seq_lens[:-1]: # Exclude the last one + cumsum += length + indices.append(cumsum) + return indices + + def validate_tensor_shape(self, tensor_shape: tuple[int, ...]) -> bool: + """Validate if tensor shape is consistent with metadata.""" + if len(tensor_shape) < 1: + return False + return tensor_shape[0] == self.total_tokens + + def to_dict(self) -> dict: + """Convert to dictionary format for serialization and debugging.""" + return { + "layer_idx": self.layer_idx, + "stage_idx": self.stage_idx, + "seq_lens": self.seq_lens, + "dtype": self.dtype, + "device": self.device, + "total_tokens": self.total_tokens, + "num_sequences": self.num_sequences, + "request_id": self.request_id, + "timestamp": self.timestamp, + } + + def __repr__(self) -> str: + """Friendly string representation.""" + return (f"AFDConnectorMetadata(layer={self.layer_idx}, " + f"stage={self.stage_idx}, seq_lens={self.seq_lens}, " + f"total_tokens={self.total_tokens}, dtype={self.dtype}, " + f"device={self.device}, request_id={self.request_id})") diff --git a/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py new file mode 100644 index 000000000000..0fa0adeb24aa --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/p2p_connector.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import re +from datetime import timedelta + +import torch +from torch.distributed.distributed_c10d import ( + _update_default_pg, + _get_default_group, +) +from typing import Any, Optional + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata +from vllm.distributed.parallel_state import ( + init_afd_process_group, + init_model_parallel_group, +) +from vllm.sequence import IntermediateTensors +from vllm.logger import init_logger +from vllm.config import VllmConfig +from vllm.platforms import current_platform + +logger = init_logger(__name__) + + +class DefaultProcessGroupSwitcher: + def __init__(self, default_group, new_default_group): + self.default_group = default_group + self.new_default_group = new_default_group + + def __enter__(self): + _update_default_pg(self.new_default_group) + + def __exit__(self, exc_type, exc_value, traceback): + _update_default_pg(self.default_group) + + +class P2PAFDConnector(AFDConnectorBase): + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ) -> None: + self.rank = rank + self.local_rank = local_rank + self._initialized = False + self.config = config + + def close(self) -> None: + """Close the connector and release resources.""" + # destroy process group + pass + + def init_afd_connector(self) -> None: + """Initialize the AFD connector.""" + afd_size = self.config.afd_config.afd_extra_config.get("afd_size") + role = self.config.afd_config.afd_role + host = self.config.afd_config.afd_host + port = self.config.afd_config.afd_port + attn_size, ffn_size = map( + int, re.match(r"(\d+)\D+(\d+)", afd_size).groups() + ) + world_rank = self.rank if role == "attention" else self.rank + attn_size + + logger.info( + "world_size = %d, world_rank = %d", ffn_size + attn_size, world_rank + ) + backend = current_platform.dist_backend + afd_pg = init_afd_process_group( + backend=backend, + init_method=f"tcp://{host}:{port}", + world_size=ffn_size + attn_size, + rank=world_rank, + group_name="afd", + timeout=timedelta(minutes=2), + ) + ffn_ranks = [i for i in range(ffn_size, ffn_size + attn_size)] + attn_ranks = [i for i in range(attn_size)] + + default_pg_switcher = DefaultProcessGroupSwitcher( + _get_default_group(), afd_pg + ) + with default_pg_switcher: + sub_group_ranks = [] + for i in range(len(ffn_ranks)): + ranks = list([attn_ranks[i], ffn_ranks[i]]) + sub_group_ranks.append(ranks) + self.process_group = init_model_parallel_group( + sub_group_ranks, self.rank, backend=backend, group_name="ae" + ) + + logger.info("p2p connector initialized") + + self._initialized = True + + def is_initialized(self) -> bool: + """Check if the connector is initialized and ready to use. + + Returns: + bool: True if the connector is initialized, False otherwise. + """ + return self._initialized + + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> Any: + """ + This method will be called by the ATTN side. + + + * To send the intermediate tensors generated by ATTN instances to FFN. + """ + + intermediate_tensors = IntermediateTensors( + { + "hidden_states": hidden_states, + } + ) + try: + self.process_group.send_tensor_dict( + intermediate_tensors.tensors, + all_gather_group=None, + ) + dst = ( + self.process_group.rank_in_group + 1 + ) % self.process_group.world_size + self.process_group.send_object(metadata, dst) + except Exception as e: + raise RuntimeError(f"Communication error: {e}") + + def recv_attn_output( + self, + timeout_ms: Optional[int] = None, + ) -> tuple[torch.Tensor, "AFDConnectorMetadata"]: + """ + This method will be called by the FFN side. + + + * To receive the intermediate tensors from ATTN. + * And (Maybe) dispatch them from the receiver to other GPUs. + """ + intermediate_tensors = self.process_group.recv_tensor_dict( + all_gather_group=None, + ) + src = ( + self.process_group.rank_in_group - 1 + ) % self.process_group.world_size + metadata = self.process_group.recv_object(src) + return intermediate_tensors["hidden_states"], metadata + + # ------------------------------------------------------------------------- + # attn <- ffn + # ------------------------------------------------------------------------- + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: "AFDConnectorMetadata", + ) -> None: + """ + This method will be called by the FFN side. + + + * To send the intermediate tensors generated by FFN instances back to + the sender (this should be the same GPU as it comes from) + """ + intermediate_tensors = IntermediateTensors( + { + "hidden_states": ffn_output, + } + ) + self.process_group.send_tensor_dict( + intermediate_tensors.tensors, + ) + dst = ( + self.process_group.rank_in_group + 1 + ) % self.process_group.world_size + + self.process_group.send_object(metadata, dst) + + def recv_ffn_output( + self, + handle: Any, + ) -> torch.Tensor: + """ + This method will be called by the ATTN side. + + + * To receive the MOE output intermediate tensors. + * And (Maybe) dispatch them from the receiver to other GPUs. + (this should be the same GPU as it comes from) + """ + intermediate_tensors = self.process_group.recv_tensor_dict( + all_gather_group=None, + ) + src = ( + self.process_group.rank_in_group - 1 + ) % self.process_group.world_size + + self.process_group.recv_object(src) + return intermediate_tensors["hidden_states"] diff --git a/vllm/distributed/afd_transfer/afd_connector/stepmesh_connector.py b/vllm/distributed/afd_transfer/afd_connector/stepmesh_connector.py new file mode 100644 index 000000000000..c8a069dfcbb5 --- /dev/null +++ b/vllm/distributed/afd_transfer/afd_connector/stepmesh_connector.py @@ -0,0 +1,420 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""StepMesh-based AFD connector implementation.""" + +import os +import subprocess +import time +from collections import deque +from typing import TYPE_CHECKING, Any, Optional + +import torch + +from vllm.logger import init_logger + +from .base import AFDConnectorBase +from .metadata import AFDConnectorMetadata + +import fserver_lib as ps # isort: skip + +if TYPE_CHECKING: + from vllm.config import VllmConfig + +logger = init_logger(__name__) + + +class StepMeshAFDConnector(AFDConnectorBase): + """StepMesh-based implementation of AFD connector. + + This connector uses StepMesh parameter server for communication between + attention workers and FFN servers. + """ + + def __init__( + self, + rank: int, + local_rank: int, + config: "VllmConfig", + ): + """Initialize StepMesh AFD connector. + + Args: + rank: Global rank of this process + local_rank: Local rank within the node + config: VllmConfig containing AFD configuration + """ + self.afd_config = config.afd_config + self.rank = rank + self.local_rank = local_rank + self.server_rank = self.afd_config.afd_server_rank + self.num_recv_times = (self.afd_config.num_ffn_servers + if self.afd_config.afd_role == "attention" else + self.afd_config.num_attention_servers) + parallel_config = config.parallel_config + self.world_size = (parallel_config.tensor_parallel_size * + parallel_config.pipeline_parallel_size * + parallel_config.data_parallel_size) + self._initialized = False + self.num_stages = self.afd_config.num_afd_stages + self.recv_counter = 0 + + # Metadata tracking for new interface + self._current_comm_handles = None + self._current_metadata = None + + if self.afd_config.afd_role == "attention": + self.events: deque = deque(maxlen=self.num_stages) + self.max_num_tokens = ( + config.scheduler_config.max_num_batched_tokens // + self.num_stages + + config.scheduler_config.max_num_batched_tokens % + self.num_stages) + self.recv_buffer: list[list[torch.Tensor]] = [[ + torch.empty( + ( + self.max_num_tokens, + config.model_config.hf_config.hidden_size, + ), + dtype=torch.bfloat16, + device=torch.device("cuda"), + ).contiguous() for _ in range(self.num_recv_times) + ] for _ in range(self.afd_config.num_afd_stages)] + self.send_buffer: list[torch.Tensor] = [ + torch.empty( + ( + self.max_num_tokens, + config.model_config.hf_config.hidden_size, + ), + dtype=torch.bfloat16, + device=torch.device("cuda"), + ).contiguous() for _ in range(self.afd_config.num_afd_stages) + ] + else: + self.max_num_tokens = ( + config.scheduler_config.max_num_batched_tokens // + self.num_stages + + config.scheduler_config.max_num_batched_tokens % + self.num_stages) * self.num_recv_times + self.ret_buffer = torch.empty( + [ + self.max_num_tokens, + config.model_config.hf_config.hidden_size, + ], + dtype=torch.bfloat16, + device=torch.device("cuda"), + ).contiguous() + + # StepMesh environment setup + self._setup_stepmesh_env() + + if (self.afd_config.afd_role == "ffn" + and self.afd_config.afd_server_rank == 0 + and self.local_rank == 0): + self._start_scheduler_process() + + def _setup_stepmesh_env(self) -> None: + """Setup StepMesh environment variables.""" + # Basic StepMesh configuration based on draft.diff + if self.afd_config.afd_role == "attention": + os.environ["DMLC_ROLE"] = "worker" + elif self.afd_config.afd_role == "ffn": + os.environ["DMLC_ROLE"] = "server" + else: + raise ValueError(f"Invalid AFD role: {self.afd_config.afd_role}") + + os.environ["DMLC_NUM_WORKER"] = str( + self.afd_config.num_attention_servers) + os.environ["DMLC_NUM_SERVER"] = str(self.afd_config.num_ffn_servers) + + os.environ["DMLC_ENABLE_RDMA"] = "ibverbs" + os.environ["DMLC_INTERFACE"] = "auto" + os.environ["STEPMESH_SPLIT_QP_LAG"] = os.environ.get( + "STEPMESH_SPLIT_QP_LAG", "0") + os.environ["STEPMESH_BIND_CPU_CORE"] = "1" + + os.environ["STEPMESH_GPU"] = os.environ.get("STEPMESH_GPU", + str(self.local_rank)) + + os.environ["DMLC_PS_ROOT_PORT"] = str(self.afd_config.afd_port) + os.environ["DMLC_PS_ROOT_URI"] = self.afd_config.afd_host + os.environ["DMLC_NODE_HOST"] = str(self.afd_config.afd_host) + os.environ["SCHEDULER_IP"] = str(self.afd_config.afd_host) + + os.environ["DMLC_NODE_RANK"] = str(self.afd_config.afd_server_rank) + os.environ["DMLC_GROUP_SIZE"] = str(self.world_size) + + os.environ["PS_VERBOSE"] = os.environ.get("PS_VERBOSE", "2") + + logger.info( + "StepMesh environment setup: role=%s, " + "num_worker=%s, " + "num_server=%s, " + "port=%s, " + "host=%s, " + "node_rank=%s, " + "gpu=%s, " + "group_size=%s", os.environ.get('DMLC_ROLE'), + os.environ.get('DMLC_NUM_WORKER'), + os.environ.get('DMLC_NUM_SERVER'), + os.environ.get('DMLC_PS_ROOT_PORT'), + os.environ.get('DMLC_PS_ROOT_URI'), + os.environ.get('DMLC_NODE_RANK'), os.environ.get('STEPMESH_GPU'), + os.environ.get('DMLC_GROUP_SIZE')) + + def _start_scheduler_process(self) -> None: + """Start scheduler process for FFN role. + + This method launches a separate subprocess to run the StepMesh scheduler + when the current process is in FFN role. + """ + try: + logger.info("Starting scheduler subprocess for FFN role") + # Use subprocess.Popen to start scheduler as a separate process + self.scheduler_process = subprocess.Popen( + [ + "python", + "-c", + "import torch; import fserver_lib as ps; import os; " + 'os.environ["DMLC_ROLE"] = "scheduler"; ' + 'os.environ["DMLC_INTERFACE"] = "brainpf_bond0"; ' + "ps.init(); ps.stop()", + ], + env=os.environ.copy(), + ) + logger.info("Scheduler subprocess started with PID: %s", + self.scheduler_process.pid) + except Exception as e: + logger.error("Failed to start scheduler subprocess: %s", e) + raise RuntimeError( + f"Failed to start scheduler subprocess: {e}") from e + + def init_afd_connector(self) -> None: + """Initialize StepMesh connector.""" + if self._initialized: + return + try: + logger.info("+++++Start init ps. %s", self.rank) + ps.init() + logger.info("----Finish init ps. %s", self.rank) + + self._initialized = True + logger.info("StepMesh connector initialized successfully as %s", + os.environ.get('DMLC_ROLE')) + + except ImportError as e: + raise ImportError( + f"StepMesh is not available. Please install StepMesh to use " + f"StepMesh AFD connector. Error: {e}") from e + except Exception as e: + raise RuntimeError( + f"Failed to initialize StepMesh connector: {e}") from e + + @property + def is_initialized(self) -> bool: + """Check if the connector is initialized.""" + return self._initialized + + # AFD Communication Methods Implementation + def send_attn_output( + self, + hidden_states: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> Any: + """Send attention output to FFN servers via StepMesh push_pull. + + Args: + hidden_states: Attention output tensor + metadata: AFD metadata containing layer_idx, stage_idx, seq_len info + + Returns: + Any: Event handle for tracking this request + """ + if not self._initialized: + raise RuntimeError("StepMesh connector not initialized") + + # Validate metadata consistency + if not metadata.validate_tensor_shape(hidden_states.shape): + raise ValueError("Tensor shape %s doesn't match metadata %s", + hidden_states.shape, metadata) + + if not metadata.is_single_sequence: + raise ValueError("Attention side should have single sequence") + + seq_len = metadata.seq_lens[0] # Single sequence for attention side + stage_id = metadata.stage_idx + + # Create keys based on the pattern from your example + node_rank_offset = int(self.rank * 1e6) + recv_key = [stage_id + 1000] + recv_buff = [t[:seq_len] for t in self.recv_buffer[stage_id]] + send_buff = [self.send_buffer[stage_id][:seq_len]] + + if seq_len > self.max_num_tokens: + raise ValueError("AFD seq_len[%s] exceeds max_num_tokens[%s]", + seq_len, self.max_num_tokens) + + send_buff[0].copy_(hidden_states[:seq_len]) + send_key = [stage_id + node_rank_offset] + + event = ps.push_pull( + send_buff, + send_key, + recv_buff, + recv_key, + ) + self.events.append((event, metadata)) + + def recv_attn_output( + self, + timeout_ms: Optional[float] = None, + ) -> tuple[torch.Tensor, AFDConnectorMetadata]: + """Receive attention output from attention workers (FFN server side). + + Args: + timeout_ms: Optional timeout in milliseconds + + Returns: + tuple: (hidden_states, metadata) received from attention workers + """ + if not self._initialized: + raise RuntimeError("StepMesh connector not initialized") + + try: + # batches = self.signal.get_batch() # type: ignore + batches = ps.get_batch() # type: ignore + + # Extract tensors and build metadata + recv_tensors = [] + seq_lens = [] + comm_handles = [] + + for node_rank in range(self.num_recv_times): + tensor = batches[node_rank][1][0] + comm_id = batches[node_rank][0] + + recv_tensors.append(tensor) + seq_lens.append(tensor.shape[0]) + comm_handles.append(comm_id) + + # Merge tensors + merged_tensor = torch.cat(recv_tensors, dim=0) + + # Infer metadata from communication + # TODO: Extract layer_idx and stage_idx from comm_id encoding + inferred_metadata = AFDConnectorMetadata.create_ffn_metadata( + layer_idx=-1, # Extract from comm_id + stage_idx=-1, # Extract from comm_id + seq_lens=seq_lens, + dtype=merged_tensor.dtype, + device=merged_tensor.device, + request_id=f"ffn_batch_{time.time()}", + ) + + # Store handles for response + self._current_comm_handles = comm_handles # type: ignore + self._current_metadata = inferred_metadata # type: ignore + + return merged_tensor, inferred_metadata + + except Exception as e: + logger.error("Failed to receive attention output: %s", e) + raise RuntimeError(f"StepMesh recv_attn_output failed: {e}") from e + + def send_ffn_output( + self, + ffn_output: torch.Tensor, + metadata: AFDConnectorMetadata, + ) -> None: + """Send FFN computation result back to attention workers. + + Args: + ffn_output: Computed FFN output + metadata: AFD metadata containing seq_lens for s + plitting and routing info + """ + if not self._initialized: + raise RuntimeError("StepMesh connector not initialized") + + self.ret_buffer[:ffn_output.shape[0]].copy_(ffn_output) + + try: + # Use metadata.seq_lens for splitting + split_indices = metadata.get_split_indices() + if split_indices: + split_outputs = torch.split(ffn_output, + metadata.seq_lens, + dim=0) + else: + split_outputs = [ffn_output] + + comm_handles = self._current_comm_handles + ps.respond_vec(self.ret_buffer, split_outputs, comm_handles) + + except Exception as e: + logger.error("Failed to send FFN output: %s", e) + raise RuntimeError(f"StepMesh send_ffn_output failed: {e}") from e + + def recv_ffn_output( + self, + timeout_ms: Optional[float] = None, + ) -> torch.Tensor: + """Wait for FFN computation result from FFN servers. + + Args: + handle: Event handle returned by send_attn_output + + Returns: + torch.Tensor: FFN computation result + """ + if not self._initialized: + raise RuntimeError("StepMesh connector not initialized") + + try: + if len(self.events) > 0: + event, metadata = self.events.popleft() + ps.wait(event, timeout_ms=50000) + # Get result from recv_buffer + if metadata: + stage_idx = metadata.stage_idx + seq_len = metadata.seq_lens[ + 0] # Single sequence for attention side + if len(self.recv_buffer[stage_idx]) == 1: + return self.recv_buffer[stage_idx][0][:seq_len] + else: + return torch.stack( + [t[:seq_len] for t in self.recv_buffer[stage_idx]], + dim=0, + ).sum(dim=0) + else: + raise ValueError("No metadata found for handle") + + except Exception as e: + logger.error("Failed to wait for FFN output: %s", e) + raise RuntimeError(f"StepMesh recv_ffn_output failed: {e}") from e + + def close(self) -> None: + """Close the StepMesh connector and release resources.""" + if self._initialized: + try: + ps.finalize() + self._initialized = False + logger.info("StepMesh connector closed successfully") + except Exception as e: + logger.error("Failed to close StepMesh connector: %s", e) + + # Clean up scheduler subprocess if it exists + if (hasattr(self, "scheduler_process") + and self.scheduler_process is not None): + try: + if (self.scheduler_process.poll() + is None): # Process is still running + logger.info("Terminating scheduler subprocess") + self.scheduler_process.terminate() + self.scheduler_process.wait(timeout=5) + logger.info("Scheduler subprocess terminated successfully") + except subprocess.TimeoutExpired: + logger.warning( + "Scheduler subprocess failed to terminate gracefully") + self.scheduler_process.kill() + except Exception as e: + logger.error("Failed to terminate scheduler subprocess: %s", e) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 12571afaa4c1..2e7331f0a70b 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -37,6 +37,10 @@ import torch import torch.distributed from torch.distributed import Backend, ProcessGroup +from torch.distributed.distributed_c10d import (PrefixStore, Store, + _new_process_group_helper, + _world, default_pg_timeout, + rendezvous) from typing_extensions import deprecated import vllm.envs as envs @@ -893,6 +897,55 @@ def combine(self, hidden_states) -> torch.Tensor: return hidden_states +def init_afd_process_group( + backend: Union[str, Backend] = None, + init_method: Optional[str] = None, + timeout: Optional[timedelta] = None, + world_size: int = -1, + rank: int = -1, + store: Optional[Store] = None, + group_name: str = None, + pg_options: Optional[Any] = None, +): + assert (store is None) or (init_method is None), ( + "Cannot specify both init_method and store.") + + if store is not None: + assert world_size > 0, "world_size must be positive if using store" + assert rank >= 0, "rank must be non-negative if using store" + elif init_method is None: + init_method = "env://" + + backend = Backend(backend) if backend else Backend("undefined") + + if timeout is None: + timeout = default_pg_timeout + + if store is None: + rendezvous_iterator = rendezvous(init_method, + rank, + world_size, + timeout=timeout) + store, rank, world_size = next(rendezvous_iterator) + store.set_timeout(timeout) + store = PrefixStore(group_name, store) + + pg_options_param_name = ("backend_options" if str(torch.__version__) + >= "2.6" else "pg_options") + pg, _ = _new_process_group_helper( + world_size, + rank, + [], + backend, + store, + group_name=group_name, + **{pg_options_param_name: pg_options}, + timeout=timeout, + ) + _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} + return pg + + _WORLD: Optional[GroupCoordinator] = None _NODE_COUNT: Optional[int] = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 595d318fbaaf..e23b1578f6ce 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -21,10 +21,10 @@ from typing_extensions import TypeIs, deprecated import vllm.envs as envs -from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig, - ConfigType, ConvertOption, DecodingConfig, - DetailedTraceModules, Device, DeviceConfig, - DistributedExecutorBackend, EPLBConfig, +from vllm.config import (AFDConfig, BlockSize, CacheConfig, CacheDType, + CompilationConfig, ConfigType, ConvertOption, + DecodingConfig, DetailedTraceModules, Device, + DeviceConfig, DistributedExecutorBackend, EPLBConfig, GuidedDecodingBackend, HfOverrides, KVEventsConfig, KVTransferConfig, LoadConfig, LogprobsMode, LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig, @@ -473,6 +473,9 @@ class EngineArgs: kv_sharing_fast_prefill: bool = \ CacheConfig.kv_sharing_fast_prefill + # AFD config + afd_config: Optional[AFDConfig] = None + def __post_init__(self): # support `EngineArgs(compilation_config={...})` # without having to manually construct a @@ -918,6 +921,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: **vllm_kwargs["kv_events_config"]) vllm_group.add_argument("--compilation-config", "-O", **vllm_kwargs["compilation_config"]) + vllm_group.add_argument("--afd-config", **vllm_kwargs["afd_config"]) vllm_group.add_argument("--additional-config", **vllm_kwargs["additional_config"]) @@ -933,7 +937,8 @@ def from_cli_args(cls, args: argparse.Namespace): # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. - engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + engine_args_dict = {attr: getattr(args, attr) for attr in attrs} + engine_args = cls(**engine_args_dict) return engine_args def create_model_config(self) -> ModelConfig: @@ -1436,6 +1441,7 @@ def create_engine_config( kv_transfer_config=self.kv_transfer_config, kv_events_config=self.kv_events_config, additional_config=self.additional_config, + afd_config=self.afd_config, ) return config diff --git a/vllm/entrypoints/afd_ffn_server.py b/vllm/entrypoints/afd_ffn_server.py new file mode 100644 index 000000000000..03f3d0fb7f8a --- /dev/null +++ b/vllm/entrypoints/afd_ffn_server.py @@ -0,0 +1,91 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM AFD FFN Server Entry Point + +This script provides a standalone entry point for running FFN servers in an AFD +(Attention-FFN Disaggregation) setup. FFN servers handle remote FFN computation +for attention workers. + +Usage: + python -m vllm.entrypoints.afd_ffn_server /path/to/model \ + --tensor-parallel-size 8 \ + --afd-config '{"afd_connector": "dummy", "afd_role": "ffn"}' \ +""" +import threading +from typing import Any + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.logger import init_logger +from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) + + +class AFDFFNServer: + """AFD FFN Server main class.""" + + def __init__(self, args: Any): + engine_args = AsyncEngineArgs.from_cli_args(args) + self.vllm_config = engine_args.create_engine_config() + logger.info("Start AFD FFN Server with vllm_config: %s", + self.vllm_config) + + def start(self) -> None: + """Start the AFD FFN server.""" + try: + # Import here to avoid circular imports + from vllm.v1.executor.abstract import Executor + + # Create configurations + executor_class = Executor.get_class(self.vllm_config) + self.model_executor = executor_class(vllm_config=self.vllm_config) + # Start the FFN server loop + self._run_server_loop() + + except Exception as e: + logger.error("Failed to start AFD FFN server: %s", e) + raise + + def _run_server_loop(self) -> None: + """Start FFN workers and wait for completion""" + logger.info("AFD FFN Server started, workers running...") + try: + # Tell workers to start FFN server loops (one-time call) + self.model_executor.collective_rpc("start_ffn_server_loop") + + # Main thread waits without busy polling + shutdown_event = threading.Event() + shutdown_event.wait() # Block until interrupted + + except KeyboardInterrupt: + logger.info("Server shutting down...") + self.model_executor.collective_rpc("stop_ffn_server_loop") + except Exception as e: + logger.error("Server error: %s", e) + raise + + +def main(args: Any) -> None: + """Main entry point for AFD FFN server.""" + try: + server = AFDFFNServer(args) + server.start() + except KeyboardInterrupt: + logger.info("Interrupted by user") + except Exception as e: + logger.error("Server error: %s", e) + raise + + +if __name__ == "__main__": + parser = FlexibleArgumentParser() + # Add model as positional argument (like vllm serve) + parser.add_argument("model", type=str, help="Model name or path") + parser = AsyncEngineArgs.add_cli_args(parser) + args = parser.parse_args() + + # Set the model from positional argument + args.model = args.model + + main(args) diff --git a/vllm/entrypoints/cli/fserver.py b/vllm/entrypoints/cli/fserver.py new file mode 100644 index 000000000000..8eb83434ec22 --- /dev/null +++ b/vllm/entrypoints/cli/fserver.py @@ -0,0 +1,51 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""vLLM AFD FFN Server CLI command.""" + +import argparse + +from vllm.engine.arg_utils import AsyncEngineArgs +from vllm.entrypoints.afd_ffn_server import main +from vllm.entrypoints.cli.types import CLISubcommand + + +class FServerCommand(CLISubcommand): + """Command for running vLLM AFD FFN Server.""" + + def __init__(self): + self.name = "fserver" + super().__init__() + + def subparser_init(self, subparsers): + """Initialize the fserver subparser.""" + parser = subparsers.add_parser( + self.name, + help="Start vLLM AFD FFN Server", + description= + "Start vLLM AFD FFN Server for Attention-FFN Disaggregation", + usage="vllm fserver MODEL --afd-config CONFIG [options]") + + # Add model as positional argument (like vllm serve) + parser.add_argument("model", type=str, help="Model name or path") + + # Use AsyncEngineArgs to add all vLLM engine arguments + parser = AsyncEngineArgs.add_cli_args(parser) + + return parser + + def validate(self, args: argparse.Namespace) -> None: + """Validate arguments for fserver command.""" + # Validate that afd_config is provided + if not hasattr(args, 'afd_config') or not args.afd_config: + raise ValueError("--afd-config is required for FFN server") + + @staticmethod + def cmd(args: argparse.Namespace) -> None: + """Run the fserver command.""" + # Call the main function from afd_ffn_server directly with parsed args + main(args) + + +def cmd_init() -> list[CLISubcommand]: + """Initialize fserver command.""" + return [FServerCommand()] diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index fed3ea650405..75ab594af48f 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -12,6 +12,7 @@ def main(): import vllm.entrypoints.cli.benchmark.main import vllm.entrypoints.cli.collect_env + import vllm.entrypoints.cli.fserver import vllm.entrypoints.cli.openai import vllm.entrypoints.cli.run_batch import vllm.entrypoints.cli.serve @@ -21,6 +22,7 @@ def main(): CMD_MODULES = [ vllm.entrypoints.cli.openai, vllm.entrypoints.cli.serve, + vllm.entrypoints.cli.fserver, vllm.entrypoints.cli.benchmark.main, vllm.entrypoints.cli.collect_env, vllm.entrypoints.cli.run_batch, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index b3ddd7b9a739..b5994da4de72 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata + from vllm.distributed.afd_transfer import AFDConnectorBase logger = init_logger(__name__) @@ -65,8 +66,10 @@ def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: list[int], @dataclass class DPMetadata: - max_tokens_across_dp_cpu: torch.Tensor - cu_tokens_across_dp_cpu: torch.Tensor + max_tokens_across_dp_cpu: torch.Tensor # 1D for normal, 1D for + # stage-wise max + cu_tokens_across_dp_cpu: torch.Tensor # 1D for normal, 2D + # [num_stages, dp_size] for stage-wise local_sizes: Optional[list[int]] = None @staticmethod @@ -97,6 +100,47 @@ def num_tokens_across_dp(num_tokens: int, dp_size: int, dist.all_reduce(num_tokens_tensor, group=group) return num_tokens_tensor.cpu() + @staticmethod + def num_stage_tokens_across_dp( + num_stage_tokens: list[int], dp_size: int, + dp_rank: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Gather the stage token counts across all DP ranks. + Args: + num_stage_tokens: list of token counts per stage for current rank + dp_size: number of DP ranks + dp_rank: current DP rank + Returns: + stage_tokens_across_dp_cpu: [num_stages, dp_size] tensor + max_stage_tokens_across_dp_cpu: [num_stages] tensor with max + tokens per stage + """ + from vllm.distributed.parallel_state import get_dp_group + device = current_platform.device_type + group = get_dp_group().device_group + + if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: + device = "cpu" + group = get_dp_group().cpu_group + + num_stages = len(num_stage_tokens) + stage_tokens_across_dp = torch.zeros((num_stages, dp_size), + device=device, + dtype=torch.int32) + stage_tokens_across_dp[:, dp_rank] = torch.tensor(num_stage_tokens, + device=device, + dtype=torch.int32) + + # AllReduce to gather from all ranks + dist.all_reduce(stage_tokens_across_dp, group=group) + stage_tokens_across_dp_cpu = stage_tokens_across_dp.cpu() + + # Compute max tokens per stage + max_stage_tokens_across_dp_cpu = torch.max(stage_tokens_across_dp_cpu, + dim=1)[0] + + return stage_tokens_across_dp_cpu, max_stage_tokens_across_dp_cpu + @staticmethod def make( parallel_config: ParallelConfig, @@ -171,6 +215,15 @@ def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: return self.local_sizes +@dataclass +class AFDMetadata: + afd_tokens_start_loc: list[int] + afd_reqs_start_loc: list[int] + afd_stage_idx: int + afd_connector: "AFDConnectorBase" + afd_tokens_lens: list[int] # padded lengths for tensor slicing + + @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context @@ -186,6 +239,7 @@ class ForwardContext: virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None + afd_metadata: Optional[AFDMetadata] = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # by default NONE, no cudagraph is used. cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE @@ -216,7 +270,8 @@ def set_forward_context( num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, - batch_descriptor: Optional[BatchDescriptor] = None): + batch_descriptor: Optional[BatchDescriptor] = None, + afd_metadata: Optional[AFDMetadata] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. @@ -240,6 +295,7 @@ def set_forward_context( virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, + afd_metadata=afd_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ) diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..22e3d09676d0 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,147 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 5 + }, + "1536": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 3 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + } +} + diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..8bac7af0c2da --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 4 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json new file mode 100644 index 000000000000..5910027e17f9 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json @@ -0,0 +1,146 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 64, + "num_warps": 4, + "num_stages": 4 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "96": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "128": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "256": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "512": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 3 + }, + "1024": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 4 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 5 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 5 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 32, + "num_warps": 8, + "num_stages": 5 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 4, + "num_stages": 4 + } +} diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 06edfb0552e8..30e46ffa7b17 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -720,7 +720,10 @@ def get_moe_configs( logger.info("Using configuration from %s for MoE layer.", config_file_path) # If a configuration has been found, return it - return {int(key): val for key, val in json.load(f).items()} + tuned_config = json.load(f) + # Delete triton_version from tuned_config + tuned_config.pop("triton_version", None) + return {int(key): val for key, val in tuned_config.items()} # If no optimized configuration is available, we will use the default # configuration diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index e4a21febc5bd..9868952ba338 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -35,11 +35,14 @@ import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.config import (CacheConfig, ParallelConfig, VllmConfig, + get_current_vllm_config) from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) +from vllm.forward_context import get_forward_context +from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -66,6 +69,11 @@ make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata,FFNNeedForwardData) + +logger = init_logger(__name__) + class DeepseekV2MLP(nn.Module): @@ -626,6 +634,8 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config + afd_config = vllm_config.afd_config + self.role = afd_config.afd_role self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -639,41 +649,44 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: attn_cls = DeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV2MoE( + if self.role is None or self.role == "attention": + self.self_attn = attn_cls( config=config, - parallel_config=parallel_config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank if hasattr( + config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, quant_config=quant_config, - prefix=f"{prefix}.mlp", + prefix=f"{prefix}.self_attn", ) + + if self.role is None or self.role == "ffn": + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -687,6 +700,10 @@ def forward( residual: Optional[torch.Tensor], ) -> torch.Tensor: # Self Attention + forward_ctx = get_forward_context() + afd_metadata = (forward_ctx.afd_metadata + if forward_ctx is not None else None) + afd_connector = afd_metadata.afd_connector if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) @@ -711,6 +728,28 @@ def forward( # Fully Connected hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + # ---------ascend ffn need data + if forward_ctx.moe_comm_method_name is not None: + moe_comm_method = forward_ctx.moe_comm_method_name + num_tokens = hidden_states.shape[0] + with_prefill = forward_ctx.with_prefill + + ffn_need_forward_data = FFNNeedForwardData(moe_comm_method,num_tokens,with_prefill) + num_stages = 0 + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=self.layer_idx, + stage_idx=num_stages, + seq_len=hidden_states.shape[0], + dtype=hidden_states.dtype, + device=hidden_states.device, + ffn_need_forward_data=ffn_need_forward_data + ) + else: + metadata = None + if self.role == "attention": + afd_connector.send_attn_output(hidden_states, metadata) + hidden_states = afd_connector.recv_ffn_output(None) + return hidden_states, residual hidden_states = self.mlp(hidden_states) if isinstance(self.mlp, @@ -724,6 +763,52 @@ def forward( return hidden_states, residual + def compute_attn_output( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + ) -> torch.Tensor: # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + ) + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + return hidden_states, residual + + def compute_ffn_output(self, hidden_states): + assert self.role == "ffn" + hidden_states = self.mlp(hidden_states) + if isinstance(self.mlp, + DeepseekV2MLP) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + return hidden_states + @support_torch_compile class DeepseekV2Model(nn.Module): @@ -782,8 +867,29 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + forward_ctx = get_forward_context() + afd_metadata = (forward_ctx.afd_metadata + if forward_ctx is not None else None) + for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + if (afd_metadata is not None + and isinstance(afd_metadata.afd_tokens_start_loc, list) + and len(afd_metadata.afd_tokens_start_loc) - 1 > 1): + num_stages = len(afd_metadata.afd_tokens_start_loc) - 1 + stage_hidden_states: list[torch.Tensor] = [] + stage_residual: list[Optional[torch.Tensor]] = [] + stage_positions: list[torch.Tensor] = [] + for stage_idx in range(num_stages): + start = afd_metadata.afd_tokens_start_loc[stage_idx] + end = start + afd_metadata.afd_tokens_lens[stage_idx] + stage_hidden_states.append( + hidden_states[start:end].clone()) + stage_residual.append(residual[start:end].clone( + ) if residual is not None else None) + stage_positions.append(positions[start:end]) + else: + hidden_states, residual = layer(positions, hidden_states, + residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ @@ -794,6 +900,13 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states + def compute_ffn_output( + self, hidden_states, + layer_idx) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.layers[layer_idx].compute_ffn_output( + hidden_states) + return hidden_states + class DeepseekV2ForCausalLM(nn.Module, SupportsPP, MixtureOfExperts, SupportsLoRA): @@ -807,7 +920,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - + self.afd_config = vllm_config.afd_config # `packed_modules_mapping` needs to be modified before # initializing DeepseekV2Model, as it is passed inplace to # quantization config init and may be used to select the @@ -845,11 +958,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): continue assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): + if (self.afd_config.afd_role is None or self.afd_config.afd_role + == "ffn") and isinstance(layer.mlp, DeepseekV2MoE): # Pick last one layer since the first ones may be dense layers. example_moe = layer.mlp self.moe_layers.append(layer.mlp.experts) + if self.afd_config.afd_role == "attention": + return if example_moe is None: raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") @@ -908,6 +1024,13 @@ def forward( inputs_embeds) return hidden_states + def compute_ffn_output( + self, current_layer_idx, + hidden_states) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model.compute_ffn_output(hidden_states, + current_layer_idx) + return hidden_states + def compute_logits( self, hidden_states: torch.Tensor, @@ -929,19 +1052,27 @@ def load_weights(self, weights: Iterable[tuple[str, # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) + if self.afd_config.afd_role == "attention": + vllm_config = get_current_vllm_config() + num_redundant_experts = ( + vllm_config.parallel_config.eplb_config.num_redundant_experts) + else: + num_redundant_experts = self.num_redundant_experts expert_params_mapping = FusedMoE.make_expert_params_mapping( ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.n_routed_experts, - num_redundant_experts=self.num_redundant_experts) + num_redundant_experts=num_redundant_experts) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue - + if self.afd_config.afd_role == "attention" and self.is_moe_weight( + name): + continue spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) if spec_layer is not None: continue # skip spec decode layers for main model @@ -989,7 +1120,9 @@ def load_weights(self, weights: Iterable[tuple[str, # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True - + if (self.afd_config.afd_role is not None + and self.afd_config.afd_role == "attention"): + continue # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) @@ -1013,6 +1146,12 @@ def load_weights(self, weights: Iterable[tuple[str, name = name_mapped break else: + if ( + self.afd_config.afd_role == "ffn" + and not self.is_moe_weight(name) + and not self.is_common_weight(name) + ): + continue if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank @@ -1039,6 +1178,18 @@ def load_weights(self, weights: Iterable[tuple[str, return loaded_params + def is_moe_weight(self, name): + return bool("shared_experts" in name or "experts" in name + or "gate" in name or "up" in name or "down" in name) + + def is_common_weight(self, name): + if ("lm_head" in name or "model.norm.weight" in name + or "embed_tokens" in name or "input_layernorm" in name + or "post_attention_layernorm" in name): + # or "model.layers.0.self_attn.o_proj.weight" in name:# for init kv cache + return True + return False + class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM): pass diff --git a/vllm/model_executor/models/step3_text.py b/vllm/model_executor/models/step3_text.py index 97611d3e140e..1f02c4cc34a8 100644 --- a/vllm/model_executor/models/step3_text.py +++ b/vllm/model_executor/models/step3_text.py @@ -10,10 +10,13 @@ from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.config import AFDConfig, CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) +from vllm.distributed.afd_transfer.afd_connector.metadata import ( + AFDConnectorMetadata) +from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import FusedMoE @@ -206,11 +209,13 @@ def __init__(self, config: ModelConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + afd_config: Optional[AFDConfig] = None, prefix: str = "") -> None: super().__init__() config = config.hf_config self.hidden_size = config.hidden_size rope_scaling = getattr(config, "rope_scaling", None) + self.layer_idx = int(prefix.split("layers.")[1].split(".")[0]) self.self_attn = Step3TextAttention( hidden_size=self.hidden_size, @@ -226,17 +231,15 @@ def __init__(self, rope_scaling=rope_scaling, prefix=f"{prefix}.self_attn") - layer_idx = int(prefix.split("layers.")[1].split(".")[0]) moe_layers_enum = getattr(config, "moe_layers_enum", None) if moe_layers_enum is not None: moe_layers_idx = [ int(i) for i in moe_layers_enum.strip().split(',') ] else: - # Default to 1dense. moe_layers_idx = [i for i in range(1, config.num_hidden_layers)] - if layer_idx in moe_layers_idx: + if self.layer_idx in moe_layers_idx: self.moe = FusedMoEBlock(config=config, quant_config=quant_config, prefix=f"{prefix}.moe") @@ -259,33 +262,151 @@ def __init__(self, self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - def forward( - self, positions: torch.Tensor, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] - ) -> tuple[torch.Tensor, torch.Tensor]: + self.graph_capture_active: bool = False + self.should_capture_graph: bool = (afd_config + and afd_config.is_attention_server) + self.graph_attn_runners_by_stage: dict[int, dict[ + int, tuple[torch.cuda.CUDAGraph, torch.Tensor, torch.Tensor, + Optional[torch.Tensor], torch.Tensor, + Optional[torch.Tensor]]]] = {} + self.graph_capture_sizes: list[int] = [] + + def _capture_cuda_graph_for_size(self, *, stage_idx: int, num_tokens: int, + device: torch.device, + hs_dtype: torch.dtype, + pos_dtype: torch.dtype) -> None: + if not self.graph_capture_active: + return + stage_graphs = self.graph_attn_runners_by_stage.setdefault( + stage_idx, {}) + if num_tokens in stage_graphs: + return + + stream = torch.cuda.Stream() + stream.wait_stream(torch.cuda.current_stream()) + + graph = torch.cuda.CUDAGraph() + with torch.cuda.stream(stream): + static_positions = torch.zeros(num_tokens, + dtype=pos_dtype, + device=device) + static_hidden_states = torch.empty((num_tokens, self.hidden_size), + dtype=hs_dtype, + device=device) + static_residual = torch.empty( + (num_tokens, self.hidden_size), dtype=hs_dtype, + device=device) if self.layer_idx > 0 else None + + self._compute_attn_output(static_hidden_states, static_residual, + static_positions) + + with torch.cuda.graph(graph, stream=stream): + static_hs_out, static_residual_out = self._compute_attn_output( + static_hidden_states, static_residual, static_positions) + + torch.cuda.current_stream().wait_stream(stream) + stage_graphs[num_tokens] = (graph, static_positions, + static_hidden_states, static_residual, + static_hs_out, static_residual_out) + if num_tokens not in self.graph_capture_sizes: + self.graph_capture_sizes.append(num_tokens) + self.graph_capture_sizes.sort() + + def _ensure_graph_for_size(self, *, stage_idx: int, size: int, + device: torch.device, hs_dtype: torch.dtype, + pos_dtype: torch.dtype) -> None: + if not self.graph_capture_active: + return + stage_graphs = self.graph_attn_runners_by_stage.get(stage_idx) + if stage_graphs is None or size not in stage_graphs: + self._capture_cuda_graph_for_size(stage_idx=stage_idx, + num_tokens=size, + device=device, + hs_dtype=hs_dtype, + pos_dtype=pos_dtype) + + def compute_ffn_output(self, hidden_states: torch.Tensor) -> torch.Tensor: + if self.use_moe: + share_output = self.share_expert(hidden_states) + moe_output = self.moe(hidden_states) + return share_output + moe_output + return self.mlp(hidden_states) + + def _compute_attn_output( + self, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: if residual is None: - residual = hidden_states + residual = hidden_states.clone() hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - ) - + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states) hidden_states, residual = self.post_attention_layernorm( hidden_states, residual) + return hidden_states, residual - if self.use_moe: - share_output = self.share_expert(hidden_states) - moe_output = self.moe(hidden_states) - hidden_states = share_output + moe_output + def compute_attn_output( + self, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + positions: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if not self.should_capture_graph: + return self._compute_attn_output(hidden_states, residual, + positions) + + device = hidden_states.device + hs_dtype = hidden_states.dtype + pos_dtype = positions.dtype + num_tokens = hidden_states.shape[0] + afd_stage_idx = 0 + forward_ctx = get_forward_context() + if forward_ctx.afd_metadata is not None: + afd_stage_idx = forward_ctx.afd_metadata.afd_stage_idx + + self._ensure_graph_for_size(stage_idx=afd_stage_idx, + size=num_tokens, + device=device, + hs_dtype=hs_dtype, + pos_dtype=pos_dtype) + + stage_graphs = self.graph_attn_runners_by_stage.get(afd_stage_idx, {}) + chosen_size = None + for size in self.graph_capture_sizes: + if size >= num_tokens and size in stage_graphs: + chosen_size = size + break + + if chosen_size is None: + return self._compute_attn_output(hidden_states, residual, + positions) + + (graph, static_positions, static_hidden_states, static_residual, + static_hs_out, static_residual_out) = stage_graphs[chosen_size] + + static_positions[:num_tokens].copy_(positions) + static_hidden_states[:num_tokens].copy_(hidden_states) + if residual is not None and static_residual is not None: + static_residual[:num_tokens].copy_(residual) + graph.replay() + + out_hidden = static_hs_out[:num_tokens].clone() + if static_residual_out is not None: + out_residual = static_residual_out[:num_tokens].clone() else: - hidden_states = self.mlp(hidden_states) + out_residual = out_hidden.clone() + return out_hidden, out_residual - return hidden_states, residual + def forward( + self, positions: torch.Tensor, hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] + ) -> tuple[torch.Tensor, torch.Tensor]: + hidden_states, residual = self.compute_attn_output( + hidden_states, residual, positions) + ffn_output = self.compute_ffn_output(hidden_states) + return ffn_output, residual @support_torch_compile @@ -310,11 +431,12 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Step3TextDecoderLayer(config=vllm_config. - model_config, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix), + lambda prefix: Step3TextDecoderLayer( + config=vllm_config.model_config, + cache_config=cache_config, + quant_config=quant_config, + afd_config=vllm_config.afd_config, + prefix=prefix), prefix=f"{prefix}.layers", ) if get_pp_group().is_last_rank: @@ -326,6 +448,17 @@ def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: make_empty_intermediate_tensors_factory(["hidden_states"], config.hidden_size)) + def set_graph_capture_mode(self, enabled: bool) -> None: + for idx in range(self.start_layer, self.end_layer): + layer = self.layers[idx] + if hasattr(layer, "graph_capture_active"): + layer.graph_capture_active = enabled + + def compute_ffn_output(self, layer_idx: int, + hidden_states: torch.Tensor) -> torch.Tensor: + layer = self.layers[layer_idx] + return layer.compute_ffn_output(hidden_states) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -347,8 +480,80 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - for layer in islice(self.layers, self.start_layer, self.end_layer): - hidden_states, residual = layer(positions, hidden_states, residual) + forward_ctx = get_forward_context() + afd_metadata = (forward_ctx.afd_metadata + if forward_ctx is not None else None) + + if afd_metadata is not None: + assert residual is None, "PP is not supported with AFD" + num_stages = len(afd_metadata.afd_tokens_start_loc) - 1 + afd_connector = afd_metadata.afd_connector + + stage_hidden_states: list[torch.Tensor] = [] + stage_residual: list[Optional[torch.Tensor]] = [] + stage_positions: list[torch.Tensor] = [] + + for stage_idx in range(num_stages): + start = afd_metadata.afd_tokens_start_loc[stage_idx] + end = start + afd_metadata.afd_tokens_lens[stage_idx] + stage_hidden_states.append(hidden_states[start:end].clone()) + stage_residual.append(residual[start:end].clone( + ) if residual is not None else None) + stage_positions.append(positions[start:end]) + + for layer_idx in range(self.start_layer, self.end_layer): + layer = self.layers[layer_idx] + + for stage_idx in range(num_stages): + afd_metadata.afd_stage_idx = stage_idx + + if layer_idx > 0: + stage_hidden_states[stage_idx].copy_( + afd_connector.recv_ffn_output()) + + current_hidden = stage_hidden_states[stage_idx] + current_residual = stage_residual[stage_idx] + current_positions = stage_positions[stage_idx] + + current_hidden, current_residual = \ + layer.compute_attn_output( + current_hidden, current_residual, + current_positions) + + metadata = AFDConnectorMetadata.create_attention_metadata( + layer_idx=layer_idx, + stage_idx=stage_idx, + seq_len=current_hidden.shape[0], + dtype=current_hidden.dtype, + device=current_hidden.device, + ) + afd_connector.send_attn_output(current_hidden, metadata) + stage_residual[stage_idx] = current_residual + + for stage_idx in range(num_stages): + recv_hidden = afd_connector.recv_ffn_output() + stage_hidden_states[stage_idx].copy_(recv_hidden) + + hidden_states = torch.cat([ + stage_hidden_states[i][:afd_metadata.afd_tokens_lens[i]] + for i in range(num_stages) + ], + dim=0) + + if stage_residual[0] is not None: + residual = torch.cat([ + stage_residual[i][:afd_metadata.afd_tokens_lens[i]] + if stage_residual[i] is not None else + stage_hidden_states[i][:afd_metadata.afd_tokens_lens[i]] + for i in range(num_stages) + ], + dim=0) + else: + residual = None + else: + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, + residual) if not get_pp_group().is_last_rank: return IntermediateTensors({ diff --git a/vllm/model_executor/models/step3_vl.py b/vllm/model_executor/models/step3_vl.py index 2ba5f94ea3b8..512d3fd79a29 100644 --- a/vllm/model_executor/models/step3_vl.py +++ b/vllm/model_executor/models/step3_vl.py @@ -1036,6 +1036,14 @@ def get_input_embeddings( self.config.image_token_id) return inputs_embeds + def set_graph_capture_mode(self, enabled: bool) -> None: + self.language_model.model.set_graph_capture_mode(enabled) + + def compute_ffn_output(self, layer_idx: int, + hidden_states: torch.Tensor) -> torch.Tensor: + return self.language_model.model.compute_ffn_output( + layer_idx, hidden_states) + def forward( self, input_ids: torch.Tensor, diff --git a/vllm/triton_utils/importing.py b/vllm/triton_utils/importing.py index 372200027bf9..2a06a9b7d11e 100644 --- a/vllm/triton_utils/importing.py +++ b/vllm/triton_utils/importing.py @@ -68,7 +68,7 @@ class TritonPlaceholder(types.ModuleType): def __init__(self): super().__init__("triton") - self.__version__ = "3.3.0" + self.__version__ = "3.4.0" self.jit = self._dummy_decorator("jit") self.autotune = self._dummy_decorator("autotune") self.heuristics = self._dummy_decorator("heuristics") diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 20f1904b3be6..d76378dc0998 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Attention layer with FlashAttention.""" from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING, Optional import numpy as np import torch @@ -31,6 +31,9 @@ get_kv_cache_layout) from vllm.v1.kv_cache_interface import AttentionSpec +if TYPE_CHECKING: + from vllm.forward_context import AFDMetadata + logger = init_logger(__name__) # NOTE(woosuk): This is an arbitrary number. Tune it if needed. @@ -182,6 +185,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], self.parallel_config = vllm_config.parallel_config self.cache_config = vllm_config.cache_config self.compilation_config = vllm_config.compilation_config + self.afd_config = vllm_config.afd_config self.num_heads_q = self.model_config.get_num_attention_heads( self.parallel_config) @@ -221,10 +225,25 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], # populated on first build() call. self.aot_sliding_window: Optional[tuple[int, int]] = None + # Initialize stage buffers for AFD + self._stage_buffers: dict[int, dict[str, torch.Tensor]] = {} + self._init_stage_buffers(vllm_config, self.aot_schedule) + def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, + afd_metadata: Optional["AFDMetadata"] = None, fast_build: bool = False) -> FlashAttentionMetadata: + if afd_metadata is not None: + return self._build_with_afd(common_attn_metadata, afd_metadata) + else: + return self._build(common_prefix_len, common_attn_metadata, + fast_build) + + def _build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashAttentionMetadata: """ fast_build disables AOT scheduling, used when there will be few iterations i.e. spec-decode @@ -360,9 +379,222 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens, causal=causal) return attn_metadata + def _init_stage_buffers(self, vllm_config: VllmConfig, + aot_schedule: bool) -> dict[str, torch.Tensor]: + if vllm_config.afd_config: + num_stages = vllm_config.afd_config.num_afd_stages + max_num_seqs = vllm_config.scheduler_config.max_num_seqs + max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens + stage_max_num_reqs = max_num_seqs // num_stages + stage_max_num_tokens = max_num_tokens // num_stages + max_model_len = vllm_config.model_config.max_model_len + num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0) + dcp_world_size = ( + vllm_config.parallel_config.decode_context_parallel_size) + stage_max_num_blocks = max( + cdiv(max_model_len, self.block_size * dcp_world_size), + 1 + num_speculative_tokens) + + for stage_idx in range(num_stages): + self._stage_buffers[stage_idx] = { + 'query_start_loc': + torch.zeros(stage_max_num_reqs + 1, + dtype=torch.int32, + device=self.device), + 'seq_lens': + torch.zeros(stage_max_num_reqs, + dtype=torch.int32, + device=self.device), + 'block_table': + torch.zeros(stage_max_num_reqs, + stage_max_num_blocks, + dtype=torch.int32, + device=self.device), + 'slot_mapping': + torch.zeros(stage_max_num_tokens, + dtype=torch.long, + device=self.device), + } + if aot_schedule: + self._stage_buffers[stage_idx][ + 'scheduler_metadata'] = torch.zeros( + stage_max_num_reqs + 1, + dtype=torch.int32, + device=self.device) + def use_cascade_attention(self, *args, **kwargs) -> bool: return use_cascade_attention(*args, **kwargs) + def _build_with_afd( + self, + common_attn_metadata: CommonAttentionMetadata, + afd_metadata: "AFDMetadata", + ) -> list[Optional[FlashAttentionMetadata]]: + """Split the metadata per AFD stage.""" + + query_start_loc = common_attn_metadata.query_start_loc + seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu + block_table_tensor = common_attn_metadata.block_table_tensor + slot_mapping = common_attn_metadata.slot_mapping + max_query_len = common_attn_metadata.max_query_len + causal = common_attn_metadata.causal + + afd_reqs_start_loc = afd_metadata.afd_reqs_start_loc + afd_tokens_start_loc = afd_metadata.afd_tokens_start_loc + afd_tokens_lens = afd_metadata.afd_tokens_lens + num_stages = len(afd_reqs_start_loc) - 1 + + aot_schedule = self.aot_schedule + + if self.aot_sliding_window is None: + self.aot_sliding_window = (-1, -1) + # For the AOT scheduler we need the sliding window value to be + # constant for all layers to. We have to populate this on the first + # build() call so the layers are constructed (cannot populate) + # in __init__. + if aot_schedule: + sliding_window_configs = _get_sliding_window_configs( + self.vllm_config) + if len(sliding_window_configs) == 1: + sliding_window_config = sliding_window_configs.pop() + if sliding_window_config is not None: + self.aot_sliding_window = sliding_window_config + elif len(sliding_window_configs) > 1: + self.aot_schedule = False + aot_schedule = False + + def schedule(batch_size, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + cache_dtype = self.cache_config.cache_dtype + if cache_dtype.startswith("fp8"): + qkv_dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn( + cache_dtype) + else: + qkv_dtype = self.kv_cache_dtype + if aot_schedule: + return get_scheduler_metadata( + batch_size=batch_size, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads_q, + num_heads_kv=self.num_heads_kv, + headdim=self.headdim, + cache_seqlens=seqlens, + qkv_dtype=qkv_dtype, + cu_seqlens_q=cu_query_lens, + page_size=self.block_size, + causal=causal, + window_size=self.aot_sliding_window, + num_splits=self.max_num_splits, + ) + return None + + stage_metadatas: list[Optional[FlashAttentionMetadata]] = [] + + for stage_idx in range(num_stages): + stage_start_req = afd_reqs_start_loc[stage_idx] + stage_end_req = afd_reqs_start_loc[stage_idx + 1] + stage_start_token = afd_tokens_start_loc[stage_idx] + stage_end_token = afd_tokens_start_loc[ + stage_idx] + afd_tokens_lens[stage_idx] + + stage_num_reqs = stage_end_req - stage_start_req + stage_num_actual_tokens = stage_end_token - stage_start_token + stage_max_seq_len = int( + seq_lens_cpu[stage_start_req:stage_end_req].max()) + + stage_max_query_len = min(max_query_len, stage_num_actual_tokens) + + if stage_num_actual_tokens == 0 or stage_num_reqs == 0: + stage_metadatas.append(None) + continue + + if self.use_full_cuda_graph: + stage_buffer = self._stage_buffers[stage_idx] + + stage_query_start_loc_data = query_start_loc[ + stage_start_req:stage_end_req + + 1] - query_start_loc[stage_start_req] + stage_buffer['query_start_loc'][:stage_num_reqs + 1].copy_( + stage_query_start_loc_data) + stage_buffer['query_start_loc'][stage_num_reqs + 1:].fill_( + stage_query_start_loc_data[stage_num_reqs].item()) + stage_query_start_loc = stage_buffer[ + 'query_start_loc'][:stage_num_reqs + 1] + + stage_seq_lens_data = seq_lens[stage_start_req:stage_end_req] + stage_buffer['seq_lens'][:stage_num_reqs].copy_( + stage_seq_lens_data) + stage_buffer['seq_lens'][stage_num_reqs:].fill_(0) + stage_seq_lens = stage_buffer['seq_lens'][:stage_num_reqs] + + stage_block_table_data = block_table_tensor[ + stage_start_req:stage_end_req] + stage_buffer['block_table'][:stage_num_reqs].copy_( + stage_block_table_data) + stage_block_table_tensor = stage_buffer[ + 'block_table'][:stage_num_reqs] + + stage_slot_mapping_data = slot_mapping[ + stage_start_token:stage_end_token] + stage_buffer['slot_mapping'][:stage_num_actual_tokens].copy_( + stage_slot_mapping_data) + stage_slot_mapping = stage_buffer[ + 'slot_mapping'][:stage_num_actual_tokens] + else: + stage_query_start_loc = query_start_loc[ + stage_start_req:stage_end_req + + 1] - query_start_loc[stage_start_req] + stage_seq_lens = seq_lens[stage_start_req:stage_end_req] + stage_block_table_tensor = block_table_tensor[ + stage_start_req:stage_end_req] + stage_slot_mapping = slot_mapping[ + stage_start_token:stage_end_token] + + if aot_schedule: + stage_scheduler_metadata = schedule( + batch_size=stage_num_reqs, + cu_query_lens=stage_query_start_loc, + max_query_len=stage_max_query_len, + seqlens=stage_seq_lens, + max_seq_len=stage_max_seq_len, + causal=causal) + max_num_splits = 0 + if self.use_full_cuda_graph: + stage_buffer = self._stage_buffers[stage_idx] + n = stage_scheduler_metadata.shape[0] + stage_buffer['scheduler_metadata'][:n].copy_( + stage_scheduler_metadata) + stage_buffer['scheduler_metadata'][n:].fill_(0) + stage_scheduler_metadata = stage_buffer[ + 'scheduler_metadata'][:n] + if stage_num_actual_tokens <= self.max_cudagraph_size: + max_num_splits = self.max_num_splits + + stage_attn_metadata = FlashAttentionMetadata( + num_actual_tokens=stage_num_actual_tokens, + max_query_len=stage_max_query_len, + query_start_loc=stage_query_start_loc, + max_seq_len=stage_max_seq_len, + seq_lens=stage_seq_lens, + block_table=stage_block_table_tensor, + slot_mapping=stage_slot_mapping, + use_cascade=False, + common_prefix_len=0, + scheduler_metadata=stage_scheduler_metadata, + cu_prefix_query_lens=None, + prefix_kv_lens=None, + suffix_kv_lens=None, + prefix_scheduler_metadata=None, + max_num_splits=max_num_splits, + causal=causal) + + stage_metadatas.append(stage_attn_metadata) + return stage_metadatas + class FlashAttentionImpl(AttentionImpl): diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index ead70c910a8f..3fcd31807c3a 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionImpl + from vllm.forward_context import AFDMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.worker.gpu_input_batch import InputBatch @@ -209,7 +210,8 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> M: + fast_build: bool = False, + **kwargs) -> M: """ Central method that builds attention metadata. Some builders (MLA) require reorder_batch to be called prior to build. @@ -241,14 +243,17 @@ def reorder_batch(self, input_batch: "InputBatch", raise NotImplementedError def build_for_cudagraph_capture( - self, common_attn_metadata: CommonAttentionMetadata) -> M: + self, + common_attn_metadata: CommonAttentionMetadata, + afd_metadata: Optional["AFDMetadata"] = None) -> M: """ Build attention metadata for CUDA graph capture. Uses build by default. Subclasses that override this method should call self.build or super().build_for_cudagraph_capture. """ return self.build(common_prefix_len=0, - common_attn_metadata=common_attn_metadata) + common_attn_metadata=common_attn_metadata, + afd_metadata=afd_metadata) def build_for_drafting( self, @@ -803,11 +808,13 @@ class FastPrefillAttentionBuilder(underlying_builder): # type: ignore def build(self, common_prefix_len: int, common_attn_metadata: CommonAttentionMetadata, - fast_build: bool = False) -> AttentionMetadata: + fast_build: bool = False, + **kwargs) -> AttentionMetadata: new_common_attn_metadata =\ make_kv_sharing_fast_prefill_common_attn_metadata(common_attn_metadata) metadata = super().build(common_prefix_len, - new_common_attn_metadata, fast_build) + new_common_attn_metadata, fast_build, + **kwargs) class KVSharingFastPrefillAttentionMetadata( metadata.__class__, # type: ignore diff --git a/vllm/v1/worker/gpu_ffn_model_runner.py b/vllm/v1/worker/gpu_ffn_model_runner.py new file mode 100644 index 000000000000..edb3fc68b00b --- /dev/null +++ b/vllm/v1/worker/gpu_ffn_model_runner.py @@ -0,0 +1,399 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import time +from typing import TYPE_CHECKING, Any, Optional + +import torch +import torch.nn as nn + +from vllm.config import VllmConfig +from vllm.distributed.afd_transfer.afd_connector.factory import ( + AFDConnectorFactory) +from vllm.distributed.communication_op import tensor_model_parallel_all_gather +from vllm.distributed.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_world_group, graph_capture) +from vllm.forward_context import set_forward_context +from vllm.logger import init_logger +from vllm.model_executor.model_loader import get_model_loader +from vllm.utils import DeviceMemoryProfiler, GiB_bytes +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin + +if TYPE_CHECKING: + from vllm.model_executor.model_loader.tensorizer import TensorizerConfig + from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec + +logger = init_logger(__name__) + + +class GPUFFNModelRunner(LoRAModelRunnerMixin): + + def __init__(self, vllm_config: VllmConfig, device: torch.device): + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + self.dtype = self.model_config.dtype + self.load_config = vllm_config.load_config + + self.afd_config = vllm_config.afd_config + if not self.afd_config or not self.afd_config.is_ffn_server: + raise ValueError( + "AFD config must be provided with afd_role='ffn' for FFN server" + ) + + self._counter = 0 + + # Initialize CUDA graph support + self.use_cuda_graph = not self.model_config.enforce_eager + + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + self.cudagraph_batch_sizes = list( + reversed( + self.vllm_config.compilation_config.cudagraph_capture_sizes)) + + # Storage for captured graphs + self._cuda_graphs: dict[tuple[int, int], torch.cuda.CUDAGraph] = { + } # {(layer_idx, num_tokens): CUDAGraph} + self._graph_memory_pool = None + + assert self.afd_config.is_ffn_server + self.connector = AFDConnectorFactory.create_connector( + get_world_group().rank, + get_world_group().local_rank, self.vllm_config) + + if getattr(self.model_config.hf_config, "text_config", + None) is not None: + self.num_layers = ( + self.model_config.hf_config.text_config.num_hidden_layers) + else: + self.num_layers = self.model_config.hf_config.num_hidden_layers + + def get_model(self) -> nn.Module: + return self.model + + def initialize_afd_connector(self) -> None: + self.connector.init_afd_connector() + + def load_model(self, **kwargs) -> None: + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: # noqa: SIM117 + time_before_load = time.perf_counter() + model_loader = get_model_loader(self.load_config) + + if not hasattr(self, "model"): + logger.info("Loading model from scratch...") + self.model = model_loader.load_model( + vllm_config=self.vllm_config, + model_config=self.model_config) + else: + logger.info( + "Model was already initialized. Loading weights inplace..." + ) + model_loader.load_weights(self.model, + model_config=self.model_config) + time_after_load = time.perf_counter() + self.model_memory_usage = m.consumed_memory + logger.info("Model loading took %.4f GiB and %.6f seconds", + self.model_memory_usage / GiB_bytes, + time_after_load - time_before_load) + + logger.info("AFD FFN Model loaded successfully") + + def _get_current_layer_idx(self) -> int: + return (self._counter // + self.afd_config.num_afd_stages) % self.num_layers + + @torch.inference_mode() + def execute_model(self, scheduler_output=None, intermediate_tensors=None): + """Execute FFN computation for a single request""" + # scheduler_output and intermediate_tensors are unused in FFN server + # mode + current_layer_idx = self._get_current_layer_idx() + try: + hidden_states, metadata = self.connector.recv_attn_output() + num_tokens = hidden_states.shape[0] + + # Try to use CUDA graph if available + cuda_graph_info = self._find_cuda_graph(current_layer_idx, + num_tokens) + if cuda_graph_info is not None: + # Use captured CUDA graph for computation + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config): + rank_ffn_output = self._execute_with_cuda_graph( + hidden_states, cuda_graph_info) + else: + # Fallback to eager mode + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config): + rank_ffn_output = self._execute_eager_mode( + hidden_states, current_layer_idx) + + self.connector.send_ffn_output(rank_ffn_output, metadata) + except Exception as e: + raise ValueError( + f"Error computing FFN for layer {current_layer_idx}: {e}" + ) from e + finally: + self._counter += 1 + if (self._counter == self.num_layers * + self.afd_config.num_afd_stages): + self._counter = 0 + return None # FFN server doesn't return ModelRunnerOutput + + def _execute_with_cuda_graph(self, hidden_states: torch.Tensor, + cuda_graph_info: dict): + """Execute FFN computation using captured CUDA graph.""" + graph = cuda_graph_info['graph'] + input_tensor = cuda_graph_info['input_hidden_states'] + output_tensor = cuda_graph_info['output'] + + # Copy input data to graph's input tensor + # Handle padding if necessary + actual_tokens = hidden_states.shape[0] + graph_tokens = input_tensor.shape[0] + + if actual_tokens <= graph_tokens: + # Copy actual data and pad with zeros if needed + input_tensor[:actual_tokens].copy_(hidden_states) + if actual_tokens < graph_tokens: + input_tensor[actual_tokens:].zero_() + else: + raise ValueError( + f"Input size {actual_tokens} exceeds graph capacity " + f"{graph_tokens}") + + # Replay the captured graph + graph.replay() + + # Return only the actual output (without padding) + return output_tensor[:actual_tokens].clone() + + def _execute_eager_mode(self, hidden_states: torch.Tensor, + current_layer_idx: int): + """Execute FFN computation in eager mode (fallback).""" + # Handle TP case: all-gather tensors from all TP ranks + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size > 1: + # All-gather hidden states from all TP ranks + gathered_hidden_states = tensor_model_parallel_all_gather( + hidden_states, dim=0) + ffn_output = self.model.compute_ffn_output(current_layer_idx, + gathered_hidden_states) + # Extract the output corresponding to current rank + start_idx = hidden_states.shape[ + 0] * get_tensor_model_parallel_rank() + end_idx = start_idx + hidden_states.shape[0] + rank_ffn_output = ffn_output[start_idx:end_idx, :] + else: + # Single TP case + rank_ffn_output = self.model.compute_ffn_output( + current_layer_idx, hidden_states) + + return rank_ffn_output + + # Methods required for interface compatibility with GPUModelRunner + def profile_run(self) -> None: + """FFN servers don't need profiling.""" + pass + + def get_kv_cache_spec(self) -> dict[str, "KVCacheSpec"]: + """FFN servers don't use KV cache.""" + return {} + + def initialize_kv_cache(self, kv_cache_config: "KVCacheConfig") -> None: + """FFN servers don't use KV cache.""" + pass + + def _dummy_run(self, num_tokens: int = 1, **kwargs) -> torch.Tensor: + """FFN servers don't need dummy runs.""" + # Return a dummy tensor for interface compatibility + return torch.zeros(num_tokens, + self.model_config.hf_config.hidden_size, + dtype=self.dtype, + device=self.device) + + def capture_model(self) -> int: + """Capture CUDA graphs for FFN operations.""" + if not self.use_cuda_graph: + logger.warning("Skipping CUDA graph capture.") + return 0 + + logger.info("Starting CUDA graph capture for FFN operations...") + start_time = time.perf_counter() + start_free_gpu_memory = torch.cuda.mem_get_info()[0] + + # Create memory pool for graphs + if self._graph_memory_pool is None: + self._graph_memory_pool = torch.cuda.graph_pool_handle() + + # Capture graphs for each layer and different batch sizes + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + with graph_capture(device=self.device): + for layer_idx in range(self.num_layers): + for num_tokens in reversed(self.cudagraph_batch_sizes): + with set_forward_context(attn_metadata=None, + vllm_config=self.vllm_config): + self._capture_graph_for_layer_and_size( + layer_idx, num_tokens) + + end_time = time.perf_counter() + end_free_gpu_memory = torch.cuda.mem_get_info()[0] + elapsed_time = end_time - start_time + cuda_graph_size = start_free_gpu_memory - end_free_gpu_memory + + logger.info( + "FFN CUDA graph capturing finished in %.0f secs, took %.2f GiB", + elapsed_time, cuda_graph_size / (1 << 30)) + return cuda_graph_size + + def _capture_graph_for_layer_and_size(self, layer_idx: int, + num_tokens: int): + """Capture CUDA graph for specific layer and number of tokens.""" + # Create dummy hidden states + dummy_hidden_states = torch.randn( + num_tokens, + self.model_config.hf_config.hidden_size, + dtype=self.dtype, + device=self.device) + + # Warm up the operations for this specific layer + for _ in range( + self.vllm_config.compilation_config.cudagraph_num_of_warmups): + self._run_ffn_computation(dummy_hidden_states, + layer_idx=layer_idx, + capture_mode=True) + + # Create and capture the graph + graph = torch.cuda.CUDAGraph() + + # Start graph capture + with torch.cuda.graph(graph, pool=self._graph_memory_pool): + output = self._run_ffn_computation(dummy_hidden_states, + layer_idx=layer_idx, + capture_mode=True) + + # Store the captured graph with layer and token count as key + self._cuda_graphs[(layer_idx, num_tokens)] = { + 'graph': graph, + 'input_hidden_states': dummy_hidden_states, + 'output': output + } + + logger.debug("Captured CUDA graph for layer %s with %s tokens", + layer_idx, num_tokens) + + def _run_ffn_computation(self, + hidden_states: torch.Tensor, + layer_idx: Optional[int] = None, + capture_mode: bool = False): + """Run FFN computation for graph capture or replay.""" + if layer_idx is None: + current_layer_idx = self._get_current_layer_idx( + ) if not capture_mode else 0 + else: + current_layer_idx = layer_idx + + tp_world_size = get_tensor_model_parallel_world_size() + if tp_world_size > 1: + # Handle TP case: all-gather tensors from all TP ranks + gathered_hidden_states = tensor_model_parallel_all_gather( + hidden_states, dim=0) + ffn_output = self.model.compute_ffn_output(current_layer_idx, + gathered_hidden_states) + + # Extract the output corresponding to current rank + start_idx = hidden_states.shape[ + 0] * get_tensor_model_parallel_rank() + end_idx = start_idx + hidden_states.shape[0] + rank_ffn_output = ffn_output[start_idx:end_idx, :] + else: + # Single TP case + rank_ffn_output = self.model.compute_ffn_output( + current_layer_idx, hidden_states) + + return rank_ffn_output + + def _find_cuda_graph(self, layer_idx: int, num_tokens: int): + """Find the smallest graph that can handle the given layer and + number of tokens.""" + if not self.use_cuda_graph: + return None + + # Find the minimum capture size that can handle num_tokens for this + # layer + for capture_size in self.cudagraph_batch_sizes: + if num_tokens <= capture_size: + return self._cuda_graphs.get((layer_idx, capture_size)) + return None + + def _dummy_sampler_run(self, hidden_states: torch.Tensor) -> None: + """FFN servers don't use samplers.""" + pass + + def update_config(self, overrides: dict[str, Any]) -> None: + """Update configuration for FFN model runner.""" + allowed_config_names = {"load_config", "model_config"} + for config_name, config_overrides in overrides.items(): + assert config_name in allowed_config_names, \ + f"Config `{config_name}` not supported. " \ + f"Allowed configs: {allowed_config_names}" + config = getattr(self, config_name) + from vllm.config import update_config + new_config = update_config(config, config_overrides) + setattr(self, config_name, new_config) + + def reload_weights(self) -> None: + """Reload model weights for FFN model runner.""" + assert getattr(self, "model", None) is not None, \ + "Cannot reload weights before model is loaded." + model_loader = get_model_loader(self.load_config) + logger.info("Reloading weights inplace...") + model = self.get_model() + model_loader.load_weights(model, model_config=self.model_config) + + @property + def lora_config(self): + """FFN servers don't support LoRA.""" + return None + + @property + def is_pooling_model(self) -> bool: + """FFN servers are not pooling models.""" + return False + + def _dummy_pooler_run(self, hidden_states: torch.Tensor): + """FFN servers don't have poolers.""" + pass + + def get_supported_tasks(self): + """Get supported tasks for FFN model runner.""" + return [] + + def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: + """Get number of input tokens for FFN model runner.""" + return num_scheduled_tokens + + def take_draft_token_ids(self, **kwargs): + """FFN servers don't support draft tokens.""" + pass + + @property + def eplb_state(self): + """FFN servers don't have EPLB state.""" + return None + + def ensure_kv_transfer_shutdown(self): + """FFN servers don't need KV transfer shutdown.""" + pass + + def save_tensorized_model( + self, + tensorizer_config: "TensorizerConfig", + ) -> None: + """FFN servers don't support tensorized model saving.""" + raise NotImplementedError( + "FFN servers don't support tensorized model saving") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d4d1f814afc0..5562cbc9a34b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -25,14 +25,15 @@ from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import (CompilationLevel, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config, update_config) +from vllm.distributed.afd_transfer import AFDConnectorFactory from vllm.distributed.eplb.eplb_state import EplbState from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) from vllm.distributed.kv_transfer.kv_connector.utils import copy_kv_blocks from vllm.distributed.parallel_state import ( - get_pp_group, get_tp_group, graph_capture, is_global_first_rank, - prepare_communication_buffer_for_model) -from vllm.forward_context import (BatchDescriptor, DPMetadata, + get_pp_group, get_tp_group, get_world_group, graph_capture, + is_global_first_rank, prepare_communication_buffer_for_model) +from vllm.forward_context import (AFDMetadata, BatchDescriptor, DPMetadata, set_forward_context) from vllm.logger import init_logger from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase @@ -376,6 +377,16 @@ def __init__( # means this layer will perform attention using the keys and values # from the KV cache of `shared_kv_cache_layers[layer_name]`. self.shared_kv_cache_layers: dict[str, str] = {} + + # init AFD config + self.afd_config = vllm_config.afd_config + if self.afd_config and self.afd_config.afd_role == "attention": + self.afd_connector = AFDConnectorFactory.create_connector( + get_world_group().rank, + get_world_group().local_rank, vllm_config) + self.afd_connector.init_afd_connector() + self.num_stages = self.afd_config.num_afd_stages + self.kv_sharing_fast_prefill_eligible_layers: set[str] = set() self.kv_sharing_fast_prefill_logits_indices = None @@ -874,8 +885,9 @@ def _get_encoder_seq_lens( def _prepare_inputs( self, scheduler_output: "SchedulerOutput", - ) -> tuple[dict[str, Any], torch.Tensor, Optional[SpecDecodeMetadata], - np.ndarray, Optional[CommonAttentionMetadata], int]: + ) -> tuple[dict[ + str, Any], torch.Tensor, Optional[SpecDecodeMetadata], np.ndarray, + Optional[CommonAttentionMetadata], int, Optional[AFDMetadata]]: """ :return: tuple[ attn_metadata: layer-to-attention_metadata mapping, @@ -1068,6 +1080,40 @@ def _prepare_inputs( encoder_seq_lens=encoder_seq_lens, ) + if self.afd_config and self.num_stages > 1: + if num_reqs >= self.num_stages: + num_reqs_per_stage = num_reqs // self.num_stages + afd_reqs_start_loc = [ + num_reqs_per_stage * i + for i in range(self.num_stages + 1) + ] + afd_reqs_start_loc[-1] = num_reqs + else: + afd_reqs_start_loc = [i for i in range(num_reqs + 1)] + + # For prefill, compute tokens per stage based on actual token + # counts + afd_tokens_start_loc = [0] + afd_tokens_lens = [] + for stage_idx in range(len(afd_reqs_start_loc) - 1): + stage_start_req = afd_reqs_start_loc[stage_idx] + stage_end_req = afd_reqs_start_loc[stage_idx + 1] + stage_tokens = int(query_start_loc[stage_end_req] - + query_start_loc[stage_start_req]) + afd_tokens_lens.append(stage_tokens) + afd_tokens_start_loc.append(afd_tokens_start_loc[-1] + + stage_tokens) + + afd_metadata = AFDMetadata( + afd_tokens_start_loc=afd_tokens_start_loc, + afd_reqs_start_loc=afd_reqs_start_loc, + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=afd_tokens_lens, + ) + else: + afd_metadata = None + if self.speculative_config and \ spec_decode_common_attn_metadata is None: spec_decode_common_attn_metadata = common_attn_metadata @@ -1096,6 +1142,7 @@ def _prepare_inputs( attn_metadata_i = builder.build( common_prefix_len=common_prefix_len, common_attn_metadata=common_attn_metadata, + #afd_metadata=afd_metadata, **extra_attn_metadata_args) for layer_name in attn_group.layer_names: @@ -1107,7 +1154,7 @@ def _prepare_inputs( return (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens, spec_decode_common_attn_metadata, - max_num_scheduled_tokens) + max_num_scheduled_tokens, afd_metadata) def _compute_cascade_attn_prefix_len( self, @@ -1698,6 +1745,55 @@ def get_dp_padding(self, dtype=torch.int32) return max_tokens_across_dp_cpu - num_tokens, num_tokens_after_padding + def get_afd_padding( + self, afd_tokens_start_loc: list[int], + afd_tokens_lens: list[int]) -> tuple[int, list[int], list[int]]: + + afd_tokens_start_loc = list(afd_tokens_start_loc) + afd_tokens_lens = list(afd_tokens_lens) + original_max_end_loc = afd_tokens_start_loc[-1] + + # 1. Stage count padding: pad to reach required num_stages by adding + # dummy stages. + if len(afd_tokens_start_loc) - 1 < self.num_stages: + missing = self.num_stages - (len(afd_tokens_start_loc) - 1) + for _ in range(missing): + afd_tokens_lens.append(0) + + # 2. Stage-wise DP padding: pad each stage to max tokens across DP + # ranks. + if self.vllm_config.parallel_config.data_parallel_size > 1: + dp_size = self.vllm_config.parallel_config.data_parallel_size + dp_rank = self.vllm_config.parallel_config.data_parallel_rank + _, max_tokens_cpu = DPMetadata.num_stage_tokens_across_dp( + afd_tokens_lens, dp_size, dp_rank) + afd_tokens_lens = max_tokens_cpu.tolist() + + # 3. If using CUDA graphs on attention server, pad each stage length + # up to the next configured cudagraph capture size so that each stage + # matches a captured graph size. + if (self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE + and self.afd_config and self.afd_config.is_attention_server): + + def pad_to_capture_size(n: int) -> int: + for s in self.cudagraph_batch_sizes: + if n <= s // self.num_stages: + return s // self.num_stages + return n + + afd_tokens_lens = [pad_to_capture_size(n) for n in afd_tokens_lens] + + # Recompute start locations from lengths to ensure consistency after + # padding. + new_start_loc = [afd_tokens_start_loc[0]] + running = afd_tokens_start_loc[0] + for length in afd_tokens_lens: + running += length + new_start_loc.append(running) + + num_pad = new_start_loc[-1] - original_max_end_loc + return num_pad, new_start_loc, afd_tokens_lens + def _pool( self, hidden_states: torch.Tensor, @@ -2027,7 +2123,8 @@ def execute_model( # Prepare the decoder inputs. (attn_metadata, logits_indices, spec_decode_metadata, num_scheduled_tokens_np, spec_decode_common_attn_metadata, - max_query_len) = self._prepare_inputs(scheduler_output) + max_query_len, + afd_metadata) = self._prepare_inputs(scheduler_output) finally: if self.prepare_inputs_event is not None: @@ -2044,15 +2141,32 @@ def execute_model( model_kwargs, ) = self._preprocess(scheduler_output, intermediate_tensors) + if afd_metadata: + # Padding for AFD + num_input_tokens = num_scheduled_tokens + (num_pad_afd, afd_tokens_start_loc, + afd_tokens_lens) = self.get_afd_padding( + afd_metadata.afd_tokens_start_loc, + afd_metadata.afd_tokens_lens) + afd_metadata.afd_tokens_start_loc = afd_tokens_start_loc + afd_metadata.afd_tokens_lens = afd_tokens_lens + num_input_tokens += num_pad_afd + num_tokens_across_dp = None + uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( num_scheduled_tokens == self.input_batch.num_reqs * max_query_len) batch_descriptor = BatchDescriptor(num_tokens=num_input_tokens, uniform_decode=uniform_decode) - cudagraph_runtime_mode, batch_descriptor = \ - self.cudagraph_dispatcher.dispatch(batch_descriptor) + if self.afd_config: + cudagraph_runtime_mode = CUDAGraphMode.NONE + else: + cudagraph_runtime_mode, batch_descriptor = \ + self.cudagraph_dispatcher.dispatch(batch_descriptor) + if afd_metadata is None: + afd_metadata = AFDMetadata(0, 0, 0, self.afd_connector, 0) # Run the model. # Use persistent buffers for CUDA graphs. with (set_forward_context( @@ -2062,6 +2176,7 @@ def execute_model( num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, + afd_metadata=afd_metadata, ), record_function_or_nullcontext("Forward"), self.maybe_get_kv_connector_output(scheduler_output) as kv_connector_output): @@ -2671,9 +2786,48 @@ def _dummy_run( CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL } - # Padding for DP - num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) - num_tokens += num_pad + # AFD padding (stage-level alignment) before DP padding + if self.vllm_config.afd_config: + if num_tokens > self.vllm_config.afd_config.num_afd_stages: + num_tokens_per_stage = ( + num_tokens // self.vllm_config.afd_config.num_afd_stages) + max_num_reqs = self.scheduler_config.max_num_seqs + num_reqs = min(num_tokens, max_num_reqs) + num_reqs_per_stage = ( + num_reqs // self.vllm_config.afd_config.num_afd_stages) + afd_tokens_start_loc = [ + i * num_tokens_per_stage + for i in range(self.vllm_config.afd_config.num_afd_stages + + 1) + ] + afd_reqs_start_loc = [ + i * num_reqs_per_stage + for i in range(self.vllm_config.afd_config.num_afd_stages + + 1) + ] + afd_tokens_lens = [ + num_tokens_per_stage + for _ in range(self.vllm_config.afd_config.num_afd_stages) + ] + afd_tokens_lens[-1] += num_tokens % num_tokens_per_stage + afd_tokens_start_loc[-1] = num_tokens + afd_metadata = AFDMetadata( + afd_tokens_start_loc=afd_tokens_start_loc, + afd_reqs_start_loc=afd_reqs_start_loc, + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=afd_tokens_lens, + ) + else: + afd_metadata = AFDMetadata( + afd_tokens_start_loc=list(range(num_tokens + 1)), + afd_reqs_start_loc=list(range(num_tokens + 1)), + afd_stage_idx=0, + afd_connector=self.afd_connector, + afd_tokens_lens=[1] * num_tokens, + ) + else: + afd_metadata = None # If cudagraph_mode.decode_mode() == FULL and # cudagraph_mode.separate_routine(). This means that we are using @@ -2768,11 +2922,25 @@ def _dummy_run( causal=True) for attn_group in self.attn_groups[kv_cache_group_id]: - attn_metadata_i = attn_group.metadata_builder\ - .build_for_cudagraph_capture(common_attn_metadata) + attn_metadata_i = (attn_group.metadata_builder. + build_for_cudagraph_capture( + common_attn_metadata, afd_metadata)) for layer_name in kv_cache_group_spec.layer_names: attn_metadata[layer_name] = attn_metadata_i + if afd_metadata: + (num_afd_pad, afd_tokens_start_loc, + afd_tokens_lens) = self.get_afd_padding( + afd_metadata.afd_tokens_start_loc, + afd_metadata.afd_tokens_lens) + afd_metadata.afd_tokens_start_loc = afd_tokens_start_loc + afd_metadata.afd_tokens_lens = afd_tokens_lens + num_tokens += num_afd_pad + num_tokens_across_dp = None + + num_pad, num_tokens_across_dp = self.get_dp_padding(num_tokens) + num_tokens += num_pad + with self.maybe_dummy_run_with_lora(self.lora_config, num_scheduled_tokens, remove_lora): model_kwargs = self._init_model_kwargs(num_tokens) @@ -2824,7 +2992,8 @@ def _dummy_run( num_tokens=num_tokens, num_tokens_across_dp=num_tokens_across_dp, cudagraph_runtime_mode=cudagraph_runtime_mode, - batch_descriptor=batch_descriptor): + batch_descriptor=batch_descriptor, + afd_metadata=afd_metadata): outputs = self.model( input_ids=input_ids, positions=positions, @@ -3094,33 +3263,52 @@ def freeze_gc(): # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. set_cudagraph_capturing_enabled(True) - with freeze_gc(), graph_capture(device=self.device): - cudagraph_mode = self.compilation_config.cudagraph_mode - if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - cudagraph_runtime_mode = cudagraph_mode.mixed_mode() - - compilation_cases = list(reversed(self.cudagraph_batch_sizes)) - self._capture_cudagraphs( - compilation_cases, - cudagraph_runtime_mode=cudagraph_runtime_mode, - uniform_decode=False) - - # Capture full cudagraph for uniform decode batches if we have - # dont already have full mixed prefill-decode cudagraphs - if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ - cudagraph_mode.separate_routine(): - max_num_tokens = self.scheduler_config.max_num_seqs * \ - self.uniform_decode_query_len - decode_cudagraph_batch_sizes = [ - x for x in self.cudagraph_batch_sizes if - x <= max_num_tokens and x >= self.uniform_decode_query_len - ] - compilation_cases_decode = list( - reversed(decode_cudagraph_batch_sizes)) - self._capture_cudagraphs( - compilation_cases=compilation_cases_decode, - cudagraph_runtime_mode=CUDAGraphMode.FULL, - uniform_decode=True) + # FIXME: hack for afd layerwise cg capture + if self.afd_config and self.afd_config.is_attention_server: + with freeze_gc(): + if hasattr(self.model, "set_graph_capture_mode"): + self.model.set_graph_capture_mode(True) + for num_tokens in tqdm(reversed(self.cudagraph_batch_sizes)): + self._dummy_run(num_tokens, + cudagraph_runtime_mode=CUDAGraphMode.NONE, + force_attention=True, + uniform_decode=False, + skip_eplb=True, + remove_lora=True) + # Disable capture mode after prewarm + if hasattr(self.model, "set_graph_capture_mode"): + self.model.set_graph_capture_mode(False) + else: + with freeze_gc(), graph_capture(device=self.device): + cudagraph_mode = self.compilation_config.cudagraph_mode + + if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: + cudagraph_runtime_mode = cudagraph_mode.mixed_mode() + + compilation_cases = list( + reversed(self.cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases, + cudagraph_runtime_mode=cudagraph_runtime_mode, + uniform_decode=False) + + # Capture full cudagraph for uniform decode batches if we have + # dont already have full mixed prefill-decode cudagraphs + if cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and \ + cudagraph_mode.separate_routine(): + max_num_tokens = self.scheduler_config.max_num_seqs * \ + self.uniform_decode_query_len + decode_cudagraph_batch_sizes = [ + x for x in self.cudagraph_batch_sizes + if x <= max_num_tokens + and x >= self.uniform_decode_query_len + ] + compilation_cases_decode = list( + reversed(decode_cudagraph_batch_sizes)) + self._capture_cudagraphs( + compilation_cases=compilation_cases_decode, + cudagraph_runtime_mode=CUDAGraphMode.FULL, + uniform_decode=True) # Disable cudagraph capturing globally, so any unexpected cudagraph # capturing will be detected and raise an error after here. @@ -3620,6 +3808,11 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: f"{layer.impl.__class__.__name__} " "does not return the softmax lse for decode.") + def initialize_afd_connector(self) -> None: + """Initialize AFD connector if available.""" + if hasattr(self, 'afd_connector') and self.afd_connector: + self.afd_connector.init_afd_connector() + def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: """ Add encoder-only layers to the KV cache config. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 6855526583f0..f55d20d3b128 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -31,6 +31,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput) from vllm.v1.utils import report_usage_stats +from vllm.v1.worker.gpu_ffn_model_runner import GPUFFNModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.utils import is_residual_scattered_for_sp from vllm.v1.worker.worker_base import WorkerBase @@ -199,8 +200,13 @@ def init_device(self): set_random_seed(self.model_config.seed) # Construct the model runner - self.model_runner: GPUModelRunner = GPUModelRunner( - self.vllm_config, self.device) + self.model_runner: GPUModelRunner | GPUFFNModelRunner + if (self.vllm_config.afd_config + and self.vllm_config.afd_config.is_ffn_server): + self.model_runner = GPUFFNModelRunner(self.vllm_config, + self.device) + else: + self.model_runner = GPUModelRunner(self.vllm_config, self.device) if self.rank == 0: # If usage stat is enabled, collect relevant info. @@ -425,8 +431,18 @@ def get_supported_tasks(self) -> tuple[SupportedTask, ...]: @torch.inference_mode() def execute_model( self, - scheduler_output: "SchedulerOutput", + scheduler_output: Optional["SchedulerOutput"] = None, ) -> Optional[Union[ModelRunnerOutput, AsyncModelRunnerOutput]]: + # FFN server mode: direct execution without pipeline parallelism + if (self.vllm_config.afd_config + and self.vllm_config.afd_config.is_ffn_server): + return self.model_runner.execute_model(scheduler_output) + + if scheduler_output is None: + raise ValueError( + "scheduler_output is required in normal inference mode") + + # Normal inference mode intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens @@ -505,6 +521,53 @@ def check_health(self) -> None: # worker will always be healthy as long as it's running. return + def start_ffn_server_loop(self) -> None: + """Start FFN server loop for AFD FFN workers""" + if not (self.vllm_config.afd_config + and self.vllm_config.afd_config.is_ffn_server): + return + + self.model_runner.capture_model() + self.model_runner.initialize_afd_connector() + + if self.profiler: + self.profiler.start() + for _ in range(1000): # FIXME: hardcoded profiler iterations + self.model_runner.execute_model(scheduler_output=None) + torch.cuda.synchronize() # Ensure GPU operations complete + self.profiler.stop() + print(self.profiler.key_averages().table( + sort_by="self_cuda_time_total")) + + import threading + self._ffn_shutdown_event = threading.Event() + + def ffn_worker_loop(): + # Set CUDA device for this thread (thread-local context) + torch.cuda.set_device(self.device) + logger.info("FFN worker loop started") + + try: + while not self._ffn_shutdown_event.is_set(): + # Execute FFN computation + self.model_runner.execute_model(scheduler_output=None) + except Exception as e: + logger.error("FFN worker loop error: %s", e) + raise + + self._ffn_thread = threading.Thread(target=ffn_worker_loop, + daemon=True) + self._ffn_thread.start() + logger.info("FFN server loop started in worker") + + def stop_ffn_server_loop(self) -> None: + """Stop FFN server loop""" + if hasattr(self, '_ffn_shutdown_event'): + self._ffn_shutdown_event.set() + if hasattr(self, '_ffn_thread'): + self._ffn_thread.join(timeout=5) + logger.info("FFN server loop stopped") + def _eplb_before_scale_down(self, old_ep_size: int, new_ep_size: int) -> None: from vllm.distributed.parallel_state import get_ep_group