diff --git a/vllm/distributed/ec_transfer/ec_lookup_buffer/mooncake_store.py b/vllm/distributed/ec_transfer/ec_lookup_buffer/mooncake_store.py index 7e60801f637f..727066993dd6 100644 --- a/vllm/distributed/ec_transfer/ec_lookup_buffer/mooncake_store.py +++ b/vllm/distributed/ec_transfer/ec_lookup_buffer/mooncake_store.py @@ -22,7 +22,9 @@ from vllm.config import VllmConfig from vllm.distributed.ec_transfer.utils.tensor_memory_pool import ( TensorMemoryPool) +from vllm.distributed.ec_transfer.utils.transfer_engine import get_global_te from vllm.logger import init_logger +from vllm.utils import get_ip DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB @@ -148,15 +150,30 @@ def __init__(self, vllm_config: "VllmConfig"): logger.info(" fast_transfer_buffer_size: %s", self.config.fast_transfer_buffer_size) - self.store.setup( - self.config.local_hostname, - self.config.metadata_server, - self.config.global_segment_size, - self.config.local_buffer_size, - self.config.protocol, - self.config.device_name, - self.config.master_server_address, - ) + if self.config.protocol == "ascend": + # if ascend direct transport is on, + # global transfer engine for an instance is required + local_hostname = get_ip() + transfer_engine = get_global_te(local_hostname, + device_name=None) + self.local_seg = local_hostname + ":" + str( + transfer_engine.get_rpc_port()) + self.store.setup(self.local_seg, "P2PHANDSHAKE", + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address, + transfer_engine.get_engine()) + else: + self.store.setup( + self.config.local_hostname, + self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, + self.config.device_name, + self.config.master_server_address, + ) except ValueError as e: logger.error("Configuration loading failed: %s", e) @@ -194,7 +211,8 @@ def close(self): if self.config.fast_transfer: self.store.unregister_buffer(self.tensor_pool.base_address, self.config.fast_transfer_buffer_size) - self.tensor_pool.cleanup() + with self.pool_lock: + self.tensor_pool.cleanup() self.put_loop.call_soon_threadsafe(self.put_loop.stop) self.put_thread.join() @@ -276,6 +294,7 @@ def _zero_copy_batch_get(self, keys: list[str], with self.pool_lock: self.tensor_pool.batch_free(buffer_addrs) logger.error("batch_get_into failed: %s", str(e)) + return results # NOTE: should I delay free buffer for id, addr, dtype, shape, read_byte in zip(exist_ids, buffer_addrs, @@ -283,8 +302,9 @@ def _zero_copy_batch_get(self, keys: list[str], buffer_shapes, read_bytes): if read_byte > 0: - results[id] = self.tensor_pool.load_tensor( - addr, dtype, shape, device) + with self.pool_lock: + results[id] = self.tensor_pool.load_tensor( + addr, dtype, shape, device) with self.pool_lock: self.tensor_pool.batch_free(buffer_addrs) @@ -398,6 +418,7 @@ async def _zero_copy_batch_put(self, keys: list[str], ",".join(keys), str(e), ) + raise try: # Zero-copy put @@ -417,6 +438,7 @@ async def _zero_copy_batch_put(self, keys: list[str], ",".join(keys), str(e), ) + raise finally: if buffer_addrs: with self.pool_lock: diff --git a/vllm/distributed/ec_transfer/utils/transfer_engine.py b/vllm/distributed/ec_transfer/utils/transfer_engine.py new file mode 100644 index 000000000000..4c69e09200e1 --- /dev/null +++ b/vllm/distributed/ec_transfer/utils/transfer_engine.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import ipaddress +import threading +from typing import Optional + +from mooncake.engine import TransferEngine # type: ignore + +_global_te = None +_global_te_lock = threading.Lock() + + +def get_global_te(hostname: str, device_name: Optional[str]): + try: + ip = ipaddress.ip_address(hostname) + if isinstance(ip, ipaddress.IPv6Address): + raise RuntimeError( + "The backend of mooncake's Ascend Direct Xfer Library " + "currently does not support IPv6.") + except ValueError: + pass + + global _global_te + if _global_te is None: + with _global_te_lock: + # Double-Checked Locking + if _global_te is None: + if TransferEngine is None: + raise RuntimeError("mooncake is not available") + transfer_engine = TransferEngine() + device_name = device_name if device_name is not None else "" + ret_value = transfer_engine.initialize(hostname, + "P2PHANDSHAKE", + "ascend", device_name) + if ret_value != 0: + raise RuntimeError( + f"TransferEngine initialization failed with " + f"ret_value: {ret_value}") + _global_te = transfer_engine + return _global_te