Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 26 additions & 13 deletions nemo_automodel/components/distributed/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,15 +170,22 @@ def _create_fsdp2_device_mesh(
if dp_replicate_size is None or dp_replicate_size <= 0:
dp_replicate_size = 1

# HSDP usecase: dp_size = dp_replicate_size * dp_shard_size
assert dp_size % dp_replicate_size == 0, "dp_size must be a multiple of dp_replicate_size"
assert dp_replicate_size < dp_size or dp_replicate_size == 1, (
"dp_replicate_size must be less than dp_size since ddp usecase is not supported by FSDP2"
)
if dp_size % dp_replicate_size != 0:
raise ValueError(
f"dp_size ({dp_size}) must be a multiple of dp_replicate_size ({dp_replicate_size})."
)
if dp_replicate_size > 1 and dp_replicate_size >= dp_size:
raise ValueError(
f"dp_replicate_size ({dp_replicate_size}) must be less than dp_size ({dp_size}); "
f"pure DDP replication is not supported by FSDP2."
)

# Expert parallelism calculations
dp_cp_size = dp_size * cp_size
assert dp_cp_size % ep_size == 0, f"{dp_cp_size=} must be a multiple of {ep_size=}"
if dp_cp_size % ep_size != 0:
raise ValueError(
f"(dp_size * cp_size) = {dp_size} * {cp_size} = {dp_cp_size} "
f"must be a multiple of ep_size ({ep_size})."
)
if ep_size < dp_cp_size:
ep_shard_size = dp_cp_size // ep_size
else:
Expand All @@ -196,8 +203,10 @@ def _create_fsdp2_device_mesh(
MeshAxisName.TP,
)
for shape, name in zip(mesh_shape, mesh_names):
assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}"
assert shape > 0, f"Expected {name} > 0, got {shape}"
if not isinstance(shape, int):
raise TypeError(f"Expected {name} to be an int, got {type(shape).__name__}")
if shape <= 0:
raise ValueError(f"Expected {name} > 0, got {shape}")

