Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ray] Add Support for Disaggregating VAE and DiT #422

Merged
merged 5 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions examples/ray/ray_flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
from xfuser.model_executor.pipelines import xFuserFluxPipeline

def main():
os.environ["MASTER_ADDR"] = "localhost"
Expand Down Expand Up @@ -50,14 +50,17 @@ def main():
print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)

for i, images in enumerate(output):
if images is not None:
image = images[0]
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)
break


if __name__ == "__main__":
Expand Down
11 changes: 8 additions & 3 deletions examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ mkdir -p ./results
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


N_GPUS=2
N_GPUS=3 # world size
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

VAE_PARALLEL_SIZE=1
DIT_PARALLEL_SIZE=2
# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
Expand All @@ -49,7 +50,8 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

export CUDA_VISIBLE_DEVICES=0,1
# It is necessary to set CUDA_VISIBLE_DEVICES for the ray driver and workers.
export CUDA_VISIBLE_DEVICES=4,5,6,7

python ./examples/ray/$SCRIPT \
--model $MODEL_ID \
Expand All @@ -66,3 +68,6 @@ $CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
--use_parallel_vae \
--dit_parallel_size $DIT_PARALLEL_SIZE \
--vae_parallel_size $VAE_PARALLEL_SIZE
50 changes: 25 additions & 25 deletions examples/ray/ray_sd3_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,7 @@
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size
from xfuser.model_executor.pipelines import xFuserStableDiffusion3Pipeline


def main():
Expand All @@ -32,7 +18,19 @@ def main():
engine_config, input_config = engine_args.create_config()
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = xFuserStableDiffusion3Pipeline
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)

# equal to
# text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
# but load encoder in worker
encoder_kwargs = {
'text_encoder_3': {
'model_class': T5EncoderModel,
'pretrained_model_name_or_path': engine_config.model_config.model,
'subfolder': 'text_encoder_3',
'torch_dtype': torch.float16
},
}

if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder 2")
Expand All @@ -44,7 +42,7 @@ def main():
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_3=text_encoder_3,
**encoder_kwargs
)
pipe.prepare_run(input_config)

Expand All @@ -63,14 +61,16 @@ def main():
print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)
for i, images in enumerate(output):
if images is not None:
image = images[0]
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)
break


