diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index c6c57ca733..6b81cf750f 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -1630,6 +1630,8 @@ def validate(self) -> None: "When finetuning with CP>1, average_in_collective must be False" ) + self._validate_cp_comm_type() + if ( isinstance(self.dataset, FinetuningDatasetConfig) and self.dataset.packed_sequence_specs is not None @@ -1688,6 +1690,40 @@ def validate(self) -> None: ) setattr(self.validation, f.name, train_val) + def _validate_cp_comm_type(self) -> None: + """Validate cp_comm_type and hierarchical_context_parallel_sizes consistency.""" + cp_comm_type = getattr(self.model, "cp_comm_type", None) + hcp_sizes = getattr(self.model, "hierarchical_context_parallel_sizes", None) + cp_size = getattr(self.model, "context_parallel_size", 1) + + if cp_size > 1 and cp_comm_type is not None: + if isinstance(cp_comm_type, list): + assert len(cp_comm_type) == self.model.num_layers, ( + f"Length of cp_comm_type ({len(cp_comm_type)}) must equal num_layers ({self.model.num_layers})." + ) + else: + assert isinstance(cp_comm_type, str), ( + f"cp_comm_type must be a str or list of str, got {type(cp_comm_type)}." + ) + + cp_comm_types = cp_comm_type if isinstance(cp_comm_type, list) else [cp_comm_type or "p2p"] + if any("a2a+p2p" in ct for ct in cp_comm_types): + assert hcp_sizes is not None, ( + "hierarchical_context_parallel_sizes must be set when cp_comm_type " + "contains 'a2a+p2p'. Without it, CP communication is silently disabled " + "and each rank attends only to its local chunk, producing artificially " + "high throughput but broken training. Example: for cp=16 across 4 nodes " + "of 8 GPUs, set hierarchical_context_parallel_sizes=[8, 2]." + ) + + if hcp_sizes is not None: + from math import prod + + assert prod(hcp_sizes) == cp_size, ( + f"Product of hierarchical_context_parallel_sizes {hcp_sizes} " + f"(={prod(hcp_sizes)}) must equal context_parallel_size (={cp_size})." + ) + def _validate_training_scheduler_compatibility(self) -> None: """Cross-validation between training and scheduler configs.""" has_train_samples = self.train.train_samples is not None diff --git a/src/megatron/bridge/training/initialize.py b/src/megatron/bridge/training/initialize.py index 5229906b0e..eda8c05b5e 100644 --- a/src/megatron/bridge/training/initialize.py +++ b/src/megatron/bridge/training/initialize.py @@ -388,6 +388,14 @@ def _create_pg_collection( get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, ) -> ProcessGroupCollection: """Create all process groups via HyperCommGrid and return a ProcessGroupCollection.""" + hcp_sizes = getattr(model_config, "hierarchical_context_parallel_sizes", None) + if hcp_sizes is not None: + raise NotImplementedError( + "Decentralized process groups (use_decentralized_pg=True) do not support " + "hierarchical_context_parallel_sizes. Use cp_comm_type='a2a' or 'p2p' instead, " + "or set use_decentralized_pg=False to use the MPU path which supports 'a2a+p2p'." + ) + world_size = torch.distributed.get_world_size() tp_size = int(model_config.tensor_model_parallel_size) pp_size = int(model_config.pipeline_model_parallel_size) diff --git a/tests/unit_tests/training/test_decentralized_pg.py b/tests/unit_tests/training/test_decentralized_pg.py index 26a00f6153..a739fedcb1 100644 --- a/tests/unit_tests/training/test_decentralized_pg.py +++ b/tests/unit_tests/training/test_decentralized_pg.py @@ -59,6 +59,7 @@ def mock_model_config(self): config.context_parallel_size = 1 config.expert_tensor_parallel_size = None config.expert_model_parallel_size = 1 + config.hierarchical_context_parallel_sizes = None return config @patch("megatron.bridge.training.initialize.HyperCommGrid") @@ -692,6 +693,7 @@ def test_create_pg_collection_with_cp(self, mock_subgroups, mock_world_size, moc mock_model_config.context_parallel_size = 2 # CP=2 mock_model_config.expert_tensor_parallel_size = None mock_model_config.expert_model_parallel_size = 1 + mock_model_config.hierarchical_context_parallel_sizes = None # Setup mock mock_grid_instance = MagicMock() @@ -726,6 +728,7 @@ def test_create_pg_collection_with_tp_cp_pp(self, mock_subgroups, mock_world_siz mock_model_config.context_parallel_size = 2 mock_model_config.expert_tensor_parallel_size = None mock_model_config.expert_model_parallel_size = 1 + mock_model_config.hierarchical_context_parallel_sizes = None # Setup mock mock_grid_instance = MagicMock() @@ -757,6 +760,7 @@ def mock_model_config(self): config.context_parallel_size = 1 config.expert_tensor_parallel_size = None config.expert_model_parallel_size = 1 + config.hierarchical_context_parallel_sizes = None return config @patch("megatron.bridge.training.initialize.HyperCommGrid")