device_mesh = init_device_mesh(
device_type="cuda" if backend == "nccl" else "cpu",
Expand Down Expand Up @@ -282,8 +291,10 @@ def _create_megatron_fsdp_device_mesh(
mesh_shape = (dp_size, cp_size, tp_size)
mesh_names = (MeshAxisName.DP, MeshAxisName.CP, MeshAxisName.TP)
for shape, name in zip(mesh_shape, mesh_names):
assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}"
assert shape > 0, f"Expected {name} > 0, got {shape}"
if not isinstance(shape, int):
raise TypeError(f"Expected {name} to be an int, got {type(shape).__name__}")
if shape <= 0:
raise ValueError(f"Expected {name} > 0, got {shape}")

# Build mesh [dp, cp, tp]
device_mesh = init_device_mesh(
Expand Down Expand Up @@ -323,8 +334,10 @@ def _create_moe_mesh(
mesh_shape = (pp_size, ep_shard_size, ep_size)
mesh_names = (MeshAxisName.PP, MeshAxisName.EP_SHARD, MeshAxisName.EP)
for shape, name in zip(mesh_shape, mesh_names):
assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}"
assert shape > 0, f"Expected {name} > 0, got {shape}"
if not isinstance(shape, int):
raise TypeError(f"Expected {name} to be an int, got {type(shape).__name__}")
if shape <= 0:
raise ValueError(f"Expected {name} > 0, got {shape}")

moe_mesh = init_device_mesh(
device_type="cuda" if backend == "nccl" else "cpu",
Expand Down
133 changes: 133 additions & 0 deletions nemo_automodel/recipes/_dist_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,129 @@
}


def validate_num_gpus(
*,
world_size: int,
tp_size: int,
pp_size: int,
cp_size: int,
ep_size: int,
dp_size: Optional[int],
dp_replicate_size: Optional[int],
) -> None:
"""Validate that parallelism dimensions are compatible with the number of GPUs.

This runs **before** device-mesh creation so that users see a single,
actionable error instead of a cryptic ``init_device_mesh`` crash.

Raises:
ValueError: With a message that explains the mismatch and suggests
concrete fixes.
"""
tp = tp_size if tp_size and tp_size > 0 else 1
pp = pp_size if pp_size and pp_size > 0 else 1
cp = cp_size if cp_size and cp_size > 0 else 1
ep = ep_size if ep_size and ep_size > 0 else 1

if world_size <= 0:
raise ValueError(
f"num_gpus (world_size) must be a positive integer, got {world_size}.\n"
f" Set the WORLD_SIZE environment variable or use "
f"torchrun --nproc_per_node to configure the number of GPUs."
)

for name, val in [("tp_size", tp_size), ("pp_size", pp_size), ("cp_size", cp_size), ("ep_size", ep_size)]:
if val is not None and val < 0:
raise ValueError(f"{name} must be a non-negative integer, got {val}.")

# tp * pp * cp is the minimum number of GPUs required (dp >= 1 is implicit)
explicit_product = tp * pp * cp
if explicit_product > world_size:
raise ValueError(
f"Not enough GPUs: tp_size * pp_size * cp_size = "
f"{tp} * {pp} * {cp} = {explicit_product}, "
f"but only {world_size} GPU(s) available.\n"
f" The minimum number of GPUs required is "
f"tp_size * pp_size * cp_size = {explicit_product}.\n"
f" Either reduce your parallelism sizes or increase the number of GPUs."
)

if world_size % explicit_product != 0:
# Suggest the nearest valid GPU counts
lower = (world_size // explicit_product) * explicit_product
upper = lower + explicit_product
suggestions = [v for v in (lower, upper) if v > 0]
raise ValueError(
f"num_gpus ({world_size}) is not divisible by "
f"tp_size * pp_size * cp_size = {tp} * {pp} * {cp} = {explicit_product}.\n"
f" data-parallel degree (dp_size) is computed as: "
f"num_gpus / (tp_size * pp_size * cp_size), which must be a whole number.\n"
f" To fix, either:\n"
f" - change num_gpus to a multiple of {explicit_product} "
f"(nearest valid: {', '.join(map(str, suggestions))}), or\n"
f" - adjust tp_size, pp_size, or cp_size so their product divides {world_size}."
)

inferred_dp = world_size // explicit_product

# When dp_size is explicitly set, it must be consistent with world_size
if dp_size is not None and dp_size > 0:
expected_world = tp * pp * cp * dp_size
if expected_world != world_size:
raise ValueError(
f"Parallelism dimensions do not match the number of GPUs.\n"
f" tp_size * pp_size * cp_size * dp_size = "
f"{tp} * {pp} * {cp} * {dp_size} = {expected_world}, "
f"but num_gpus = {world_size}.\n"
f" To fix, either:\n"
f" - remove dp_size (set to null) to auto-infer it as {inferred_dp}, or\n"
f" - change num_gpus to {expected_world}, or\n"
f" - adjust dp_size to {inferred_dp}."
)
inferred_dp = dp_size

# HSDP: dp_replicate_size must evenly divide dp_size
if dp_replicate_size is not None and dp_replicate_size > 1:
if inferred_dp % dp_replicate_size != 0:
valid_values = [i for i in range(2, inferred_dp + 1) if inferred_dp % i == 0 and i < inferred_dp]
hint = f"valid dp_replicate_size values: {valid_values}" if valid_values else "increase dp_size first"
raise ValueError(
f"dp_replicate_size ({dp_replicate_size}) does not evenly divide "
f"dp_size ({inferred_dp}).\n"
f" For HSDP, dp_size must be a multiple of dp_replicate_size.\n"
f" To fix: {hint}."
)
if dp_replicate_size >= inferred_dp:
raise ValueError(
f"dp_replicate_size ({dp_replicate_size}) must be strictly less than "
f"dp_size ({inferred_dp}).\n"
f" Pure DDP replication is not supported with FSDP2; there must be "
f"at least 2 sharding groups.\n"
f" To fix: reduce dp_replicate_size or increase the number of GPUs."
)

# EP: (dp_size * cp_size) must be divisible by ep_size
if ep > 1:
dp_cp = inferred_dp * cp
if dp_cp < ep:
raise ValueError(
f"ep_size ({ep}) exceeds dp_size * cp_size = "
f"{inferred_dp} * {cp} = {dp_cp}.\n"
f" Expert-parallel degree cannot exceed the data-parallel * "
f"context-parallel degree.\n"
f" To fix: reduce ep_size to at most {dp_cp}, "
f"or increase the number of GPUs."
)
if dp_cp % ep != 0:
valid_ep = [i for i in range(2, dp_cp + 1) if dp_cp % i == 0]
raise ValueError(
f"(dp_size * cp_size) = {inferred_dp} * {cp} = {dp_cp} "
f"is not divisible by ep_size ({ep}).\n"
f" ep_size must evenly divide (dp_size * cp_size).\n"
f" Valid ep_size values for this configuration: {valid_ep}."
)


def _validate_strategy_kwargs(
strategy_name: str,
strategy_cls: type,
Expand Down Expand Up @@ -161,6 +284,16 @@ def setup_distributed(cfg: Any, world_size: int) -> MeshContext:
cfg_dict = cfg.distributed.to_dict() if not isinstance(cfg, dict) else cfg
parsed = parse_distributed_section(cfg_dict)

validate_num_gpus(
world_size=world_size,
tp_size=parsed["tp_size"],
pp_size=parsed["pp_size"],
cp_size=parsed["cp_size"],
ep_size=parsed["ep_size"],
dp_size=parsed["dp_size"],
dp_replicate_size=parsed["dp_replicate_size"],
)

device_mesh, moe_mesh = create_device_mesh(
parsed["strategy_config"],
dp_size=parsed["dp_size"],
Expand Down
Loading