if __name__ == "__main__":
Expand Down
29 changes: 24 additions & 5 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class xFuserArgs:
# ray arguments
use_ray: bool = False
ray_world_size: int = 1
vae_parallel_size: int = 0
dit_parallel_size: int = 1
# pipefusion parallel
pipefusion_parallel_degree: int = 1
num_pipeline_patch: Optional[int] = None
Expand Down Expand Up @@ -165,6 +167,12 @@ def add_cli_args(parser: FlexibleArgumentParser):
default=1,
help="The number of ray workers (world_size for ray)",
)
parallel_group.add_argument(
"--dit_parallel_size",
type=int,
default=0,
help="The number of processes for DIT parallelization.",
)
parallel_group.add_argument(
"--use_cfg_parallel",
action="store_true",
Expand Down Expand Up @@ -210,6 +218,12 @@ def add_cli_args(parser: FlexibleArgumentParser):
default=1,
help="Tensor parallel degree.",
)
parallel_group.add_argument(
"--vae_parallel_size",
type=int,
default=0,
help="Number of processes for VAE parallelization. 0: no seperate process for VAE, 1: run VAE in a separate process, >1: distribute VAE across multiple processes.",
)
parallel_group.add_argument(
"--split_scheme",
type=str,
Expand Down Expand Up @@ -345,7 +359,10 @@ def create_config(
self.world_size = self.ray_world_size
else:
self.world_size = torch.distributed.get_world_size()


if self.dit_parallel_size == 0 and (not self.use_parallel_vae or self.vae_parallel_size == 0):
self.dit_parallel_size = self.world_size
assert self.dit_parallel_size+self.vae_parallel_size == self.world_size, f"DIT parallel size {self.dit_parallel_size} and VAE parallel size {self.vae_parallel_size} must sum to world size {self.world_size}"
model_config = ModelConfig(
model=self.model,
download_dir=self.download_dir,
Expand All @@ -366,25 +383,27 @@ def create_config(
dp_config=DataParallelConfig(
dp_degree=self.data_parallel_degree,
use_cfg_parallel=self.use_cfg_parallel,
world_size=self.world_size,
dit_parallel_size=self.dit_parallel_size,
),
sp_config=SequenceParallelConfig(
ulysses_degree=self.ulysses_degree,
ring_degree=self.ring_degree,
world_size=self.world_size,
dit_parallel_size=self.dit_parallel_size,
),
tp_config=TensorParallelConfig(
tp_degree=self.tensor_parallel_degree,
split_scheme=self.split_scheme,
world_size=self.world_size,
dit_parallel_size=self.dit_parallel_size,
),
pp_config=PipeFusionParallelConfig(
pp_degree=self.pipefusion_parallel_degree,
num_pipeline_patch=self.num_pipeline_patch,
attn_layer_num_for_pp=self.attn_layer_num_for_pp,
world_size=self.world_size,
dit_parallel_size=self.dit_parallel_size,
),
world_size=self.world_size,
dit_parallel_size=self.dit_parallel_size,
vae_parallel_size=self.vae_parallel_size,
)

fast_attn_config = FastAttnConfig(
Expand Down
52 changes: 26 additions & 26 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __post_init__(self):
class DataParallelConfig:
dp_degree: int = 1
use_cfg_parallel: bool = False
world_size: int = 1
dit_parallel_size: int = 1

def __post_init__(self):
assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1"
Expand All @@ -96,20 +96,20 @@ def __post_init__(self):
self.cfg_degree = 2
else:
self.cfg_degree = 1
assert self.dp_degree * self.cfg_degree <= self.world_size, (
assert self.dp_degree * self.cfg_degree <= self.dit_parallel_size, (
"dp_degree * cfg_degree must be less than or equal to "
"world_size because of classifier free guidance"
"dit_parallel_size because of classifier free guidance"
)
assert (
self.world_size % (self.dp_degree * self.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
self.dit_parallel_size % (self.dp_degree * self.cfg_degree) == 0
), "dit_parallel_size must be divisible by dp_degree * cfg_degree"


@dataclass
class SequenceParallelConfig:
ulysses_degree: Optional[int] = None
ring_degree: Optional[int] = None
world_size: int = 1
dit_parallel_size: int = 1

def __post_init__(self):
if self.ulysses_degree is None:
Expand Down Expand Up @@ -140,29 +140,29 @@ def __post_init__(self):
class TensorParallelConfig:
tp_degree: int = 1
split_scheme: Optional[str] = "row"
world_size: int = 1
dit_parallel_size: int = 1

def __post_init__(self):
assert self.tp_degree >= 1, "tp_degree must greater than 1"
assert (
self.tp_degree <= self.world_size
), "tp_degree must be less than or equal to world_size"
self.tp_degree <= self.dit_parallel_size
), "tp_degree must be less than or equal to dit_parallel_size"


@dataclass
class PipeFusionParallelConfig:
pp_degree: int = 1
num_pipeline_patch: Optional[int] = None
attn_layer_num_for_pp: Optional[List[int]] = (None,)
world_size: int = 1
dit_parallel_size: int = 1

def __post_init__(self):
assert (
self.pp_degree is not None and self.pp_degree >= 1
), "pipefusion_degree must be set and greater than 1 to use pipefusion"
assert (
self.pp_degree <= self.world_size
), "pipefusion_degree must be less than or equal to world_size"
self.pp_degree <= self.dit_parallel_size
), "pipefusion_degree must be less than or equal to dit_parallel_size"
if self.num_pipeline_patch is None:
self.num_pipeline_patch = self.pp_degree
logger.info(
Expand Down Expand Up @@ -193,7 +193,8 @@ class ParallelConfig:
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"
dit_parallel_size: int = 1
vae_parallel_size: int = 1 # 0 means the vae is in the same process with diffusion

def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
Expand All @@ -207,23 +208,23 @@ def __post_init__(self):
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = self.world_size
assert parallel_world_size == world_size, (
dit_parallel_size = self.dit_parallel_size
assert parallel_world_size == dit_parallel_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {self.world_size}"
f"must be equal to dit_parallel_size {self.dit_parallel_size}"
)
assert (
world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"
dit_parallel_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
), "dit_parallel_size must be divisible by dp_degree * cfg_degree"
assert (
world_size % self.pp_config.pp_degree == 0
), "world_size must be divisible by pp_degree"
dit_parallel_size % self.pp_config.pp_degree == 0
), "dit_parallel_size must be divisible by pp_degree"
assert (
world_size % self.sp_config.sp_degree == 0
), "world_size must be divisible by sp_degree"
dit_parallel_size % self.sp_config.sp_degree == 0
), "dit_parallel_size must be divisible by sp_degree"
assert (
world_size % self.tp_config.tp_degree == 0
), "world_size must be divisible by tp_degree"
dit_parallel_size % self.tp_config.tp_degree == 0
), "dit_parallel_size must be divisible by tp_degree"
self.dp_degree = self.dp_config.dp_degree
self.cfg_degree = self.dp_config.cfg_degree
self.sp_degree = self.sp_config.sp_degree
Expand All @@ -242,9 +243,8 @@ class EngineConfig:
fast_attn_config: FastAttnConfig

def __post_init__(self):
world_size = self.parallel_config.world_size
if self.fast_attn_config.use_fast_attn:
assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"
assert self.parallel_config.dp_degree == self.parallel_config.dit_parallel_size, f"dit_parallel_size must be equal to dp_degree when using DiTFastAttn"

def to_dict(self):
"""Return the configs as a dictionary, for use in **kwargs."""
Expand Down
14 changes: 14 additions & 0 deletions xfuser/core/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
initialize_model_parallel,
model_parallel_is_initialized,
get_tensor_model_parallel_world_size,
get_vae_parallel_group,
get_vae_parallel_rank,
get_vae_parallel_world_size,
get_dit_world_size,
init_vae_group,
init_dit_group,
get_dit_group,
)
from .runtime_state import (
get_runtime_state,
Expand Down Expand Up @@ -58,4 +65,11 @@
"get_runtime_state",
"runtime_state_is_initialized",
"initialize_runtime_state",
"get_dit_world_size",
"get_vae_parallel_group",
"get_vae_parallel_rank",
"get_vae_parallel_world_size",
"init_vae_group",
"init_dit_group",
"get_dit_group",
]
Loading
Loading