From 9fc69883d150cff6247288e5870f289d5e8b6ed8 Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 2 Mar 2026 10:45:36 -0800 Subject: [PATCH 1/2] add num gpu validation Signed-off-by: Alexandros Koumparoulis --- .../components/distributed/mesh_utils.py | 39 +++-- nemo_automodel/recipes/_dist_setup.py | 133 +++++++++++++++++ tests/unit_tests/recipes/test_dist_setup.py | 141 +++++++++++++++++- 3 files changed, 299 insertions(+), 14 deletions(-) diff --git a/nemo_automodel/components/distributed/mesh_utils.py b/nemo_automodel/components/distributed/mesh_utils.py index 2e247e066..a4150c9b8 100644 --- a/nemo_automodel/components/distributed/mesh_utils.py +++ b/nemo_automodel/components/distributed/mesh_utils.py @@ -170,15 +170,22 @@ def _create_fsdp2_device_mesh( if dp_replicate_size is None or dp_replicate_size <= 0: dp_replicate_size = 1 - # HSDP usecase: dp_size = dp_replicate_size * dp_shard_size - assert dp_size % dp_replicate_size == 0, "dp_size must be a multiple of dp_replicate_size" - assert dp_replicate_size < dp_size or dp_replicate_size == 1, ( - "dp_replicate_size must be less than dp_size since ddp usecase is not supported by FSDP2" - ) + if dp_size % dp_replicate_size != 0: + raise ValueError( + f"dp_size ({dp_size}) must be a multiple of dp_replicate_size ({dp_replicate_size})." + ) + if dp_replicate_size > 1 and dp_replicate_size >= dp_size: + raise ValueError( + f"dp_replicate_size ({dp_replicate_size}) must be less than dp_size ({dp_size}); " + f"pure DDP replication is not supported by FSDP2." + ) - # Expert parallelism calculations dp_cp_size = dp_size * cp_size - assert dp_cp_size % ep_size == 0, f"{dp_cp_size=} must be a multiple of {ep_size=}" + if dp_cp_size % ep_size != 0: + raise ValueError( + f"(dp_size * cp_size) = {dp_size} * {cp_size} = {dp_cp_size} " + f"must be a multiple of ep_size ({ep_size})." + ) if ep_size < dp_cp_size: ep_shard_size = dp_cp_size // ep_size else: @@ -196,8 +203,10 @@ def _create_fsdp2_device_mesh( MeshAxisName.TP, ) for shape, name in zip(mesh_shape, mesh_names): - assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}" - assert shape > 0, f"Expected {name} > 0, got {shape}" + if not isinstance(shape, int): + raise TypeError(f"Expected {name} to be an int, got {type(shape).__name__}") + if shape <= 0: + raise ValueError(f"Expected {name} > 0, got {shape}") device_mesh = init_device_mesh( device_type="cuda" if backend == "nccl" else "cpu", @@ -282,8 +291,10 @@ def _create_megatron_fsdp_device_mesh( mesh_shape = (dp_size, cp_size, tp_size) mesh_names = (MeshAxisName.DP, MeshAxisName.CP, MeshAxisName.TP) for shape, name in zip(mesh_shape, mesh_names): - assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}" - assert shape > 0, f"Expected {name} > 0, got {shape}" + if not isinstance(shape, int): + raise TypeError(f"Expected {name} to be an int, got {type(shape).__name__}") + if shape <= 0: + raise ValueError(f"Expected {name} > 0, got {shape}") # Build mesh [dp, cp, tp] device_mesh = init_device_mesh( @@ -323,8 +334,10 @@ def _create_moe_mesh( mesh_shape = (pp_size, ep_shard_size, ep_size) mesh_names = (MeshAxisName.PP, MeshAxisName.EP_SHARD, MeshAxisName.EP) for shape, name in zip(mesh_shape, mesh_names): - assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}" - assert shape > 0, f"Expected {name} > 0, got {shape}" + if not isinstance(shape, int): + raise TypeError(f"Expected {name} to be an int, got {type(shape).__name__}") + if shape <= 0: + raise ValueError(f"Expected {name} > 0, got {shape}") moe_mesh = init_device_mesh( device_type="cuda" if backend == "nccl" else "cpu", diff --git a/nemo_automodel/recipes/_dist_setup.py b/nemo_automodel/recipes/_dist_setup.py index 68f97d954..01ef23d23 100644 --- a/nemo_automodel/recipes/_dist_setup.py +++ b/nemo_automodel/recipes/_dist_setup.py @@ -40,6 +40,129 @@ } +def validate_num_gpus( + *, + world_size: int, + tp_size: int, + pp_size: int, + cp_size: int, + ep_size: int, + dp_size: Optional[int], + dp_replicate_size: Optional[int], +) -> None: + """Validate that parallelism dimensions are compatible with the number of GPUs. + + This runs **before** device-mesh creation so that users see a single, + actionable error instead of a cryptic ``init_device_mesh`` crash. + + Raises: + ValueError: With a message that explains the mismatch and suggests + concrete fixes. + """ + tp = tp_size if tp_size and tp_size > 0 else 1 + pp = pp_size if pp_size and pp_size > 0 else 1 + cp = cp_size if cp_size and cp_size > 0 else 1 + ep = ep_size if ep_size and ep_size > 0 else 1 + + if world_size <= 0: + raise ValueError( + f"num_gpus (world_size) must be a positive integer, got {world_size}.\n" + f" Set the WORLD_SIZE environment variable or use " + f"torchrun --nproc_per_node to configure the number of GPUs." + ) + + for name, val in [("tp_size", tp_size), ("pp_size", pp_size), ("cp_size", cp_size), ("ep_size", ep_size)]: + if val is not None and val < 0: + raise ValueError(f"{name} must be a non-negative integer, got {val}.") + + # tp * pp * cp is the minimum number of GPUs required (dp >= 1 is implicit) + explicit_product = tp * pp * cp + if explicit_product > world_size: + raise ValueError( + f"Not enough GPUs: tp_size * pp_size * cp_size = " + f"{tp} * {pp} * {cp} = {explicit_product}, " + f"but only {world_size} GPU(s) available.\n" + f" The minimum number of GPUs required is " + f"tp_size * pp_size * cp_size = {explicit_product}.\n" + f" Either reduce your parallelism sizes or increase the number of GPUs." + ) + + if world_size % explicit_product != 0: + # Suggest the nearest valid GPU counts + lower = (world_size // explicit_product) * explicit_product + upper = lower + explicit_product + suggestions = [v for v in (lower, upper) if v > 0] + raise ValueError( + f"num_gpus ({world_size}) is not divisible by " + f"tp_size * pp_size * cp_size = {tp} * {pp} * {cp} = {explicit_product}.\n" + f" data-parallel degree (dp_size) is computed as: " + f"num_gpus / (tp_size * pp_size * cp_size), which must be a whole number.\n" + f" To fix, either:\n" + f" - change num_gpus to a multiple of {explicit_product} " + f"(nearest valid: {', '.join(map(str, suggestions))}), or\n" + f" - adjust tp_size, pp_size, or cp_size so their product divides {world_size}." + ) + + inferred_dp = world_size // explicit_product + + # When dp_size is explicitly set, it must be consistent with world_size + if dp_size is not None and dp_size > 0: + expected_world = tp * pp * cp * dp_size + if expected_world != world_size: + raise ValueError( + f"Parallelism dimensions do not match the number of GPUs.\n" + f" tp_size * pp_size * cp_size * dp_size = " + f"{tp} * {pp} * {cp} * {dp_size} = {expected_world}, " + f"but num_gpus = {world_size}.\n" + f" To fix, either:\n" + f" - remove dp_size (set to null) to auto-infer it as {inferred_dp}, or\n" + f" - change num_gpus to {expected_world}, or\n" + f" - adjust dp_size to {inferred_dp}." + ) + inferred_dp = dp_size + + # HSDP: dp_replicate_size must evenly divide dp_size + if dp_replicate_size is not None and dp_replicate_size > 1: + if inferred_dp % dp_replicate_size != 0: + valid_values = [i for i in range(2, inferred_dp + 1) if inferred_dp % i == 0 and i < inferred_dp] + hint = f"valid dp_replicate_size values: {valid_values}" if valid_values else "increase dp_size first" + raise ValueError( + f"dp_replicate_size ({dp_replicate_size}) does not evenly divide " + f"dp_size ({inferred_dp}).\n" + f" For HSDP, dp_size must be a multiple of dp_replicate_size.\n" + f" To fix: {hint}." + ) + if dp_replicate_size >= inferred_dp: + raise ValueError( + f"dp_replicate_size ({dp_replicate_size}) must be strictly less than " + f"dp_size ({inferred_dp}).\n" + f" Pure DDP replication is not supported with FSDP2; there must be " + f"at least 2 sharding groups.\n" + f" To fix: reduce dp_replicate_size or increase the number of GPUs." + ) + + # EP: (dp_size * cp_size) must be divisible by ep_size + if ep > 1: + dp_cp = inferred_dp * cp + if dp_cp < ep: + raise ValueError( + f"ep_size ({ep}) exceeds dp_size * cp_size = " + f"{inferred_dp} * {cp} = {dp_cp}.\n" + f" Expert-parallel degree cannot exceed the data-parallel * " + f"context-parallel degree.\n" + f" To fix: reduce ep_size to at most {dp_cp}, " + f"or increase the number of GPUs." + ) + if dp_cp % ep != 0: + valid_ep = [i for i in range(2, dp_cp + 1) if dp_cp % i == 0] + raise ValueError( + f"(dp_size * cp_size) = {inferred_dp} * {cp} = {dp_cp} " + f"is not divisible by ep_size ({ep}).\n" + f" ep_size must evenly divide (dp_size * cp_size).\n" + f" Valid ep_size values for this configuration: {valid_ep}." + ) + + def _validate_strategy_kwargs( strategy_name: str, strategy_cls: type, @@ -161,6 +284,16 @@ def setup_distributed(cfg: Any, world_size: int) -> MeshContext: cfg_dict = cfg.distributed.to_dict() if not isinstance(cfg, dict) else cfg parsed = parse_distributed_section(cfg_dict) + validate_num_gpus( + world_size=world_size, + tp_size=parsed["tp_size"], + pp_size=parsed["pp_size"], + cp_size=parsed["cp_size"], + ep_size=parsed["ep_size"], + dp_size=parsed["dp_size"], + dp_replicate_size=parsed["dp_replicate_size"], + ) + device_mesh, moe_mesh = create_device_mesh( parsed["strategy_config"], dp_size=parsed["dp_size"], diff --git a/tests/unit_tests/recipes/test_dist_setup.py b/tests/unit_tests/recipes/test_dist_setup.py index be95a7827..aa52f43b0 100644 --- a/tests/unit_tests/recipes/test_dist_setup.py +++ b/tests/unit_tests/recipes/test_dist_setup.py @@ -22,7 +22,7 @@ from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config, MegatronFSDPConfig from nemo_automodel.components.distributed.pipelining.config import PipelineConfig from nemo_automodel.components.moe.config import MoEParallelizerConfig -from nemo_automodel.recipes._dist_setup import parse_distributed_section +from nemo_automodel.recipes._dist_setup import parse_distributed_section, validate_num_gpus # --------------------------------------------------------------------------- # Basic dict parsing @@ -354,3 +354,142 @@ def test_combined_pipeline_and_moe(self): def test_backend_configuration(self, strategy): result = parse_distributed_section({"strategy": strategy, "backend": "gloo"}) assert result["strategy_config"].backend == "gloo" + + +# --------------------------------------------------------------------------- +# validate_num_gpus – happy paths +# --------------------------------------------------------------------------- + + +class TestValidateNumGpusHappy: + def test_single_gpu(self): + validate_num_gpus(world_size=1, tp_size=1, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None) + + def test_8gpu_tp2_pp2(self): + validate_num_gpus(world_size=8, tp_size=2, pp_size=2, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None) + + def test_8gpu_tp8(self): + validate_num_gpus(world_size=8, tp_size=8, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None) + + def test_16gpu_tp2_pp2_cp2(self): + validate_num_gpus( + world_size=16, tp_size=2, pp_size=2, cp_size=2, ep_size=1, dp_size=None, dp_replicate_size=None + ) + + def test_explicit_dp_size(self): + validate_num_gpus(world_size=8, tp_size=2, pp_size=1, cp_size=1, ep_size=1, dp_size=4, dp_replicate_size=None) + + def test_hsdp(self): + validate_num_gpus(world_size=16, tp_size=2, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=2) + + def test_ep(self): + validate_num_gpus(world_size=8, tp_size=1, pp_size=1, cp_size=1, ep_size=2, dp_size=None, dp_replicate_size=None) + + def test_ep_with_cp(self): + validate_num_gpus(world_size=8, tp_size=1, pp_size=1, cp_size=2, ep_size=4, dp_size=None, dp_replicate_size=None) + + def test_full_parallelism(self): + validate_num_gpus( + world_size=64, tp_size=2, pp_size=2, cp_size=2, ep_size=4, dp_size=None, dp_replicate_size=2 + ) + + def test_none_sizes_treated_as_1(self): + validate_num_gpus( + world_size=4, tp_size=None, pp_size=None, cp_size=None, ep_size=None, dp_size=None, dp_replicate_size=None + ) + + def test_zero_sizes_treated_as_1(self): + validate_num_gpus( + world_size=4, tp_size=0, pp_size=0, cp_size=0, ep_size=0, dp_size=None, dp_replicate_size=None + ) + + +# --------------------------------------------------------------------------- +# validate_num_gpus – error cases +# --------------------------------------------------------------------------- + + +class TestValidateNumGpusErrors: + def test_zero_world_size(self): + with pytest.raises(ValueError, match="must be a positive integer"): + validate_num_gpus( + world_size=0, tp_size=1, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None + ) + + def test_negative_world_size(self): + with pytest.raises(ValueError, match="must be a positive integer"): + validate_num_gpus( + world_size=-1, tp_size=1, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None + ) + + def test_tp_exceeds_gpus(self): + with pytest.raises(ValueError, match="Not enough GPUs"): + validate_num_gpus( + world_size=4, tp_size=8, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None + ) + + def test_product_exceeds_gpus(self): + with pytest.raises(ValueError, match="Not enough GPUs"): + validate_num_gpus( + world_size=4, tp_size=2, pp_size=2, cp_size=2, ep_size=1, dp_size=None, dp_replicate_size=None + ) + + def test_world_not_divisible(self): + with pytest.raises(ValueError, match="not divisible"): + validate_num_gpus( + world_size=7, tp_size=2, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None + ) + + def test_world_not_divisible_suggests_fix(self): + with pytest.raises(ValueError, match="nearest valid"): + validate_num_gpus( + world_size=5, tp_size=2, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None + ) + + def test_explicit_dp_mismatch(self): + with pytest.raises(ValueError, match="do not match"): + validate_num_gpus( + world_size=8, tp_size=2, pp_size=1, cp_size=1, ep_size=1, dp_size=2, dp_replicate_size=None + ) + + def test_explicit_dp_mismatch_suggests_auto(self): + with pytest.raises(ValueError, match="auto-infer"): + validate_num_gpus( + world_size=8, tp_size=2, pp_size=1, cp_size=1, ep_size=1, dp_size=2, dp_replicate_size=None + ) + + def test_hsdp_not_divisible(self): + with pytest.raises(ValueError, match="does not evenly divide"): + validate_num_gpus( + world_size=8, tp_size=1, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=3 + ) + + def test_hsdp_replicate_equals_dp(self): + with pytest.raises(ValueError, match="strictly less than"): + validate_num_gpus( + world_size=8, tp_size=1, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=8 + ) + + def test_ep_exceeds_dp_cp(self): + with pytest.raises(ValueError, match="exceeds"): + validate_num_gpus( + world_size=4, tp_size=2, pp_size=1, cp_size=1, ep_size=4, dp_size=None, dp_replicate_size=None + ) + + def test_ep_not_divisible(self): + with pytest.raises(ValueError, match="not divisible by ep_size"): + validate_num_gpus( + world_size=12, tp_size=1, pp_size=1, cp_size=1, ep_size=5, dp_size=None, dp_replicate_size=None + ) + + def test_ep_not_divisible_suggests_valid(self): + with pytest.raises(ValueError, match="Valid ep_size"): + validate_num_gpus( + world_size=12, tp_size=1, pp_size=1, cp_size=1, ep_size=5, dp_size=None, dp_replicate_size=None + ) + + def test_negative_tp_size(self): + with pytest.raises(ValueError, match="non-negative"): + validate_num_gpus( + world_size=8, tp_size=-1, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None + ) From 06c4170fb836452c3e08161c76e8de39902f4cbd Mon Sep 17 00:00:00 2001 From: Alexandros Koumparoulis Date: Mon, 2 Mar 2026 10:47:19 -0800 Subject: [PATCH 2/2] add product tests Signed-off-by: Alexandros Koumparoulis --- tests/unit_tests/recipes/test_dist_setup.py | 35 +++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/unit_tests/recipes/test_dist_setup.py b/tests/unit_tests/recipes/test_dist_setup.py index aa52f43b0..a806f4a73 100644 --- a/tests/unit_tests/recipes/test_dist_setup.py +++ b/tests/unit_tests/recipes/test_dist_setup.py @@ -493,3 +493,38 @@ def test_negative_tp_size(self): validate_num_gpus( world_size=8, tp_size=-1, pp_size=1, cp_size=1, ep_size=1, dp_size=None, dp_replicate_size=None ) + + # -- total-product-of-axes checks ---------------------------------------- + + def test_all_axes_product_exceeds_world_size(self): + """tp * pp * cp * dp = 2 * 2 * 2 * 4 = 32 ≠ 16 GPUs.""" + with pytest.raises(ValueError, match="do not match"): + validate_num_gpus( + world_size=16, tp_size=2, pp_size=2, cp_size=2, ep_size=1, dp_size=4, dp_replicate_size=None + ) + + def test_all_axes_product_less_than_world_size(self): + """tp * pp * cp * dp = 2 * 2 * 1 * 1 = 4 ≠ 8 GPUs.""" + with pytest.raises(ValueError, match="do not match"): + validate_num_gpus( + world_size=8, tp_size=2, pp_size=2, cp_size=1, ep_size=1, dp_size=1, dp_replicate_size=None + ) + + def test_all_axes_product_matches_world_size(self): + """tp * pp * cp * dp = 2 * 2 * 2 * 2 = 16 == 16 GPUs — should pass.""" + validate_num_gpus( + world_size=16, tp_size=2, pp_size=2, cp_size=2, ep_size=1, dp_size=2, dp_replicate_size=None + ) + + def test_all_axes_with_ep_product_matches(self): + """tp * pp * cp * dp = 2 * 2 * 2 * 4 = 32 GPUs, ep=4 divides dp*cp=8 — should pass.""" + validate_num_gpus( + world_size=32, tp_size=2, pp_size=2, cp_size=2, ep_size=4, dp_size=4, dp_replicate_size=None + ) + + def test_all_axes_with_ep_product_mismatch(self): + """tp * pp * cp * dp = 2 * 2 * 2 * 4 = 32 ≠ 64 GPUs.""" + with pytest.raises(ValueError, match="do not match"): + validate_num_gpus( + world_size=64, tp_size=2, pp_size=2, cp_size=2, ep_size=4, dp_size=4, dp_replicate_size=None + )