diff --git a/EPD_README.md b/EPD_README.md new file mode 100644 index 000000000000..4cc2840c95fb --- /dev/null +++ b/EPD_README.md @@ -0,0 +1,29 @@ +1. 部署vLLM和vLLM ascend +1.1 下载vLLM和vLLM ascend代码包 +wget https://github.com/hsliuustc0106/vllm/archive/refs/tags/v0.9.1-EPD.tar.gz +wget -0 v0.9.1-ascend-EPD.tar.gz https://github.com/hsliuustc0106/vllm-ascend/archive/refs/tags/v0.9.1-EPD.tar.gz + +1.2 解压代码包 +tar -zxvf v0.9.1-EPD.tar.gz +tar -zxvf v0.9.1-ascend-EPD.tar.gz + +1.3 部署vLLM +cd vllm-0.9.1-EPD +pip install -r requirements/build.txt + +SETUPTOOLS_SCM_PRETEND_VERSION=0.9.1 VLLM_TARGET_DEVICE=empty pip install -e . + +1.4 部署vLLM-ascend +cd ../vllm-ascend-0.9.1-EPD +pip install -e . + +2. 启动E + PD和推理服务(python API + zmq) +cd ../vllm-0.9.1-EPD/examples/offline_inference/epd +增加模型patch和修改run.sh参数 +在vllm-0.9.1-EPD/examples/offline_inference/epd/chat_with_image.py文件下修改推理参数 +执行bash run.sh进行推理 + +3. 启动E + PD和推理服务(zmq + http) +cd ../vllm-0.9.1-EPD/examples/offline_inference/epd +增加模型patch和修改run_zmq_http.sh参数 +执行bash run_zmq_http.sh启动推理服务 diff --git a/examples/offline_inference/epd/run_zmq_http.sh b/examples/offline_inference/epd/run_zmq_http.sh new file mode 100644 index 000000000000..3f059092dc4f --- /dev/null +++ b/examples/offline_inference/epd/run_zmq_http.sh @@ -0,0 +1,182 @@ +#!/usr/bin/env bash + +set -euo pipefail + +CURRENT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) + +MAX_NUM_SEQS_ENCODER="${MAX_NUM_SEQS_ENCODER:-1}" +MAX_NUM_SEQS_PD="${MAX_NUM_SEQS_PD:-128}" +ENCODER_ADDR_PREFIX="${ENCODER_ADDR_PREFIX:-/tmp/encoder}" +PD_ADDR_PREFIX="${PD_ADDR_PREFIX:-/tmp/prefill_decode}" +PROXY_ADDR="${PROXY_ADDR:-/tmp/proxy}" +PID_FILE="${PID_FILE:-${CURRENT_DIR}/pid.txt}" + +MODEL="" +SHARED_STORAGE_PATH="/dev/shm/epd" +GPU_UTILIZATION_ENCODER=0.0 +GPU_UTILIZATION_PD=0.95 +ENCODER_DEVICE_ID_BASE=0 +ENCODER_NUMBER=1 +PD_DEVICE_ID_BASE=1 +PD_NUMBER=2 +LOG_PATH="${CURRENT_DIR}/logs" +IMAGE_FILE_PATH="" + +function start_encoder() { + local dev_id=$1 + local address=$2 + local proxy_address=$3 + local log_file=$4 + + VLLM_USE_V1=1 ASCEND_RT_VISIBLE_DEVICES=$dev_id python -m vllm.entrypoints.disaggregated.worker \ + --proxy-addr $proxy_address \ + --worker-addr $address \ + --model $MODEL \ + --gpu-memory-utilization $GPU_UTILIZATION_ENCODER \ + --max-num-seqs $MAX_NUM_SEQS_ENCODER \ + --enforce-eager \ + --no-enable-prefix-caching \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_producer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$SHARED_STORAGE_PATH"'" + } + }' \ + >"$log_file" 2>&1 & + echo $! >> "$PID_FILE" +} + +function start_pd() { + local dev_id=$1 + local address=$2 + local proxy_address=$3 + local log_file=$4 + + VLLM_USE_V1=1 ASCEND_RT_VISIBLE_DEVICES=$dev_id python -m vllm.entrypoints.disaggregated.worker \ + --proxy-addr $proxy_address \ + --worker-addr $address \ + --model $MODEL \ + --gpu-memory-utilization $GPU_UTILIZATION_PD \ + --max-num-seqs $MAX_NUM_SEQS_PD \ + --enforce-eager \ + --ec-transfer-config '{ + "ec_connector": "ECSharedStorageConnector", + "ec_role": "ec_consumer", + "ec_connector_extra_config": { + "shared_storage_path": "'"$SHARED_STORAGE_PATH"'" + } + }' \ + >"$log_file" 2>&1 & + echo $! >> "$PID_FILE" +} + +function start_all() { + mkdir -p "$LOG_PATH" + if [ -f "$PID_FILE" ]; then + rm "$PID_FILE" + fi + + if [ -d "$SHARED_STORAGE_PATH" ]; then + rm -rf "$SHARED_STORAGE_PATH" + fi + mkdir -p "$SHARED_STORAGE_PATH" + + echo "Starting encoder workers..." + for ((i=0; i /dev/null 2>&1; then + echo "Stopping process $pid" + kill "$pid" + for i in {1..5}; do + sleep 1 + if ! kill -0 "$pid" > /dev/null 2>&1; then + break + fi + done + if kill -0 "$pid" > /dev/null 2>&1; then + echo "Process $pid did not exit, killing with -9" + kill -9 "$pid" + fi + fi + done < "$PID_FILE" + rm "$PID_FILE" + else + echo "No PID file found. Are the workers running?" + fi + + if [ -d "$SHARED_STORAGE_PATH" ]; then + rm -rf "$SHARED_STORAGE_PATH" + echo "Removed shared storage at $SHARED_STORAGE_PATH" + fi +} + +function print_help() { + echo "Usage: $0 [--model MODEL] [--shared-storage-path PATH] + [--gpu-utilization-encoder FLOAT] [--gpu-utilization-pd FLOAT] + [--encoder-device-id-base INT] [--encoder-number INT] + [--pd-device-id-base INT] [--pd-number INT] + [--image-file-path PATH] [--log-path PATH] + [--stop] [--help]" +} + +while [[ "$#" -gt 0 ]]; do + case $1 in + --model) MODEL="$2"; shift ;; + --shared-storage-path) SHARED_STORAGE_PATH="$2"; shift ;; + --gpu-utilization-encoder) GPU_UTILIZATION_ENCODER="$2"; shift ;; + --gpu-utilization-pd) GPU_UTILIZATION_PD="$2"; shift ;; + --encoder-device-id-base) ENCODER_DEVICE_ID_BASE="$2"; shift ;; + --encoder-number) ENCODER_NUMBER="$2"; shift ;; + --pd-device-id-base) PD_DEVICE_ID_BASE="$2"; shift ;; + --pd-number) PD_NUMBER="$2"; shift ;; + --log-path) LOG_PATH="$2"; shift ;; + --image-file-path) IMAGE_FILE_PATH="$2"; shift ;; + --stop) stop_all; exit 0 ;; + --help) print_help; exit 0 ;; + *) echo "Unknown parameter passed: $1"; exit 1 ;; + esac + shift +done + +if [ -z "$MODEL" ]; then + echo "Error: --model is required." + exit 1 +fi + +if [ -z "$IMAGE_FILE_PATH" ]; then + echo "Error: --image-file-path is required." + exit 1 +fi + +start_all + +python zmq_http_proxy.py \ + --host "0.0.0.0" \ + --port "8000" \ + --proxy-addr $PROXY_ADDR \ + --encode-addrs $(for ((i=0; i"$LOG_PATH/proxy.log" 2>&1 & \ No newline at end of file diff --git a/vllm/disaggregated/disagg_worker.py b/vllm/disaggregated/disagg_worker.py index 1cf18b0aa7ab..70de978b3bbe 100644 --- a/vllm/disaggregated/disagg_worker.py +++ b/vllm/disaggregated/disagg_worker.py @@ -123,4 +123,6 @@ def _decode_mm_data(mm_data: dict[str, any]) -> dict[str, any]: decoded_img = np.frombuffer(bytes( img["data"]), dtype=img["dtype"]).reshape(img["shape"]) decoded_images.append(decoded_img) + if len(decoded_images) == 1: + decoded_images = decoded_images[0] return {"image": decoded_images} diff --git a/vllm/disaggregated/protocol.py b/vllm/disaggregated/protocol.py index 511c9375d563..d273fc39e097 100644 --- a/vllm/disaggregated/protocol.py +++ b/vllm/disaggregated/protocol.py @@ -32,6 +32,7 @@ class ResponseType: class GenerationResponse(msgspec.Struct): request_id: str text: str + prompt_token_ids: list[int] token_ids: list[int] finish_reason: Optional[str] = None stop_reason: Optional[str] = None @@ -46,6 +47,7 @@ def from_request_output( return GenerationResponse( request_id=request_output.request_id, text=out.text, + prompt_token_ids=request_output.prompt_token_ids, token_ids=out.token_ids, finish_reason=out.finish_reason, stop_reason=out.stop_reason, diff --git a/vllm/disaggregated/zmq_http_proxy.py b/vllm/disaggregated/zmq_http_proxy.py new file mode 100644 index 000000000000..b0e8b6bf657a --- /dev/null +++ b/vllm/disaggregated/zmq_http_proxy.py @@ -0,0 +1,639 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import argparse +import asyncio +import os +import uuid +from collections.abc import AsyncGenerator, AsyncIterator, Mapping +from typing import Optional, Union + +import msgspec +import numpy as np +import random +import uvicorn +import zmq +import zmq.asyncio +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse + +from vllm.config import DecodingConfig, ModelConfig, VllmConfig +from vllm.core.scheduler import SchedulerOutputs +from vllm.disaggregated.protocol import (FailureResponse, GenerationRequest, + GenerationResponse, RequestType, + ResponseType) +from vllm.engine.protocol import EngineClient +from vllm.inputs.data import PromptType +from vllm.inputs.preprocess import InputPreprocessor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.outputs import CompletionOutput, PoolingRequestOutput, RequestOutput +from vllm.pooling_params import PoolingParams +from vllm.prompt_adapter.request import PromptAdapterRequest +from vllm.sampling_params import SamplingParams +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from vllm.utils import Device +from vllm.v1.outputs import SamplerOutput + +logger = init_logger(__name__) + +app = FastAPI() + + +class Proxy(EngineClient): + """ + Proxy + """ + + def __init__( + self, + proxy_addr: str, + encode_addr_list: list[str], + pd_addr_list: list[str], + model_name: str, + ): + self.queues: dict[str, asyncio.Queue] = {} + + self.encoder = msgspec.msgpack.Encoder() + + self.ctx = zmq.asyncio.Context() + self.proxy_addr = f"ipc://{proxy_addr}" + self.encode_addr_list = [f"ipc://{addr}" for addr in encode_addr_list] + self.pd_addr_list = [f"ipc://{addr}" for addr in pd_addr_list] + self.to_encode_sockets = [] + for addr in self.encode_addr_list: + socket = self.ctx.socket(zmq.constants.PUSH) + socket.connect(addr) + self.to_encode_sockets.append(socket) + self.to_pd_sockets = [] + for addr in self.pd_addr_list: + socket = self.ctx.socket(zmq.constants.PUSH) + socket.connect(addr) + self.to_pd_sockets.append(socket) + + self.output_handler: Optional[asyncio.Task] = None + + # Dummy: needed for EngineClient Protocol. + self.model_config = ModelConfig( + model=model_name, + tokenizer=model_name, + tokenizer_mode="auto", + trust_remote_code=False, + dtype="auto", + task="generate", + seed=42, + ) + + # Dummy: needed for EngineClient Protocol. + # TODO: refactor OAI Server to avoid needing this. + self.tokenizer = TokenizerGroup(**dict( + tokenizer_id=self.model_config.tokenizer, + enable_lora=False, + max_num_seqs=1024, + max_loras=0, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision, + truncation_side=self.model_config.truncation_side, + )) + + def shutdown(self): + self.ctx.destroy() + if (task := self.output_handler) is not None: + task.cancel() + + socket_path = self.proxy_addr.replace("ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + + async def _run_encode( + self, + request: GenerationRequest, + q: asyncio.Queue[Union[Exception, GenerationResponse]], + ) -> None: + """ + Send the encode request to one encoder worker. + The encoder worker is selected based on hashing the request ID. + """ + if not self.to_encode_sockets: + raise RuntimeError( + "No encode workers configured: encode_addr_list is empty.") + + try: + payload = self.encoder.encode(request) + except Exception as e: + raise RuntimeError("Failed to serialize GenerationRequest") from e + + msg = (RequestType.ENCODE, payload) + idx = random.randint(0, len(self.to_encode_sockets) - 1) + socket = self.to_encode_sockets[idx] + await socket.send_multipart(msg, copy=False) + + response = await q.get() + logger.info("Encode response: %s", response) + if isinstance(response, Exception): + raise response + + async def _run_pd( + self, + request: GenerationRequest, + q: asyncio.Queue[Union[Exception, GenerationResponse]], + ): + """ + Send the generation request to a PD worker and yield its response. + The PD worker is selected based on hashing the request ID. + """ + if not self.to_pd_sockets: + raise RuntimeError( + "No PD workers configured: pd_addr_list is empty.") + + try: + payload = self.encoder.encode(request) + except Exception as e: + raise RuntimeError("Failed to serialize GenerationRequest") from e + + msg = (RequestType.GENERATION, payload) + idx = random.randint(0, len(self.to_pd_sockets) - 1) + socket = self.to_pd_sockets[idx] + await socket.send_multipart(msg, copy=False) + + finished = False + while not finished: + response = await q.get() + if isinstance(response, Exception): + raise response + finished = response.finish_reason is not None + yield response + + def _to_request_output(self, resp: GenerationResponse) -> RequestOutput: + """Convert a PD/Generate response to vLLM RequestOutput. + + This creates a single CompletionOutput. If the response includes + text/token_ids attributes, they are used; otherwise defaults are used. + """ + text = getattr(resp, "text", "") + token_ids = getattr(resp, "token_ids", []) + + completion = CompletionOutput( + index=0, + text=text, + token_ids=token_ids, + cumulative_logprob=None, + logprobs=None, + finish_reason=resp.finish_reason, + stop_reason=resp.stop_reason, + ) + + return RequestOutput( + request_id=resp.request_id, + prompt=None, + prompt_token_ids=resp.prompt_token_ids, + prompt_logprobs=None, + outputs=[completion], + finished=resp.finish_reason is not None, + ) + + async def generate( + self, + prompt: PromptType, + sampling_params: SamplingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + priority: int = 0, + data_parallel_rank: Optional[int] = None, + ): + if self.output_handler is None: + self.output_handler = asyncio.create_task( + self._run_output_handler()) + if not request_id: + request_id = uuid.uuid4().hex + + q = asyncio.Queue() + self.queues[request_id] = q + + # Support both raw string prompts and dict prompts with multimodal data + prompt_text = prompt["prompt"] if isinstance(prompt, dict) else prompt + + request = GenerationRequest( + request_id=request_id, + prompt=prompt_text, + sampling_params=sampling_params, + ) + + if _has_mm_data(prompt): + request.multi_modal_data = _encode_mm_data( + prompt["multi_modal_data"]) + await self._run_encode(request, q) + + # TODO: support pd separation + async for pd_response in self._run_pd(request, q): + yield self._to_request_output(pd_response) + + async def _run_output_handler(self) -> None: + """Background task to pull responses and dispatch to request queues. + + Binds a PULL socket on proxy_addr and receives multipart messages of + the form (response_type, payload). Decodes payload into a + GenerationResponse and enqueues it into the corresponding request queue + keyed by request_id. + """ + socket: Optional[zmq.asyncio.Socket] = None + decoder = msgspec.msgpack.Decoder(GenerationResponse) + failure_decoder = msgspec.msgpack.Decoder(FailureResponse) + + try: + socket = self.ctx.socket(zmq.constants.PULL) + socket.bind(self.proxy_addr) + + while True: + resp_type, payload = await socket.recv_multipart() + if (resp_type == ResponseType.GENERATION + or resp_type == ResponseType.ENCODE): + resp = decoder.decode(payload) + self.queues[resp.request_id].put_nowait(resp) + elif resp_type == ResponseType.FAILURE: + resp = failure_decoder.decode(payload) + raise RuntimeError(f"Worker error: {resp.error_message}") + else: + raise RuntimeError( + f"Unknown response type from worker: {resp_type}") + except Exception as e: + # TODO: maybe there is a more fine-grained way to handle errors. + # For now, if there is any error, we terminate all requests. + for q in self.queues.values(): + q.put_nowait(e) + finally: + if socket is not None: + socket.close(linger=0) + + def encode( + self, + prompt: PromptType, + pooling_params: PoolingParams, + request_id: str, + lora_request: Optional[LoRARequest] = None, + trace_headers: Optional[Mapping[str, str]] = None, + priority: int = 0, + ) -> AsyncGenerator[PoolingRequestOutput, None]: + raise NotImplementedError + + async def abort(self, request_id: str) -> None: + raise NotImplementedError + + async def get_model_config(self) -> ModelConfig: + return self.model_config + + async def get_decoding_config(self) -> DecodingConfig: + raise NotImplementedError + + async def get_input_preprocessor(self) -> InputPreprocessor: + raise NotImplementedError + + async def get_tokenizer( + self, + lora_request: Optional[LoRARequest] = None, + ) -> AnyTokenizer: + if lora_request is not None: + raise NotImplementedError("LoRA is not yet supported.") + return self.tokenizer.get_lora_tokenizer(None) + + async def is_tracing_enabled(self) -> bool: + return False + + async def do_log_stats( + self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[list[SamplerOutput]] = None, + ) -> None: + pass + + async def check_health(self) -> None: + pass + + async def start_profile(self) -> None: + raise NotImplementedError + + async def stop_profile(self) -> None: + raise NotImplementedError + + async def reset_prefix_cache(self, + device: Optional[Device] = None) -> None: + raise NotImplementedError + + async def sleep(self, level: int = 1) -> None: + raise NotImplementedError + + async def wake_up(self) -> None: + raise NotImplementedError + + async def is_sleeping(self) -> bool: + return False + + async def add_lora(self, lora_request: LoRARequest) -> None: + raise NotImplementedError + + @property + def errored(self) -> bool: + return False + + def dead_error(self) -> Exception: + return Exception("PDController has failed.") + + def is_running(self) -> bool: + return True + + def is_stopped(self) -> bool: + return False + + async def get_vllm_config(self) -> VllmConfig: + raise NotImplementedError + + async def reset_mm_cache(self) -> None: + raise NotImplementedError + + +# Helper functions +def _has_mm_data(prompt: PromptType) -> bool: + if isinstance(prompt, dict): + return "multi_modal_data" in prompt + return False + +def _encode_mm_data(mm_data: dict[str, any]) -> dict[str, any]: + images = mm_data.get("image", []) + if not isinstance(images, list): + images = [images] + encoded_images = [] + for img in images: + if isinstance(img, np.ndarray): + encoded_img = { + "type": "ndarray", + "data": img.tobytes(), + "shape": img.shape, + "dtype": str(img.dtype), + } + encoded_images.append(encoded_img) + return {"image": encoded_images} + + +# FastAPI event handlers +@app.on_event("startup") +async def startup_event(): + # Initialize the proxy instance in the app state + if not hasattr(app.state, "proxy"): + # Use default values for testing, will be overridden by command line args + app.state.proxy = Proxy( + proxy_addr="/tmp/vllm_proxy.ipc", + encode_addr_list=["/tmp/vllm_encode_0.ipc"], + pd_addr_list=["/tmp/vllm_pd_0.ipc"], + model_name="unknown-model", + ) + + +@app.on_event("shutdown") +async def shutdown_event(): + if hasattr(app.state, "proxy"): + app.state.proxy.shutdown() + + +# FastAPI routes +@app.post("/v1/chat/completions") +async def chat_completions(request: Request): + try: + request_data = await request.json() + request_id = request.headers.get("x-request-id", str(uuid.uuid4())) + is_streaming = request_data.get("stream", False) + + # Extract parameters from request + prompt = request_data.get("messages", []) + # For simplicity, we'll use the last message content as the prompt + if prompt and isinstance(prompt, list): + prompt_text = prompt[-1].get("content", "") + else: + prompt_text = "" + + # Create sampling params + sampling_params = SamplingParams( + temperature=request_data.get("temperature", 0.7), + top_p=request_data.get("top_p", 1.0), + max_tokens=request_data.get("max_tokens", 100), + stop=request_data.get("stop", None), + seed=request_data.get("seed", 77), + repetition_penalty=request_data.get("repetition_penalty", 1.0), + stop_token_ids=request_data.get("stop_token_ids", None), + ) + + if is_streaming: + async def stream_generator(): + async for output in app.state.proxy.generate( + prompt=prompt_text, + sampling_params=sampling_params, + request_id=request_id, + ): + prompt_tokens = len(output.prompt_token_ids) + completion_tokens = len(output.outputs[0].token_ids) + total_tokens = prompt_tokens + completion_tokens + # Format according to OpenAI's streaming format + chunk = { + "id": request_id, + "object": "chat.completion.chunk", + "created": int(asyncio.get_event_loop().time()), + "model": app.state.proxy.model_config.model, + "choices": [ + { + "index": 0, + "delta": {"content": output.outputs[0].text}, + "finish_reason": output.outputs[0].finish_reason + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens + } + } + yield f"data: {msgspec.json.encode(chunk).decode()}\n\n" + # End of stream + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_generator(), media_type="text/event-stream") + else: + # For non-streaming, collect all outputs + final_output = None + async for output in app.state.proxy.generate( + prompt=prompt_text, + sampling_params=sampling_params, + request_id=request_id, + ): + final_output = output + + if final_output: + prompt_tokens = len(final_output.prompt_token_ids) + completion_tokens = len(final_output.outputs[0].token_ids) + total_tokens = prompt_tokens + completion_tokens + response = { + "id": request_id, + "object": "chat.completion", + "created": int(asyncio.get_event_loop().time()), + "model": app.state.proxy.model_config.model, + "choices": [ + { + "index": 0, + "message": {"role": "assistant", "content": final_output.outputs[0].text}, + "finish_reason": final_output.outputs[0].finish_reason + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens + } + } + return JSONResponse(content=response) + else: + raise HTTPException(status_code=500, detail="No response from proxy") + except Exception as e: + logger.error("Error processing chat completion request: %s", e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.post("/v1/completions") +async def completions(request: Request): + try: + request_data = await request.json() + request_id = request.headers.get("x-request-id", str(uuid.uuid4())) + is_streaming = request_data.get("stream", False) + + # Extract parameters from request + prompt = request_data.get("prompt", "") + + # Create sampling params + sampling_params = SamplingParams( + temperature=request_data.get("temperature", 0.7), + top_p=request_data.get("top_p", 1.0), + max_tokens=request_data.get("max_tokens", 100), + stop=request_data.get("stop", None), + seed=request_data.get("seed", 77), + repetition_penalty=request_data.get("repetition_penalty", 1.0), + stop_token_ids=request_data.get("stop_token_ids", None), + ) + + if is_streaming: + async def stream_generator(): + async for output in app.state.proxy.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + ): + prompt_tokens = len(output.prompt_token_ids) + completion_tokens = len(output.outputs[0].token_ids) + total_tokens = prompt_tokens + completion_tokens + # Format according to OpenAI's streaming format + chunk = { + "id": request_id, + "object": "text_completion.chunk", + "created": int(asyncio.get_event_loop().time()), + "model": app.state.proxy.model_config.model, + "choices": [ + { + "index": 0, + "text": output.outputs[0].text, + "finish_reason": output.outputs[0].finish_reason + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens + } + } + yield f"data: {msgspec.json.encode(chunk).decode()}\n\n" + # End of stream + yield "data: [DONE]\n\n" + + return StreamingResponse(stream_generator(), media_type="text/event-stream") + else: + # For non-streaming, collect all outputs + final_output = None + async for output in app.state.proxy.generate( + prompt=prompt, + sampling_params=sampling_params, + request_id=request_id, + ): + final_output = output + + if final_output: + prompt_tokens = len(final_output.prompt_token_ids) + completion_tokens = len(final_output.outputs[0].token_ids) + total_tokens = prompt_tokens + completion_tokens + response = { + "id": request_id, + "object": "text_completion", + "created": int(asyncio.get_event_loop().time()), + "model": app.state.proxy.model_config.model, + "choices": [ + { + "index": 0, + "text": final_output.outputs[0].text, + "finish_reason": final_output.outputs[0].finish_reason + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": total_tokens + } + } + return JSONResponse(content=response) + else: + raise HTTPException(status_code=500, detail="No response from proxy") + except Exception as e: + logger.error("Error processing completion request: %s", e) + raise HTTPException(status_code=500, detail=str(e)) from e + + +@app.get("/health") +async def health_check(): + try: + if hasattr(app.state, "proxy"): + return JSONResponse(content={"status": "healthy"}) + else: + return JSONResponse(content={"status": "unhealthy", "reason": "Proxy not initialized"}, status_code=503) + except Exception as e: + logger.error("Health check failed: %s", e) + return JSONResponse(content={"status": "unhealthy", "reason": str(e)}, status_code=503) + + +# Main entry point +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="vLLM Disaggregated Proxy") + parser.add_argument("--host", type=str, default="0.0.0.0", help="Proxy host") + parser.add_argument("--port", type=int, default=8000, help="Proxy port") + parser.add_argument("--proxy-addr", type=str, default="/tmp/vllm_proxy.ipc", help="Proxy IPC address") + parser.add_argument("--encode-addrs", nargs='+', type=str, required=True, help="Comma-separated list of encode worker IPC addresses") + parser.add_argument("--pd-addrs", nargs='+', type=str, required=True, help="Comma-separated list of PD worker IPC addresses") + parser.add_argument("--model-name", type=str, required=True, help="Model name") + + args = parser.parse_args() + + # Initialize the proxy with the provided arguments + app.state.proxy = Proxy( + proxy_addr=args.proxy_addr, + encode_addr_list=args.encode_addrs, + pd_addr_list=args.pd_addrs, + model_name=args.model_name, + ) + + logger.info(f"Starting vLLM Disaggregated Proxy on {args.host}:{args.port}") + + # Run the server with uvicorn + uvicorn.run( + app, + host=args.host, + port=args.port, + log_level="info", + access_log=False, + loop="asyncio", + )