From a2a1c9b9845c3d4f80e393097d296e4930fce40a Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Thu, 5 Mar 2026 10:20:40 -0700 Subject: [PATCH 1/3] [training] fix: validate hierarchical_context_parallel_sizes for a2a+p2p cp_comm_type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prevent silent training degradation when cp_comm_type='a2a+p2p' is used without hierarchical_context_parallel_sizes. Without this validation, context parallel communication is silently disabled — each CP rank attends only to its local chunk, producing artificially high throughput but broken training. Signed-off-by: yaoyu-33 Made-with: Cursor --- 3rdparty/Megatron-LM | 2 +- src/megatron/bridge/training/initialize.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/3rdparty/Megatron-LM b/3rdparty/Megatron-LM index 23dd639cf3..26857798c7 160000 --- a/3rdparty/Megatron-LM +++ b/3rdparty/Megatron-LM @@ -1 +1 @@ -Subproject commit 23dd639cf3de30f3b9d8d0fae71ee31180be9ddd +Subproject commit 26857798c716766ea8dda69875cb9373e7d0d1d6 diff --git a/src/megatron/bridge/training/initialize.py b/src/megatron/bridge/training/initialize.py index 5229906b0e..eda8c05b5e 100644 --- a/src/megatron/bridge/training/initialize.py +++ b/src/megatron/bridge/training/initialize.py @@ -388,6 +388,14 @@ def _create_pg_collection( get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None, ) -> ProcessGroupCollection: """Create all process groups via HyperCommGrid and return a ProcessGroupCollection.""" + hcp_sizes = getattr(model_config, "hierarchical_context_parallel_sizes", None) + if hcp_sizes is not None: + raise NotImplementedError( + "Decentralized process groups (use_decentralized_pg=True) do not support " + "hierarchical_context_parallel_sizes. Use cp_comm_type='a2a' or 'p2p' instead, " + "or set use_decentralized_pg=False to use the MPU path which supports 'a2a+p2p'." + ) + world_size = torch.distributed.get_world_size() tp_size = int(model_config.tensor_model_parallel_size) pp_size = int(model_config.pipeline_model_parallel_size) From b32f868bda585afa718c42e25f80d4b1e387a3bf Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Thu, 5 Mar 2026 12:46:45 -0700 Subject: [PATCH 2/3] [training] fix: move HCP/cp_comm_type validation to bridge, revert mcore submodule Revert the Megatron-LM submodule bump and move the hierarchical_context_parallel_sizes / cp_comm_type validations into ConfigContainer._validate_cp_comm_type() on the bridge side instead. Signed-off-by: Yu Yao Signed-off-by: yaoyu-33 Made-with: Cursor --- 3rdparty/Megatron-LM | 2 +- src/megatron/bridge/training/config.py | 36 ++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/3rdparty/Megatron-LM b/3rdparty/Megatron-LM index 26857798c7..23dd639cf3 160000 --- a/3rdparty/Megatron-LM +++ b/3rdparty/Megatron-LM @@ -1 +1 @@ -Subproject commit 26857798c716766ea8dda69875cb9373e7d0d1d6 +Subproject commit 23dd639cf3de30f3b9d8d0fae71ee31180be9ddd diff --git a/src/megatron/bridge/training/config.py b/src/megatron/bridge/training/config.py index c6c57ca733..6b81cf750f 100644 --- a/src/megatron/bridge/training/config.py +++ b/src/megatron/bridge/training/config.py @@ -1630,6 +1630,8 @@ def validate(self) -> None: "When finetuning with CP>1, average_in_collective must be False" ) + self._validate_cp_comm_type() + if ( isinstance(self.dataset, FinetuningDatasetConfig) and self.dataset.packed_sequence_specs is not None @@ -1688,6 +1690,40 @@ def validate(self) -> None: ) setattr(self.validation, f.name, train_val) + def _validate_cp_comm_type(self) -> None: + """Validate cp_comm_type and hierarchical_context_parallel_sizes consistency.""" + cp_comm_type = getattr(self.model, "cp_comm_type", None) + hcp_sizes = getattr(self.model, "hierarchical_context_parallel_sizes", None) + cp_size = getattr(self.model, "context_parallel_size", 1) + + if cp_size > 1 and cp_comm_type is not None: + if isinstance(cp_comm_type, list): + assert len(cp_comm_type) == self.model.num_layers, ( + f"Length of cp_comm_type ({len(cp_comm_type)}) must equal num_layers ({self.model.num_layers})." + ) + else: + assert isinstance(cp_comm_type, str), ( + f"cp_comm_type must be a str or list of str, got {type(cp_comm_type)}." + ) + + cp_comm_types = cp_comm_type if isinstance(cp_comm_type, list) else [cp_comm_type or "p2p"] + if any("a2a+p2p" in ct for ct in cp_comm_types): + assert hcp_sizes is not None, ( + "hierarchical_context_parallel_sizes must be set when cp_comm_type " + "contains 'a2a+p2p'. Without it, CP communication is silently disabled " + "and each rank attends only to its local chunk, producing artificially " + "high throughput but broken training. Example: for cp=16 across 4 nodes " + "of 8 GPUs, set hierarchical_context_parallel_sizes=[8, 2]." + ) + + if hcp_sizes is not None: + from math import prod + + assert prod(hcp_sizes) == cp_size, ( + f"Product of hierarchical_context_parallel_sizes {hcp_sizes} " + f"(={prod(hcp_sizes)}) must equal context_parallel_size (={cp_size})." + ) + def _validate_training_scheduler_compatibility(self) -> None: """Cross-validation between training and scheduler configs.""" has_train_samples = self.train.train_samples is not None From 19f065a09798773ce8c03bf83e89ba8f10341d0f Mon Sep 17 00:00:00 2001 From: yaoyu-33 Date: Mon, 9 Mar 2026 19:31:42 -0700 Subject: [PATCH 3/3] [training, test] fix: set hierarchical_context_parallel_sizes=None in mock configs for test_decentralized_pg MagicMock returns a truthy MagicMock instance for any unset attribute, causing the new NotImplementedError guard in _create_pg_collection() to fire for all tests that call the function directly without patching it. Explicitly set hierarchical_context_parallel_sizes=None on every mock model config that feeds into _create_pg_collection(). Signed-off-by: yaoyu-33 --- tests/unit_tests/training/test_decentralized_pg.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit_tests/training/test_decentralized_pg.py b/tests/unit_tests/training/test_decentralized_pg.py index 26a00f6153..a739fedcb1 100644 --- a/tests/unit_tests/training/test_decentralized_pg.py +++ b/tests/unit_tests/training/test_decentralized_pg.py @@ -59,6 +59,7 @@ def mock_model_config(self): config.context_parallel_size = 1 config.expert_tensor_parallel_size = None config.expert_model_parallel_size = 1 + config.hierarchical_context_parallel_sizes = None return config @patch("megatron.bridge.training.initialize.HyperCommGrid") @@ -692,6 +693,7 @@ def test_create_pg_collection_with_cp(self, mock_subgroups, mock_world_size, moc mock_model_config.context_parallel_size = 2 # CP=2 mock_model_config.expert_tensor_parallel_size = None mock_model_config.expert_model_parallel_size = 1 + mock_model_config.hierarchical_context_parallel_sizes = None # Setup mock mock_grid_instance = MagicMock() @@ -726,6 +728,7 @@ def test_create_pg_collection_with_tp_cp_pp(self, mock_subgroups, mock_world_siz mock_model_config.context_parallel_size = 2 mock_model_config.expert_tensor_parallel_size = None mock_model_config.expert_model_parallel_size = 1 + mock_model_config.hierarchical_context_parallel_sizes = None # Setup mock mock_grid_instance = MagicMock() @@ -757,6 +760,7 @@ def mock_model_config(self): config.context_parallel_size = 1 config.expert_tensor_parallel_size = None config.expert_model_parallel_size = 1 + config.hierarchical_context_parallel_sizes = None return config @patch("megatron.bridge.training.initialize.HyperCommGrid")