Skip to content

Commit

Permalink
Seperate vae and dit
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuahua123 committed Jan 5, 2025
1 parent d6868c1 commit dd38452
Show file tree
Hide file tree
Showing 15 changed files with 420 additions and 94 deletions.
6 changes: 3 additions & 3 deletions examples/ray/ray_flux_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def main():
if not os.path.exists("results"):
os.mkdir("results")

for i, result in enumerate(output):
if result is not None:
image = result.images[0]
for i, images in enumerate(output):
if images is not None:
image = images[0]
image.save(
f"./results/{model_name}_result_{i}.png"
)
Expand Down
9 changes: 5 additions & 4 deletions examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ mkdir -p ./results
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


N_GPUS=2
N_GPUS=3 # world size
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

VAE_PARALLEL_SIZE=1
# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
Expand All @@ -50,7 +50,7 @@ PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

# It is necessary to set CUDA_VISIBLE_DEVICES for the ray driver and workers.
export CUDA_VISIBLE_DEVICES=0,1
export CUDA_VISIBLE_DEVICES=4,5,6,7

python ./examples/ray/$SCRIPT \
--model $MODEL_ID \
Expand All @@ -67,4 +67,5 @@ $CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
--use_parallel_vae \
--use_parallel_vae \
--vae_parallel_size $VAE_PARALLEL_SIZE
19 changes: 14 additions & 5 deletions xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class xFuserArgs:
# ray arguments
use_ray: bool = False
ray_world_size: int = 1
vae_parallel_size: int = 0
# pipefusion parallel
pipefusion_parallel_degree: int = 1
num_pipeline_patch: Optional[int] = None
Expand Down Expand Up @@ -210,6 +211,12 @@ def add_cli_args(parser: FlexibleArgumentParser):
default=1,
help="Tensor parallel degree.",
)
parallel_group.add_argument(
"--vae_parallel_size",
type=int,
default=0,
help="VAE parallel size.",
)
parallel_group.add_argument(
"--split_scheme",
type=str,
Expand Down Expand Up @@ -345,7 +352,7 @@ def create_config(
self.world_size = self.ray_world_size
else:
self.world_size = torch.distributed.get_world_size()

self.dit_world_size = self.world_size - self.vae_parallel_size # FIXME: Lack of scalability
model_config = ModelConfig(
model=self.model,
download_dir=self.download_dir,
Expand All @@ -366,25 +373,27 @@ def create_config(
dp_config=DataParallelConfig(
dp_degree=self.data_parallel_degree,
use_cfg_parallel=self.use_cfg_parallel,
world_size=self.world_size,
world_size=self.dit_world_size,
),
sp_config=SequenceParallelConfig(
ulysses_degree=self.ulysses_degree,
ring_degree=self.ring_degree,
world_size=self.world_size,
world_size=self.dit_world_size,
),
tp_config=TensorParallelConfig(
tp_degree=self.tensor_parallel_degree,
split_scheme=self.split_scheme,
world_size=self.world_size,
world_size=self.dit_world_size,
),
pp_config=PipeFusionParallelConfig(
pp_degree=self.pipefusion_parallel_degree,
num_pipeline_patch=self.num_pipeline_patch,
attn_layer_num_for_pp=self.attn_layer_num_for_pp,
world_size=self.world_size,
world_size=self.dit_world_size,
),
world_size=self.world_size,
dit_world_size=self.dit_world_size,
vae_parallel_size=self.vae_parallel_size,
)

fast_attn_config = FastAttnConfig(
Expand Down
5 changes: 3 additions & 2 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@ class ParallelConfig:
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"
dit_world_size: int = 1 # FIXME: remove this
vae_parallel_size: int = 1 # 0 means the vae is in the same process with diffusion

def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
Expand All @@ -207,7 +208,7 @@ def __post_init__(self):
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = self.world_size
world_size = self.dit_world_size
assert parallel_world_size == world_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {self.world_size}"
Expand Down
14 changes: 14 additions & 0 deletions xfuser/core/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@
initialize_model_parallel,
model_parallel_is_initialized,
get_tensor_model_parallel_world_size,
get_vae_parallel_group,
get_vae_parallel_rank,
get_vae_parallel_world_size,
get_dit_world_size,
init_vae_group,
init_dit_group,
get_dit_group,
)
from .runtime_state import (
get_runtime_state,
Expand Down Expand Up @@ -58,4 +65,11 @@
"get_runtime_state",
"runtime_state_is_initialized",
"initialize_runtime_state",
"get_dit_world_size",
"get_vae_parallel_group",
"get_vae_parallel_rank",
"get_vae_parallel_world_size",
"init_vae_group",
"init_dit_group",
"get_dit_group",
]
5 changes: 5 additions & 0 deletions xfuser/core/distributed/group_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ def __init__(
else:
self.device = torch.device("cpu")

@property
def size(self):
"""Return the size of the process group (alias for world_size)"""
return self.world_size

@property
def first_rank(self):
"""Return the global rank of the first process in the group"""
Expand Down
82 changes: 67 additions & 15 deletions xfuser/core/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import torch
import torch.distributed
import xfuser.envs as envs

import os
from xfuser.logger import init_logger
from .group_coordinator import (
GroupCoordinator,
Expand All @@ -30,7 +30,8 @@
_PP: Optional[PipelineGroupCoordinator] = None
_CFG: Optional[GroupCoordinator] = None
_DP: Optional[GroupCoordinator] = None

_DIT: Optional[GroupCoordinator] = None
_VAE: Optional[GroupCoordinator] = None

# * QUERY
def get_world_group() -> GroupCoordinator:
Expand Down Expand Up @@ -155,6 +156,26 @@ def is_dp_last_group():
and get_pipeline_parallel_rank() == (get_pipeline_parallel_world_size() - 1)
)

def get_dit_world_size():
"""Return world size for the DiT model (excluding VAE)."""
return (get_data_parallel_world_size() *
get_classifier_free_guidance_world_size() *
get_sequence_parallel_world_size() *
get_pipeline_parallel_world_size() *
get_tensor_model_parallel_world_size())

# Add VAE getter functions
def get_vae_parallel_group() -> GroupCoordinator:
assert _VAE is not None, "VAE parallel group is not initialized"
return _VAE

def get_vae_parallel_world_size():
"""Return world size for the VAE parallel group."""
return get_vae_parallel_group().world_size

def get_vae_parallel_rank():
"""Return my rank for the VAE parallel group."""
return get_vae_parallel_group().rank_in_group

# * SET

Expand Down Expand Up @@ -215,7 +236,6 @@ def init_distributed_environment(
_WORLD.world_size == torch.distributed.get_world_size()
), "world group already initialized with a different world size"


def model_parallel_is_initialized():
"""Check if tensor and pipeline parallel groups are initialized."""
return (
Expand Down Expand Up @@ -260,7 +280,32 @@ def init_model_parallel_group(
local_rank=local_rank,
torch_distributed_backend=backend,
)


def init_dit_group(
dit_world_size: int,
backend: str,
):
global _DIT
_DIT = torch.distributed.new_group(
ranks=list(range(dit_world_size)), backend=backend
)

def get_dit_group():
assert _DIT is not None, "DIT group is not initialized"
return _DIT

def init_vae_group(
dit_world_size: int,
vae_parallel_size: int,
backend: str,
):
# Initialize VAE group first
global _VAE
assert _VAE is None, "VAE parallel group is already initialized"
vae_ranks = list(range(dit_world_size, dit_world_size + vae_parallel_size))
_VAE = torch.distributed.new_group(
ranks=vae_ranks, backend=backend
)

def initialize_model_parallel(
data_parallel_degree: int = 1,
Expand All @@ -270,6 +315,7 @@ def initialize_model_parallel(
ring_degree: int = 1,
tensor_parallel_degree: int = 1,
pipeline_parallel_degree: int = 1,
vae_parallel_size: int = 0,
backend: Optional[str] = None,
) -> None:
"""
Expand Down Expand Up @@ -315,17 +361,15 @@ def initialize_model_parallel(
assert torch.distributed.is_initialized()
world_size: int = torch.distributed.get_world_size()
backend = backend or torch.distributed.get_backend(get_world_group().device_group)
dit_world_size = (data_parallel_degree *
classifier_free_guidance_degree *
sequence_parallel_degree *
pipeline_parallel_degree *
tensor_parallel_degree)

if (
world_size
!= data_parallel_degree
* classifier_free_guidance_degree
* sequence_parallel_degree
* tensor_parallel_degree
* pipeline_parallel_degree
):
if world_size < dit_world_size:
raise RuntimeError(
f"world_size ({world_size}) is not equal to "
f"world_size ({world_size}) is less than "
f"tensor_parallel_degree ({tensor_parallel_degree}) x "
f"pipeline_parallel_degree ({pipeline_parallel_degree}) x"
f"sequence_parallel_degree ({sequence_parallel_degree}) x"
Expand All @@ -344,7 +388,6 @@ def initialize_model_parallel(
)
global _DP
assert _DP is None, "data parallel group is already initialized"

_DP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("dp"),
local_rank=get_world_group().local_rank,
Expand Down Expand Up @@ -382,8 +425,9 @@ def initialize_model_parallel(
sp_ulysses_degree=ulysses_degree,
sp_ring_degree=ring_degree,
rank=get_world_group().rank_in_group,
world_size=get_world_group().world_size,
world_size=dit_world_size
)

_SP = init_model_parallel_group(
group_ranks=rank_generator.get_ranks("sp"),
local_rank=get_world_group().local_rank,
Expand All @@ -409,6 +453,9 @@ def initialize_model_parallel(
parallel_mode="tensor",
)

if vae_parallel_size > 0:
init_vae_group(dit_world_size, vae_parallel_size, backend)
init_dit_group(dit_world_size, backend)

def destroy_model_parallel():
"""Set the groups to none and destroy them."""
Expand Down Expand Up @@ -437,6 +484,11 @@ def destroy_model_parallel():
_PP.destroy()
_PP = None

global _VAE
if _VAE:
_VAE.destroy()
_VAE = None


def destroy_distributed_environment():
global _WORLD
Expand Down
1 change: 1 addition & 0 deletions xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def _check_distributed_env(
ring_degree=parallel_config.ring_degree,
tensor_parallel_degree=parallel_config.tp_degree,
pipeline_parallel_degree=parallel_config.pp_degree,
vae_parallel_size=parallel_config.vae_parallel_size,
)

def destory_distributed_env(self):
Expand Down
Loading

0 comments on commit dd38452

Please sign in to comment.