diff --git a/xfuser/core/distributed/parallel_state.py b/xfuser/core/distributed/parallel_state.py index 67b434e1..81768406 100644 --- a/xfuser/core/distributed/parallel_state.py +++ b/xfuser/core/distributed/parallel_state.py @@ -8,14 +8,13 @@ import torch import torch.distributed import xfuser.envs as envs -import os from xfuser.logger import init_logger from .group_coordinator import ( GroupCoordinator, PipelineGroupCoordinator, SequenceParallelGroupCoordinator, ) -from .utils import RankGenerator, generate_masked_orthogonal_rank_groups +from .utils import RankGenerator env_info = envs.PACKAGES_CHECKER.get_packages_info() HAS_LONG_CTX_ATTN = env_info["has_long_ctx_attn"] @@ -281,7 +280,7 @@ def init_model_parallel_group( local_rank=local_rank, torch_distributed_backend=backend, ) - + def init_dit_group( dit_parallel_size: int, backend: str, @@ -290,7 +289,7 @@ def init_dit_group( _DIT = torch.distributed.new_group( ranks=list(range(dit_parallel_size)), backend=backend ) - + def get_dit_group(): assert _DIT is not None, "DIT group is not initialized" return _DIT @@ -338,12 +337,12 @@ def initialize_model_parallel( dp_degree (2) * cfg_degree (2) * sp_degree (2) * pp_degree (2) = 16. - The present function will create 2 data parallel-groups, + The present function will create 8 data-parallel groups, 8 CFG group, 8 pipeline-parallel group, and 8 sequence-parallel groups: - 2 data-parallel groups: - [g0, g1, g2, g3, g4, g5, g6, g7], - [g8, g9, g10, g11, g12, g13, g14, g15] + 8 data-parallel groups: + [g0, g8], [g1, g9], [g2, g10], [g3, g11], + [g4, g12], [g5, g13], [g6, g14], [g7, g15] 8 CFG-parallel groups: [g0, g4], [g1, g5], [g2, g6], [g3, g7], [g8, g12], [g9, g13], [g10, g14], [g11, g15]