diff --git a/examples/ray/ray_flux_example.py b/examples/ray/ray_flux_example.py new file mode 100644 index 00000000..1292d048 --- /dev/null +++ b/examples/ray/ray_flux_example.py @@ -0,0 +1,64 @@ +import time +import os +import torch +import torch.distributed +from transformers import T5EncoderModel +from xfuser import xFuserArgs +from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline +from xfuser.config import FlexibleArgumentParser +from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline + +def main(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + engine_config, input_config = engine_args.create_config() + engine_config.runtime_config.dtype = torch.bfloat16 + model_name = engine_config.model_config.model.split("/")[-1] + PipelineClass = xFuserFluxPipeline + text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16) + if args.use_fp8_t5_encoder: + from optimum.quanto import freeze, qfloat8, quantize + quantize(text_encoder_2, weights=qfloat8) + freeze(text_encoder_2) + + pipe = RayDiffusionPipeline.from_pretrained( + PipelineClass=PipelineClass, + pretrained_model_name_or_path=engine_config.model_config.model, + engine_config=engine_config, + torch_dtype=torch.bfloat16, + text_encoder_2=text_encoder_2, + ) + pipe.prepare_run(input_config) + + start_time = time.time() + output = pipe( + height=input_config.height, + width=input_config.width, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + output_type=input_config.output_type, + max_sequence_length=256, + guidance_scale=0.0, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + ) + end_time = time.time() + elapsed_time = end_time - start_time + + print(f"elapsed time:{elapsed_time}") + if not os.path.exists("results"): + os.mkdir("results") + # output is a list of results from each worker, we take the last one + for i, image in enumerate(output[-1].images): + image.save( + f"./results/{model_name}_result_{i}.png" + ) + print( + f"image {i} saved to ./results/{model_name}_result_{i}.png" + ) + + +if __name__ == "__main__": + main() diff --git a/examples/ray/ray_run.sh b/examples/ray/ray_run.sh new file mode 100644 index 00000000..c32f6c39 --- /dev/null +++ b/examples/ray/ray_run.sh @@ -0,0 +1,68 @@ +set -x +# If using a Ray cluster across multiple machines, you need to manually start a Ray cluster like this: +# ray start --head --port=6379 for master node +# ray start --address='192.168.1.1:6379' for worker node +# otherwise, it is not necessary. (for single node) + +export PYTHONPATH=$PWD:$PYTHONPATH + +# Select the model type +export MODEL_TYPE="Flux" +# Configuration for different model types +# script, model_id, inference_step +declare -A MODEL_CONFIGS=( + ["Sd3"]="ray_sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20" + ["Flux"]="ray_flux_example.py /cfs/dit/FLUX.1-dev 28" +) + +if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then + IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" + export SCRIPT MODEL_ID INFERENCE_STEP +else + echo "Invalid MODEL_TYPE: $MODEL_TYPE" + exit 1 +fi + +mkdir -p ./results + +# task args +TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning" + + +N_GPUS=2 +PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1" + +# CFG_ARGS="--use_cfg_parallel" + +# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance. +# PIPEFUSION_ARGS="--num_pipeline_patch 8 " + +# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed. +# OUTPUT_ARGS="--output_type latent" + +# PARALLLEL_VAE="--use_parallel_vae" + +# Another compile option is `--use_onediff` which will use onediff's compiler. +# COMPILE_FLAG="--use_torch_compile" + + +# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality. +# QUANTIZE_FLAG="--use_fp8_t5_encoder" + +export CUDA_VISIBLE_DEVICES=0,1 + +python ./examples/ray/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 1 \ +--prompt "brown dog laying on the ground with a metal bowl in front of him." \ +--use_ray \ +--ray_world_size $N_GPUS \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$COMPILE_FLAG \ +$QUANTIZE_FLAG \ diff --git a/examples/ray/ray_sd3_example.py b/examples/ray/ray_sd3_example.py new file mode 100644 index 00000000..79bbc215 --- /dev/null +++ b/examples/ray/ray_sd3_example.py @@ -0,0 +1,77 @@ +import time +import os +import torch +import torch.distributed +from transformers import T5EncoderModel +from xfuser import xFuserArgs +from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline +from xfuser.config import FlexibleArgumentParser +from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline +import time +import os +import torch +import torch.distributed +from transformers import T5EncoderModel +from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + is_dp_last_group, + get_data_parallel_rank, + get_runtime_state, +) +from xfuser.core.distributed.parallel_state import get_data_parallel_world_size + + +def main(): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + engine_config, input_config = engine_args.create_config() + model_name = engine_config.model_config.model.split("/")[-1] + PipelineClass = xFuserStableDiffusion3Pipeline + text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16) + if args.use_fp8_t5_encoder: + from optimum.quanto import freeze, qfloat8, quantize + print(f"rank {local_rank} quantizing text encoder 2") + quantize(text_encoder_3, weights=qfloat8) + freeze(text_encoder_3) + + pipe = RayDiffusionPipeline.from_pretrained( + PipelineClass=PipelineClass, + pretrained_model_name_or_path=engine_config.model_config.model, + engine_config=engine_config, + torch_dtype=torch.float16, + text_encoder_3=text_encoder_3, + ) + pipe.prepare_run(input_config) + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + output = pipe( + height=input_config.height, + width=input_config.width, + prompt=input_config.prompt, + num_inference_steps=input_config.num_inference_steps, + output_type=input_config.output_type, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + ) + end_time = time.time() + elapsed_time = end_time - start_time + print(f"elapsed time:{elapsed_time}") + if not os.path.exists("results"): + os.mkdir("results") + # output is a list of results from each worker, we take the last one + for i, image in enumerate(output[-1].images): + image.save( + f"./results/{model_name}_result_{i}.png" + ) + print( + f"image {i} saved to ./results/{model_name}_result_{i}.png" + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/setup.py b/setup.py index 4c8ab883..87a5d7fe 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,8 @@ def get_cuda_version(): "imageio", "imageio-ffmpeg", "optimum-quanto", - "flash_attn>=2.6.3" # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops + "flash_attn>=2.6.3", + "ray" ], extras_require={ "diffusers": [ diff --git a/xfuser/config/args.py b/xfuser/config/args.py index f187225b..c0d8d7bc 100644 --- a/xfuser/config/args.py +++ b/xfuser/config/args.py @@ -79,6 +79,9 @@ class xFuserArgs: # tensor parallel tensor_parallel_degree: int = 1 split_scheme: Optional[str] = "row" + # ray arguments + use_ray: bool = False + ray_world_size: int = 1 # pipefusion parallel pipefusion_parallel_degree: int = 1 num_pipeline_patch: Optional[int] = None @@ -151,6 +154,17 @@ def add_cli_args(parser: FlexibleArgumentParser): # Parallel arguments parallel_group = parser.add_argument_group("Parallel Processing Options") + runtime_group.add_argument( + "--use_ray", + action="store_true", + help="Enable ray to run inference in multi-card", + ) + parallel_group.add_argument( + "--ray_world_size", + type=int, + default=1, + help="The number of ray workers (world_size for ray)", + ) parallel_group.add_argument( "--use_cfg_parallel", action="store_true", @@ -322,11 +336,15 @@ def from_cli_args(cls, args: argparse.Namespace): def create_config( self, ) -> Tuple[EngineConfig, InputConfig]: - if not torch.distributed.is_initialized(): + if not self.use_ray and not torch.distributed.is_initialized(): logger.warning( "Distributed environment is not initialized. " "Initializing..." ) init_distributed_environment() + if self.use_ray: + self.world_size = self.ray_world_size + else: + self.world_size = torch.distributed.get_world_size() model_config = ModelConfig( model=self.model, @@ -348,20 +366,25 @@ def create_config( dp_config=DataParallelConfig( dp_degree=self.data_parallel_degree, use_cfg_parallel=self.use_cfg_parallel, + world_size=self.world_size, ), sp_config=SequenceParallelConfig( ulysses_degree=self.ulysses_degree, ring_degree=self.ring_degree, + world_size=self.world_size, ), tp_config=TensorParallelConfig( tp_degree=self.tensor_parallel_degree, split_scheme=self.split_scheme, + world_size=self.world_size, ), pp_config=PipeFusionParallelConfig( pp_degree=self.pipefusion_parallel_degree, num_pipeline_patch=self.num_pipeline_patch, attn_layer_num_for_pp=self.attn_layer_num_for_pp, + world_size=self.world_size, ), + world_size=self.world_size, ) fast_attn_config = FastAttnConfig( diff --git a/xfuser/config/config.py b/xfuser/config/config.py index b599d244..eeb22143 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -86,6 +86,7 @@ def __post_init__(self): class DataParallelConfig: dp_degree: int = 1 use_cfg_parallel: bool = False + world_size: int = 1 def __post_init__(self): assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1" @@ -95,12 +96,12 @@ def __post_init__(self): self.cfg_degree = 2 else: self.cfg_degree = 1 - assert self.dp_degree * self.cfg_degree <= dist.get_world_size(), ( + assert self.dp_degree * self.cfg_degree <= self.world_size, ( "dp_degree * cfg_degree must be less than or equal to " "world_size because of classifier free guidance" ) assert ( - dist.get_world_size() % (self.dp_degree * self.cfg_degree) == 0 + self.world_size % (self.dp_degree * self.cfg_degree) == 0 ), "world_size must be divisible by dp_degree * cfg_degree" @@ -108,6 +109,7 @@ def __post_init__(self): class SequenceParallelConfig: ulysses_degree: Optional[int] = None ring_degree: Optional[int] = None + world_size: int = 1 def __post_init__(self): if self.ulysses_degree is None: @@ -138,11 +140,12 @@ def __post_init__(self): class TensorParallelConfig: tp_degree: int = 1 split_scheme: Optional[str] = "row" + world_size: int = 1 def __post_init__(self): assert self.tp_degree >= 1, "tp_degree must greater than 1" assert ( - self.tp_degree <= dist.get_world_size() + self.tp_degree <= self.world_size ), "tp_degree must be less than or equal to world_size" @@ -151,13 +154,14 @@ class PipeFusionParallelConfig: pp_degree: int = 1 num_pipeline_patch: Optional[int] = None attn_layer_num_for_pp: Optional[List[int]] = (None,) + world_size: int = 1 def __post_init__(self): assert ( self.pp_degree is not None and self.pp_degree >= 1 ), "pipefusion_degree must be set and greater than 1 to use pipefusion" assert ( - self.pp_degree <= dist.get_world_size() + self.pp_degree <= self.world_size ), "pipefusion_degree must be less than or equal to world_size" if self.num_pipeline_patch is None: self.num_pipeline_patch = self.pp_degree @@ -188,6 +192,8 @@ class ParallelConfig: sp_config: SequenceParallelConfig pp_config: PipeFusionParallelConfig tp_config: TensorParallelConfig + world_size: int = 1 # FIXME: remove this + worker_cls: str = "xfuser.ray.worker.worker.Worker" def __post_init__(self): assert self.tp_config is not None, "tp_config must be set" @@ -201,10 +207,10 @@ def __post_init__(self): * self.tp_config.tp_degree * self.pp_config.pp_degree ) - world_size = dist.get_world_size() + world_size = self.world_size assert parallel_world_size == world_size, ( f"parallel_world_size {parallel_world_size} " - f"must be equal to world_size {dist.get_world_size()}" + f"must be equal to world_size {self.world_size}" ) assert ( world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0 @@ -236,7 +242,7 @@ class EngineConfig: fast_attn_config: FastAttnConfig def __post_init__(self): - world_size = dist.get_world_size() + world_size = self.parallel_config.world_size if self.fast_attn_config.use_fast_attn: assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn" diff --git a/xfuser/ray/pipeline/__init__.py b/xfuser/ray/pipeline/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xfuser/ray/pipeline/base_executor.py b/xfuser/ray/pipeline/base_executor.py new file mode 100644 index 00000000..8a86eb53 --- /dev/null +++ b/xfuser/ray/pipeline/base_executor.py @@ -0,0 +1,21 @@ +# Copyright 2024 The xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/executor/executor_base.py +# Copyright (c) 2023, vLLM team. All rights reserved. +from abc import ABC, abstractmethod + +from xfuser.config.config import EngineConfig + + +class BaseExecutor(ABC): + def __init__( + self, + engine_config: EngineConfig, + ): + self.engine_config = engine_config + self.parallel_config = engine_config.parallel_config + self._init_executor() + + @abstractmethod + def _init_executor(self): + pass diff --git a/xfuser/ray/pipeline/pipeline_utils.py b/xfuser/ray/pipeline/pipeline_utils.py new file mode 100644 index 00000000..a4959a81 --- /dev/null +++ b/xfuser/ray/pipeline/pipeline_utils.py @@ -0,0 +1,113 @@ +# Copyright 2024 The xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/executor/gpu_executor.py +# Copyright (c) 2023, vLLM team. All rights reserved. +import ray +from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy +from itertools import islice, repeat +from typing import Any, Dict, List, Optional, Tuple + +from xfuser.ray.pipeline.base_executor import BaseExecutor +from xfuser.ray.pipeline.ray_utils import initialize_ray_cluster +from xfuser.logger import init_logger +from xfuser.ray.worker.worker_wrappers import RayWorkerWrapper +from xfuser.config.config import InputConfig, EngineConfig +logger = init_logger(__name__) + + +class GPUExecutor(BaseExecutor): + def _init_executor(self): + pass + + +class RayDiffusionPipeline(GPUExecutor): + workers = [] + def _init_executor(self): + self._init_ray_workers() + self._run_workers("init_worker_distributed_environment") + + def _init_ray_workers(self): + placement_group = initialize_ray_cluster(self.engine_config.parallel_config) + + # create placement group and worker wrapper instance for lazy load worker + self.workers = [] + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + # Skip bundles without GPUs + if not bundle.get("GPU", 0): + continue + + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_bundle_index=bundle_id, + placement_group_capture_child_tasks=True, + ) + worker = ray.remote( + num_cpus=0, + num_gpus=1, + scheduling_strategy=scheduling_strategy, + )(RayWorkerWrapper).remote(self.engine_config,bundle_id) + self.workers.append(worker) + + self.node_metadata = {} + + def _run_workers( + self, + method: str, + *args, + async_run_tensor_parallel_workers_only: bool = False, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. Can be used in the following + ways: + + Args: + - async_run_tensor_parallel_workers_only: If True the method will be + run only in the remote TP workers, not the driver worker. + It will also be run asynchronously and return a list of futures + rather than blocking on the results. + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + + count = len(self.workers) + # If using SPMD worker, all workers are the same, so we should execute + # the args on all workers. Otherwise, we skip the first worker's args + # because those args will go to the driver worker. + first_worker_args_index: int = 0 + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, first_worker_args_index, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, first_worker_args_index, None) + + # Start the ray workers first. + ray_workers = self.workers + ray_worker_outputs = [ + worker.execute_method.remote(method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(ray_workers, all_worker_args, all_worker_kwargs) + ] + + if async_run_tensor_parallel_workers_only: + # Just return futures + return ray_worker_outputs + + # Get the results of the ray workers. + if self.workers: + ray_worker_outputs = ray.get(ray_worker_outputs) + + return ray_worker_outputs + + @classmethod + def from_pretrained(cls,PipelineClass,pretrained_model_name_or_path: str,engine_config: EngineConfig,**kwargs): + pipeline = cls(engine_config) + pipeline._run_workers("from_pretrained",PipelineClass,pretrained_model_name_or_path,engine_config,**kwargs) + return pipeline + + def prepare_run(self, input_config: InputConfig, steps: int = 3, sync_steps: int = 1): + self._run_workers("prepare_run",input_config,steps,sync_steps) + + def __call__(self,**kwargs): + return self._run_workers("execute",**kwargs) diff --git a/xfuser/ray/pipeline/ray_utils.py b/xfuser/ray/pipeline/ray_utils.py new file mode 100644 index 00000000..e90c7b68 --- /dev/null +++ b/xfuser/ray/pipeline/ray_utils.py @@ -0,0 +1,296 @@ +# Copyright 2024 The xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/executor/ray_utils.py +# Copyright (c) 2023, vLLM team. All rights reserved. +import time +import socket +from typing import Dict, List, Optional +from collections import defaultdict + +from xfuser.config import ParallelConfig +from xfuser.logger import init_logger +from xfuser.envs import environment_variables + +logger = init_logger(__name__) + +PG_WAIT_TIMEOUT = 1800 + +try: + import ray + from ray.util import placement_group_table + from ray.util.placement_group import PlacementGroup + +except ImportError as e: + ray = None # type: ignore + ray_import_err = e + + +def ray_is_available() -> bool: + """Returns True if Ray is available.""" + return ray is not None + + +def assert_ray_available(): + """Raise an exception if Ray is not available.""" + if ray is None: + raise ValueError( + "Failed to import Ray, please install Ray with " "`pip install ray`." + ) from ray_import_err + + +def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): + """Wait until a placement group is ready. + + It prints the informative log messages if the placement group is + not created within time. + + """ + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + placement_group_specs = current_placement_group.bundle_specs + + s = time.time() + pg_ready_ref = current_placement_group.ready() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval) + if len(ready) > 0: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for creating a placement group of specs for " + "%d seconds. specs=%s. Check " + "`ray status` to see if you have enough resources.", + int(time.time() - s), + placement_group_specs, + ) + + try: + ray.get(pg_ready_ref, timeout=0) + except ray.exceptions.GetTimeoutError: + raise ValueError( + "Cannot provide a placement group of " + f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " + "`ray status` to make sure the cluster has enough resources." + ) from None + + +def _wait_until_pg_removed(current_placement_group: "PlacementGroup"): + ray.util.remove_placement_group(current_placement_group) + s = time.time() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + pg = ray.util.get_current_placement_group() + if pg is None: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for removing a placement group of specs for " "%d seconds.", + int(time.time() - s), + ) + time.sleep(wait_interval) + + +def initialize_ray_cluster( + parallel_config: ParallelConfig, + ray_address: Optional[str] = None, +): + """Initialize the distributed cluster with Ray. + + it will connect to the Ray cluster and create a placement group + for the workers, which includes the specification of the resources + for each distributed worker. + + Args: + parallel_config: The configurations for parallel execution. + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + """ + assert_ray_available() + + # Connect to a ray cluster. + ray.init(address=ray_address, ignore_reinit_error=True) + + device_str = "GPU" + # Create placement group for worker processes + current_placement_group = ray.util.get_current_placement_group() + if current_placement_group: + # We are in a placement group + bundles = current_placement_group.bundle_specs + # Verify that we can use the placement group. + device_bundles = 0 + for bundle in bundles: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: + raise ValueError( + "Placement group bundle cannot have more than 1 " f"{device_str}." + ) + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group." + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}." + ) + else: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group." + ) + # Create a new placement group + placement_group_specs: List[Dict[str, float]] = [ + {device_str: 1.0} for _ in range(parallel_config.world_size) + ] + + # By default, Ray packs resources as much as possible. + current_placement_group = ray.util.placement_group( + placement_group_specs, strategy="PACK" + ) + _wait_until_pg_ready(current_placement_group) + + assert current_placement_group is not None + _verify_bundles(current_placement_group, parallel_config, device_str) + # Set the placement group in the parallel config + return current_placement_group + + +def get_num_nodes_in_placement_group() -> int: + pg_table = ray.util.placement_group_table() + current_pg = ray.util.get_current_placement_group() + num_nodes = 0 + + if current_pg: + nodes_in_pg = set() + for pg_key, pg in pg_table.items(): + if pg_key == current_pg.id.hex(): + for _, node in pg["bundles_to_node_id"].items(): + nodes_in_pg.add(node) + num_nodes = len(nodes_in_pg) + + return num_nodes + + +def get_ip() -> str: + host_ip = environment_variables["MASTER_ADDR"]() + if host_ip: + return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + logger.warning( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", + stacklevel=2, + ) + return "0.0.0.0" + + +def get_open_port() -> int: + port = environment_variables["MASTER_PORT"]() + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info( + "Port %d is already in use, trying port %d", port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def get_distributed_init_method(ip: str, port: int) -> str: + # Brackets are not permitted in ipv4 addresses, + # see https://github.com/python/cpython/issues/103848 + return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" + + +def _verify_bundles( + placement_group: "PlacementGroup", parallel_config: ParallelConfig, device_str: str +): + """Verify a given placement group has bundles located in the right place. + + There are 2 rules. + - Warn if all tensor parallel workers cannot fit in a single node. + - Fail if driver node is not included in a placement group. + """ + assert ( + ray.is_initialized() + ), "Ray is not initialized although distributed-executor-backend is ray." + pg_data = placement_group_table(placement_group) + # bundle_idx -> node_id + bundle_to_node_ids = pg_data["bundles_to_node_id"] + # bundle_idx -> bundle (e.g., {"GPU": 1}) + bundles = pg_data["bundles"] + # node_id -> List of bundle (e.g., {"GPU": 1}) + node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + + for bundle_idx, node_id in bundle_to_node_ids.items(): + node_id_to_bundle[node_id].append(bundles[bundle_idx]) + driver_node_id = ray.get_runtime_context().get_node_id() + + if driver_node_id not in node_id_to_bundle: + raise RuntimeError( + f"driver node id {driver_node_id} is not included in a placement " + f"group {placement_group.id}. Node id -> bundles " + f"{node_id_to_bundle}. " + "You don't have enough GPUs available in a current node. Check " + "`ray status` to see if you have available GPUs in a node " + f"{driver_node_id} before starting an vLLM engine." + ) + + for node_id, bundles in node_id_to_bundle.items(): + if len(bundles) < parallel_config.sp_degree: + logger.warning( + "sequence parallel degree=%d " + "is bigger than a reserved number of %ss (%d " + "%ss) in a node %s. Sequence parallel workers can be " + "spread out to 2+ nodes which can degrade the performance " + "unless you have fast interconnect across nodes, like " + "Infiniband. To resolve this issue, make sure you have more " + "than %d GPUs available at each node.", + parallel_config.sp_degree, + device_str, + len(bundles), + device_str, + node_id, + parallel_config.tp_degree, + ) diff --git a/xfuser/ray/worker/__init__.py b/xfuser/ray/worker/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/xfuser/ray/worker/utils.py b/xfuser/ray/worker/utils.py new file mode 100644 index 00000000..97d2d772 --- /dev/null +++ b/xfuser/ray/worker/utils.py @@ -0,0 +1,31 @@ +# Copyright 2024 The xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/utils.py +# Copyright (c) 2023, vLLM team. All rights reserved. +import os +from typing import Dict, Any +import importlib.util +from xfuser.logger import init_logger + +logger = init_logger(__name__) + + +def resolve_obj_by_qualname(qualname: str) -> Any: + """ + Resolve an object by its fully qualified name. + """ + module_name, obj_name = qualname.rsplit(".", 1) + module = importlib.import_module(module_name) + return getattr(module, obj_name) + + +def update_environment_variables(envs: Dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " "from '%s' to '%s'", + k, + os.environ[k], + v, + ) + os.environ[k] = v diff --git a/xfuser/ray/worker/worker.py b/xfuser/ray/worker/worker.py new file mode 100644 index 00000000..574ea01f --- /dev/null +++ b/xfuser/ray/worker/worker.py @@ -0,0 +1,65 @@ +from abc import ABC, abstractmethod + +from xfuser.core.distributed import ( + get_world_group, +) +from xfuser.config.config import EngineConfig, InputConfig,ParallelConfig +from xfuser.core.distributed import init_distributed_environment + +class WorkerBase(ABC): + def __init__( + self, + ) -> None: + pass + + @abstractmethod + def from_pretrained( + self, PipelineClass, pretrained_model_name_or_path: str, engine_config: EngineConfig,**kwargs, + ): + raise NotImplementedError + + @abstractmethod + def prepare_run(self,input_config: InputConfig,steps: int = 3,sync_steps: int = 1): + raise NotImplementedError + @abstractmethod + def execute(self, **kwargs): + raise NotImplementedError + + +class Worker(WorkerBase): + """ + A worker class that executes (a partition of) the model on a GPU. + """ + parallel_config: ParallelConfig + def __init__( + self, + parallel_config: ParallelConfig, + rank: int, + ) -> None: + WorkerBase.__init__(self) + self.parallel_config = parallel_config + self.rank = rank + self.pipe = None + + def init_worker_distributed_environment(self): + init_distributed_environment( + rank=self.rank, + world_size=self.parallel_config.world_size, + ) + + def from_pretrained(self,PipelineClass, pretrained_model_name_or_path: str, engine_config: EngineConfig,**kwargs,): + local_rank = get_world_group().local_rank + pipe = PipelineClass.from_pretrained( + pretrained_model_name_or_path=pretrained_model_name_or_path, + engine_config=engine_config, + **kwargs + ).to(f"cuda:{local_rank}") + self.pipe = pipe + return + + def prepare_run(self,input_config: InputConfig,steps: int = 3,sync_steps: int = 1): + self.pipe.prepare_run(input_config,steps,sync_steps) + + def execute(self, **kwargs): + output = self.pipe(**kwargs) + return output diff --git a/xfuser/ray/worker/worker_wrappers.py b/xfuser/ray/worker/worker_wrappers.py new file mode 100644 index 00000000..0f4a4a78 --- /dev/null +++ b/xfuser/ray/worker/worker_wrappers.py @@ -0,0 +1,41 @@ +# Copyright 2024 The xDiT team. +# Adapted from +# https://github.com/vllm-project/vllm/blob/main/vllm/worker/worker_base.py +# Copyright (c) 2023, vLLM team. All rights reserved. +import os +from abc import ABC +from typing import Any, Dict + +from xfuser.ray.worker.utils import update_environment_variables, resolve_obj_by_qualname +from xfuser.config.config import EngineConfig + +class BaseWorkerWrapper(ABC): + def __init__(self, worker_cls: str): + self.worker_cls = worker_cls + self.worker = None + + # lazy import + def init_worker(self, *args, **kwargs): + worker_class = resolve_obj_by_qualname( + self.worker_cls) + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method: str, *args, **kwargs) -> Any: + method = getattr(self, method, None) or getattr( + self.worker, method, None) + if not method: + raise (AttributeError( + f"Method {method} not found in Worker class")) + return method(*args, **kwargs) + + def update_environs(environs: Dict[str, str]): + if "CUDA_VISIBLE_DEVICES" in environs and "CUDA_VISIBLE_DEVICES" in os.environ: + del os.environ["CUDA_VISIBLE_DEVICES"] + update_environment_variables(environs) + + +class RayWorkerWrapper(BaseWorkerWrapper): + def __init__(self, engine_config: EngineConfig, rank: int) -> None: + super().__init__(engine_config.parallel_config.worker_cls) + self.init_worker(engine_config.parallel_config, rank)