Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions src/megatron/bridge/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1630,6 +1630,8 @@ def validate(self) -> None:
"When finetuning with CP>1, average_in_collective must be False"
)

self._validate_cp_comm_type()

if (
isinstance(self.dataset, FinetuningDatasetConfig)
and self.dataset.packed_sequence_specs is not None
Expand Down Expand Up @@ -1688,6 +1690,40 @@ def validate(self) -> None:
)
setattr(self.validation, f.name, train_val)

def _validate_cp_comm_type(self) -> None:
"""Validate cp_comm_type and hierarchical_context_parallel_sizes consistency."""
cp_comm_type = getattr(self.model, "cp_comm_type", None)
hcp_sizes = getattr(self.model, "hierarchical_context_parallel_sizes", None)
cp_size = getattr(self.model, "context_parallel_size", 1)

if cp_size > 1 and cp_comm_type is not None:
if isinstance(cp_comm_type, list):
assert len(cp_comm_type) == self.model.num_layers, (
f"Length of cp_comm_type ({len(cp_comm_type)}) must equal num_layers ({self.model.num_layers})."
)
else:
assert isinstance(cp_comm_type, str), (
f"cp_comm_type must be a str or list of str, got {type(cp_comm_type)}."
)

cp_comm_types = cp_comm_type if isinstance(cp_comm_type, list) else [cp_comm_type or "p2p"]
if any("a2a+p2p" in ct for ct in cp_comm_types):
assert hcp_sizes is not None, (
"hierarchical_context_parallel_sizes must be set when cp_comm_type "
"contains 'a2a+p2p'. Without it, CP communication is silently disabled "
"and each rank attends only to its local chunk, producing artificially "
"high throughput but broken training. Example: for cp=16 across 4 nodes "
"of 8 GPUs, set hierarchical_context_parallel_sizes=[8, 2]."
)

if hcp_sizes is not None:
from math import prod

assert prod(hcp_sizes) == cp_size, (
f"Product of hierarchical_context_parallel_sizes {hcp_sizes} "
f"(={prod(hcp_sizes)}) must equal context_parallel_size (={cp_size})."
)

def _validate_training_scheduler_compatibility(self) -> None:
"""Cross-validation between training and scheduler configs."""
has_train_samples = self.train.train_samples is not None
Expand Down
8 changes: 8 additions & 0 deletions src/megatron/bridge/training/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,14 @@ def _create_pg_collection(
get_position_embedding_ranks: Optional[Callable[[list[int], Optional[int]], list[int]]] = None,
) -> ProcessGroupCollection:
"""Create all process groups via HyperCommGrid and return a ProcessGroupCollection."""
hcp_sizes = getattr(model_config, "hierarchical_context_parallel_sizes", None)
if hcp_sizes is not None:
raise NotImplementedError(
"Decentralized process groups (use_decentralized_pg=True) do not support "
"hierarchical_context_parallel_sizes. Use cp_comm_type='a2a' or 'p2p' instead, "
"or set use_decentralized_pg=False to use the MPU path which supports 'a2a+p2p'."
)

world_size = torch.distributed.get_world_size()
tp_size = int(model_config.tensor_model_parallel_size)
pp_size = int(model_config.pipeline_model_parallel_size)
Expand Down
4 changes: 4 additions & 0 deletions tests/unit_tests/training/test_decentralized_pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def mock_model_config(self):
config.context_parallel_size = 1
config.expert_tensor_parallel_size = None
config.expert_model_parallel_size = 1
config.hierarchical_context_parallel_sizes = None
return config

@patch("megatron.bridge.training.initialize.HyperCommGrid")
Expand Down Expand Up @@ -692,6 +693,7 @@ def test_create_pg_collection_with_cp(self, mock_subgroups, mock_world_size, moc
mock_model_config.context_parallel_size = 2 # CP=2
mock_model_config.expert_tensor_parallel_size = None
mock_model_config.expert_model_parallel_size = 1
mock_model_config.hierarchical_context_parallel_sizes = None

# Setup mock
mock_grid_instance = MagicMock()
Expand Down Expand Up @@ -726,6 +728,7 @@ def test_create_pg_collection_with_tp_cp_pp(self, mock_subgroups, mock_world_siz
mock_model_config.context_parallel_size = 2
mock_model_config.expert_tensor_parallel_size = None
mock_model_config.expert_model_parallel_size = 1
mock_model_config.hierarchical_context_parallel_sizes = None

# Setup mock
mock_grid_instance = MagicMock()
Expand Down Expand Up @@ -757,6 +760,7 @@ def mock_model_config(self):
config.context_parallel_size = 1
config.expert_tensor_parallel_size = None
config.expert_model_parallel_size = 1
config.hierarchical_context_parallel_sizes = None
return config

@patch("megatron.bridge.training.initialize.HyperCommGrid")
Expand Down
Loading