From 7d6a08d91026e1bd64b94316eef336027af4db7d Mon Sep 17 00:00:00 2001 From: JimmyZhang12 <67203904+JimmyZhang12@users.noreply.github.com> Date: Mon, 16 Sep 2024 12:46:59 -0400 Subject: [PATCH] [NeMo-UX] Add token drop callback and optimize mixtral configs (#10361) * add token drop plugin Signed-off-by: Jimmy Zhang * add checks Signed-off-by: Jimmy Zhang * add expert parallel configs Signed-off-by: Jimmy Zhang * Apply isort and black reformatting Signed-off-by: JimmyZhang12 * amend comment Signed-off-by: Jimmy Zhang * Apply isort and black reformatting Signed-off-by: JimmyZhang12 * add comm overlap Signed-off-by: Jimmy Zhang * fix rebase errors Signed-off-by: Jimmy Zhang * Apply isort and black reformatting Signed-off-by: JimmyZhang12 * fix typo Signed-off-by: Jimmy Zhang * add test configs Signed-off-by: Jimmy Zhang * fix Signed-off-by: Jimmy Zhang --------- Signed-off-by: Jimmy Zhang Signed-off-by: JimmyZhang12 Co-authored-by: Jimmy Zhang Co-authored-by: JimmyZhang12 Co-authored-by: Pablo Garay Co-authored-by: Alexandros Koumparoulis <153118171+akoumpa@users.noreply.github.com> Signed-off-by: Alexandros Koumparoulis --- nemo/collections/llm/gpt/model/mixtral.py | 3 +- nemo/collections/llm/recipes/mixtral_8x22b.py | 147 +++++++++++++++-- nemo/collections/llm/recipes/mixtral_8x3b.py | 148 ++++++++++++++++-- nemo/collections/llm/recipes/mixtral_8x7b.py | 107 +++++++++++-- .../pytorch/callbacks/moe_token_drop.py | 55 +++++++ .../llm/recipes/test_mixtral_8x22b.py | 118 ++++++++++++++ .../llm/recipes/test_mixtral_8x3b.py | 110 +++++++++++++ .../llm/recipes/test_mixtral_8x7b.py | 112 +++++++++++++ 8 files changed, 757 insertions(+), 43 deletions(-) create mode 100644 nemo/lightning/pytorch/callbacks/moe_token_drop.py create mode 100644 tests/collections/llm/recipes/test_mixtral_8x22b.py create mode 100644 tests/collections/llm/recipes/test_mixtral_8x3b.py create mode 100644 tests/collections/llm/recipes/test_mixtral_8x7b.py diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py index bc255ae8fb87..bb3dc0068ca3 100644 --- a/nemo/collections/llm/gpt/model/mixtral.py +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -57,11 +57,10 @@ class MixtralConfig(GPTConfig): # MoE num_moe_experts: int = 8 moe_aux_loss_coeff: float = 0.01 - moe_expert_capacity_factor: float = 1.0 - moe_pad_expert_input_to_capacity: bool = True moe_router_topk: int = 2 moe_router_pre_softmax: bool = True moe_token_dispatcher_type: str = "alltoall" + moe_router_load_balancing_type: str = 'aux_loss' init_method_std: float = 0.02 layernorm_epsilon: float = 1e-5 diff --git a/nemo/collections/llm/recipes/mixtral_8x22b.py b/nemo/collections/llm/recipes/mixtral_8x22b.py index 209a5926a008..b9e61d2c4a7d 100644 --- a/nemo/collections/llm/recipes/mixtral_8x22b.py +++ b/nemo/collections/llm/recipes/mixtral_8x22b.py @@ -14,6 +14,8 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin from nemo.collections.llm.utils import Config, Partial +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback from nemo.utils.exp_manager import TimingCallback NAME = "mixtral_8x22b" @@ -24,19 +26,51 @@ def model() -> Config[pl.LightningModule]: def trainer( - tensor_parallelism: int, - pipeline_parallelism: int, - pipeline_parallelism_type: Optional[torch.dtype], - virtual_pipeline_parallelism: Optional[int], - context_parallelism: int, - sequence_parallelism: bool, - expert_parallelism: int, - num_nodes: int = 1, + tensor_parallelism: int = 2, + pipeline_parallelism: int = 4, + pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 14, + context_parallelism: int = 2, + sequence_parallelism: bool = True, + expert_parallelism: int = 8, + num_nodes: int = 16, num_gpus_per_node: int = 8, max_steps: int = 1168251, - callbacks: Optional[list[Config[Callback]]] = None, -) -> Config[nl.Trainer]: - strategy = Config( + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Mixtral 8x22B model. + + This function sets up the distributed training strategy optimized for the large Mixtral 8x22B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + expert_parallelism (int): Degree of expert parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=mixtral_8x22b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=16, num_gpus_per_node=8) + >>> print(trainer_config) + + Note: + This configuration uses extensive parallelism to handle the large model size efficiently. + """ + strategy = run.Config( nl.MegatronStrategy, tensor_model_parallel_size=tensor_parallelism, pipeline_model_parallel_size=pipeline_parallelism, @@ -71,9 +105,34 @@ def trainer( def pretrain_recipe( - name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain + name: str, ckpt_dir: str, num_nodes: int = 16, num_gpus_per_node: int = 8, fn: Callable = pretrain ) -> Partial: - return Partial( + """ + Create a pre-training recipe for Mixtral 8x22B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory mixtral_8x22b + $ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=16, name='my_mixtral_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=16) + >>> print(recipe) + """ + return run.Partial( fn, model=model(), trainer=trainer( @@ -95,8 +154,66 @@ def pretrain_recipe( ) -def hf_resume() -> Config[nl.AutoResume]: - return Config( +@run.cli.factory(target=pretrain, name=NAME + "_performance") +def pretrain_recipe_performance( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Mixtral 8x22B model. + + This recipe enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory "mixtral_8x22b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe_performance(name="mixtral_8x22b_perf", num_nodes=8) + >>> print(recipe) + + Note: + Use this recipe with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn) + recipe.trainer.callbacks.extend( + [ + run.Config(MegatronTokenDropCallback), + run.Config(MegatronCommOverlapCallback), + ] + ) + + return recipe + + +def hf_resume() -> run.Config[nl.AutoResume]: + """ + Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x22B model. + + This function sets up the configuration to resume training from a pre-trained + Hugging Face model checkpoint. + + More info about the model can be found at: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1 + + Returns: + run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. + + Note: + This is particularly useful for fine-tuning scenarios where you want to + start from the pre-trained Mixtral 8x22B model. + """ + return run.Config( nl.AutoResume, restore_config=Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x22B-v0.1"), ) diff --git a/nemo/collections/llm/recipes/mixtral_8x3b.py b/nemo/collections/llm/recipes/mixtral_8x3b.py index 7dc8170e13e3..6cea2ab30725 100644 --- a/nemo/collections/llm/recipes/mixtral_8x3b.py +++ b/nemo/collections/llm/recipes/mixtral_8x3b.py @@ -14,6 +14,9 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin from nemo.collections.llm.utils import Config, Partial +from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback from nemo.utils.exp_manager import TimingCallback NAME = "mixtral_8x3b" @@ -24,19 +27,48 @@ def model() -> Config[pl.LightningModule]: def trainer( - tensor_parallelism: int, - pipeline_parallelism: int, - pipeline_parallelism_type: Optional[torch.dtype], - virtual_pipeline_parallelism: Optional[int], - context_parallelism: int, - sequence_parallelism: bool, - expert_parallelism: int, - num_nodes: int = 1, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 1, + pipeline_parallelism_type: Optional[torch.dtype] = None, + virtual_pipeline_parallelism: Optional[int] = None, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + expert_parallelism: int = 4, + num_nodes: int = 2, num_gpus_per_node: int = 8, max_steps: int = 1168251, - callbacks: Optional[list[Config[Callback]]] = None, -) -> Config[nl.Trainer]: - strategy = Config( + callbacks: Optional[list[run.Config[Callback]]] = None, +) -> run.Config[nl.Trainer]: + """ + Configure the NeMo Lightning Trainer for Mixtral 8x3B model. + + This function sets up the distributed training strategy optimized for the Mixtral 8x3B model. + + Args: + tensor_parallelism (int): Degree of tensor model parallelism. + pipeline_parallelism (int): Degree of pipeline model parallelism. + pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism. + virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism. + context_parallelism (int): Degree of context parallelism. + sequence_parallelism (bool): Whether to use sequence parallelism. + expert_parallelism (int): Degree of expert parallelism. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + max_steps (int): Maximum number of training steps. + callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations. + + Returns: + run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer. + + Examples: + CLI usage: + $ nemo llm pretrain trainer=mixtral_8x3b ... + + Python API usage: + >>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8) + >>> print(trainer_config) + """ + strategy = run.Config( nl.MegatronStrategy, tensor_model_parallel_size=tensor_parallelism, pipeline_model_parallel_size=pipeline_parallelism, @@ -71,9 +103,34 @@ def trainer( def pretrain_recipe( - name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain + name: str, ckpt_dir: str, num_nodes: int = 2, num_gpus_per_node: int = 8, fn: Callable = pretrain ) -> Partial: - return Partial( + """ + Create a pre-training recipe for Mixtral 8x3B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, and data settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): Function to use for pre-training (default: nemo.collections.llm.api.pretrain). + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory mixtral_8x3b + $ nemo llm pretrain --factory "mixtral_8x3b(num_nodes=2, name='my_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="mixtral_8x3b_pretrain", num_nodes=2) + >>> print(recipe) + """ + return run.Partial( fn, model=model(), trainer=trainer( @@ -95,8 +152,69 @@ def pretrain_recipe( ) -def hf_resume() -> Config[nl.AutoResume]: - return Config( +@run.cli.factory(target=pretrain, name=NAME + "_performance") +def pretrain_recipe_performance( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Mixtral 8x3B model. + + This recipe enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=2, name='perf_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe_performance(name="mixtral_8x3b", num_nodes=4) + >>> print(recipe) + + Note: + Use this recipe with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn) + + recipe.trainer.callbacks.extend( + [ + run.Config(MegatronTokenDropCallback), + run.Config(MegatronCommOverlapCallback), + ] + ) + + return recipe + + +def hf_resume() -> run.Config[nl.AutoResume]: + """ + Configure the Hugging Face model resuming for Mixtral 8x3B model. + + This function sets up the configuration for resuming training from a Hugging Face model. + + Returns: + run.Config[nl.AutoResume]: Configuration for resuming from a Hugging Face model. + + Examples: + CLI usage: + $ nemo llm finetune --factory "mixtral_8x3b(resume=hf_resume())" + + Python API usage: + >>> recipe = finetune_recipe(name="mixtral_8x3b_finetune", num_nodes=2) + >>> recipe.resume = hf_resume() + >>> print(recipe) + """ + return run.Config( nl.AutoResume, restore_config=Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x7B-v0.1"), ) diff --git a/nemo/collections/llm/recipes/mixtral_8x7b.py b/nemo/collections/llm/recipes/mixtral_8x7b.py index bacbfcab4e2d..4b0e4f3dc5fd 100644 --- a/nemo/collections/llm/recipes/mixtral_8x7b.py +++ b/nemo/collections/llm/recipes/mixtral_8x7b.py @@ -14,6 +14,8 @@ from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin from nemo.collections.llm.utils import Config, Partial +from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback +from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback from nemo.utils.exp_manager import TimingCallback NAME = "mixtral_8x7b" @@ -24,14 +26,14 @@ def model() -> Config[pl.LightningModule]: def trainer( - tensor_parallelism: int, - pipeline_parallelism: int, - pipeline_parallelism_type: Optional[torch.dtype], - virtual_pipeline_parallelism: Optional[int], - context_parallelism: int, - sequence_parallelism: bool, - expert_parallelism: int, - num_nodes: int = 1, + tensor_parallelism: int = 1, + pipeline_parallelism: int = 4, + pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16, + virtual_pipeline_parallelism: Optional[int] = 8, + context_parallelism: int = 1, + sequence_parallelism: bool = False, + expert_parallelism: int = 8, + num_nodes: int = 8, num_gpus_per_node: int = 8, max_steps: int = 1168251, callbacks: Optional[list[Config[Callback]]] = None, @@ -73,7 +75,32 @@ def trainer( def pretrain_recipe( name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain ) -> Partial: - return Partial( + """ + Create a pre-training recipe for Mixtral 8x7B model. + + This function sets up a complete configuration for pre-training, including + model, trainer, data, logging, optimization, and resumption settings. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory mixtral_8x7b + $ nemo llm pretrain --factory "mixtral_8x7b(num_nodes=8, name='my_mixtral_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe(name="mixtral_8x7b_pretrain", num_nodes=8) + >>> print(recipe) + """ + return run.Partial( fn, model=model(), trainer=trainer( @@ -95,8 +122,66 @@ def pretrain_recipe( ) -def hf_resume() -> Config[nl.AutoResume]: - return Config( +@run.cli.factory(target=pretrain, name=NAME + "_performance") +def pretrain_recipe_performance( + dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain +) -> run.Partial: + """ + Create a performance-optimized pre-training recipe for Mixtral 8x7B model. + + This recipe enables performance optimizations that may not be suitable for all use cases. + It builds upon the standard pre-training recipe and adds additional performance enhancements. + + Args: + dir (Optional[str]): Directory for saving logs and checkpoints. + name (str): Name of the pre-training run. + num_nodes (int): Number of compute nodes to use. + num_gpus_per_node (int): Number of GPUs per node. + fn (Callable): The pre-training function to use. + + Returns: + run.Partial: Partial configuration for performance-optimized pre-training. + + Examples: + CLI usage: + $ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')" + + Python API usage: + >>> recipe = pretrain_recipe_performance(name="mixtral_8x3b_perf", num_nodes=8) + >>> print(recipe) + + Note: + Use this recipe with caution and only when you need maximum performance. + It may not be suitable for all hardware configurations or use cases. + """ + recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn) + recipe.trainer.callbacks.extend( + [ + run.Config(MegatronTokenDropCallback), + run.Config(MegatronCommOverlapCallback), + ] + ) + + return recipe + + +def hf_resume() -> run.Config[nl.AutoResume]: + """ + Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x7B model. + + This function sets up the configuration to resume training from a pre-trained + Hugging Face model checkpoint. + + More info about the model can be found at: https://huggingface.co/mistralai/Mixtral-8x7B-v0.1 + + Returns: + run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint. + + Note: + This is particularly useful for fine-tuning scenarios where you want to + start from the pre-trained Mixtral 8x7B model. + """ + return run.Config( nl.AutoResume, restore_config=Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x7B-v0.1"), ) diff --git a/nemo/lightning/pytorch/callbacks/moe_token_drop.py b/nemo/lightning/pytorch/callbacks/moe_token_drop.py new file mode 100644 index 000000000000..fc2aea84f3c1 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/moe_token_drop.py @@ -0,0 +1,55 @@ +import pytorch_lightning as pl +from megatron.core import ModelParallelConfig +from pytorch_lightning.callbacks.callback import Callback + +from nemo.lightning.pytorch.strategies.megatron_strategy import MegatronStrategy + + +class MegatronTokenDropCallback(Callback): + """ + A PyTorch Lightning callback to enable token drop for MOEs. Token drop improves performance by better + balancing work across experts, but may affect convergence. + + Args: + moe_expert_capacity_factor (float): The capacity factor for all experts + moe_pad_expert_input_to_capacity (bool): Pad the input for each expert to the expert capacity lengt + + Example: + >>> callback = MegatronCommOverlapCallback() + >>> trainer = Trainer(callbacks=[callback]) + """ + + def __init__( + self, + moe_expert_capacity_factor: float = 1.0, + moe_pad_expert_input_to_capacity: bool = True, + ): + + if moe_expert_capacity_factor < 0: + moe_expert_capacity_factor = None + self.moe_expert_capacity_factor = moe_expert_capacity_factor + self.moe_pad_expert_input_to_capacity = moe_pad_expert_input_to_capacity + + def _set_cfgs(self, cfg): + cfg.moe_expert_capacity_factor = self.moe_expert_capacity_factor + cfg.moe_pad_expert_input_to_capacity = self.moe_pad_expert_input_to_capacity + + def setup(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + assert isinstance(trainer.strategy, MegatronStrategy), "MegatronTokenDrop requires MegatronStrategy" + if hasattr(trainer.model, "config") and isinstance(trainer.model.config, ModelParallelConfig): + assert trainer.model.config.moe_token_dispatcher_type in [ + "alltoall", + "alltoall_seq", + ], 'moe_expert_capacity_factor only works with alltoall token dispatcher' + assert trainer.model.config.moe_router_load_balancing_type in [ + "aux_loss", + "none", + ], 'moe_expert_capacity_factor only works with aux_loss or none load balancing' + + if self.moe_pad_expert_input_to_capacity: + if self.moe_expert_capacity_factor is None: + raise ValueError('moe_expert_capacity_factor must be set to use moe_pad_expert_input_to_capacity') + + self._set_cfgs(trainer.model.config) + if hasattr(trainer.model, '__io__'): + self._set_cfgs(trainer.model.__io__.config) diff --git a/tests/collections/llm/recipes/test_mixtral_8x22b.py b/tests/collections/llm/recipes/test_mixtral_8x22b.py new file mode 100644 index 000000000000..3f855721e14f --- /dev/null +++ b/tests/collections/llm/recipes/test_mixtral_8x22b.py @@ -0,0 +1,118 @@ +import nemo_run as run +import pytest +import torch +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x22B, MixtralModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes import mixtral_8x22b +from nemo.lightning import AutoResume, Trainer + + +class TestMixtral8x22B: + @pytest.fixture(scope="class") + def recipe_module(self): + return mixtral_8x22b + + def test_model(self, recipe_module): + model_config = recipe_module.model() + assert isinstance(model_config, run.Config) + assert model_config.__fn_or_cls__ == MixtralModel + assert isinstance(model_config.config, run.Config) + assert model_config.config.__fn_or_cls__ == MixtralConfig8x22B + + def test_trainer(self, recipe_module): + trainer_config = recipe_module.trainer() + assert isinstance(trainer_config, run.Config) + assert trainer_config.__fn_or_cls__ == Trainer + assert trainer_config.accelerator == "gpu" + assert trainer_config.devices == 8 + assert trainer_config.num_nodes == 16 + + # Check strategy configuration + assert isinstance(trainer_config.strategy, run.Config) + assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" + assert trainer_config.strategy.tensor_model_parallel_size == 2 + assert trainer_config.strategy.pipeline_model_parallel_size == 4 + assert trainer_config.strategy.pipeline_dtype == torch.bfloat16 + assert trainer_config.strategy.virtual_pipeline_model_parallel_size == 14 + assert trainer_config.strategy.context_parallel_size == 2 + assert trainer_config.strategy.sequence_parallel is True + assert trainer_config.strategy.expert_model_parallel_size == 8 + + # Check DDP configuration + assert isinstance(trainer_config.strategy.ddp, run.Config) + assert trainer_config.strategy.ddp.__fn_or_cls__ == DistributedDataParallelConfig + assert trainer_config.strategy.ddp.check_for_nan_in_grad is True + assert trainer_config.strategy.ddp.grad_reduce_in_fp32 is True + + def test_pretrain_recipe(self, recipe_module): + recipe = recipe_module.pretrain_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == pretrain + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == MixtralModel + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == MockDataModule + assert recipe.data.seq_length == 8192 + assert recipe.data.global_batch_size == 512 + assert recipe.data.micro_batch_size == 1 + + def test_finetune_recipe(self, recipe_module): + recipe = recipe_module.finetune_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == finetune + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == MixtralModel + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == SquadDataModule + assert recipe.data.seq_length == 8192 + assert recipe.data.global_batch_size == 512 + assert recipe.data.micro_batch_size == 1 + assert isinstance(recipe.peft, run.Config) + assert recipe.peft.__fn_or_cls__ == LoRA + assert recipe.peft.target_modules == ['linear_qkv', 'linear_proj'] + assert recipe.peft.dim == 32 + + @pytest.mark.parametrize("num_nodes,num_gpus_per_node", [(8, 8), (16, 4), (32, 2)]) + def test_pretrain_recipe_with_different_configurations(self, recipe_module, num_nodes, num_gpus_per_node): + recipe = recipe_module.pretrain_recipe(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) + assert recipe.trainer.num_nodes == num_nodes + assert recipe.trainer.devices == num_gpus_per_node + + def test_hf_resume(self, recipe_module): + resume_config = recipe_module.hf_resume() + assert isinstance(resume_config, run.Config) + assert resume_config.__fn_or_cls__ == AutoResume + assert isinstance(resume_config.restore_config, run.Config) + assert resume_config.restore_config.path == "hf://mistralai/Mixtral-8x22B-v0.1" + + def test_trainer_parallelism_options(self, recipe_module): + trainer_config = recipe_module.trainer( + tensor_parallelism=4, + pipeline_parallelism=4, + context_parallelism=2, + sequence_parallelism=False, + expert_parallelism=2, + ) + assert trainer_config.strategy.tensor_model_parallel_size == 4 + assert trainer_config.strategy.pipeline_model_parallel_size == 4 + assert trainer_config.strategy.context_parallel_size == 2 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.expert_model_parallel_size == 2 + + def test_model_config_parameters(self, recipe_module): + model_config = recipe_module.model() + mixtral_config = model_config.config + assert mixtral_config.num_layers == 56 + assert mixtral_config.hidden_size == 6144 + assert mixtral_config.num_attention_heads == 48 + assert mixtral_config.seq_length == 4096 + assert mixtral_config.num_moe_experts == 8 diff --git a/tests/collections/llm/recipes/test_mixtral_8x3b.py b/tests/collections/llm/recipes/test_mixtral_8x3b.py new file mode 100644 index 000000000000..238fec74e0e1 --- /dev/null +++ b/tests/collections/llm/recipes/test_mixtral_8x3b.py @@ -0,0 +1,110 @@ +import nemo_run as run +import pytest + +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x3B, MixtralModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes import mixtral_8x3b +from nemo.lightning import AutoResume, Trainer + + +class TestMixtral8x3B: + @pytest.fixture(scope="class") + def recipe_module(self): + return mixtral_8x3b + + def test_model(self, recipe_module): + model_config = recipe_module.model() + assert isinstance(model_config, run.Config) + assert model_config.__fn_or_cls__ == MixtralModel + assert isinstance(model_config.config, run.Config) + assert model_config.config.__fn_or_cls__ == MixtralConfig8x3B + + def test_trainer(self, recipe_module): + trainer_config = recipe_module.trainer() + assert isinstance(trainer_config, run.Config) + assert trainer_config.__fn_or_cls__ == Trainer + assert trainer_config.accelerator == "gpu" + assert trainer_config.devices == 8 + assert trainer_config.num_nodes == 2 + + # Check strategy configuration + assert isinstance(trainer_config.strategy, run.Config) + assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" + assert trainer_config.strategy.tensor_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_dtype is None + assert trainer_config.strategy.virtual_pipeline_model_parallel_size is None + assert trainer_config.strategy.context_parallel_size == 1 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.expert_model_parallel_size == 4 + + def test_pretrain_recipe(self, recipe_module): + recipe = recipe_module.pretrain_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == pretrain + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == MixtralModel + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == MockDataModule + assert recipe.data.seq_length == 8192 + assert recipe.data.global_batch_size == 512 + assert recipe.data.micro_batch_size == 1 + + def test_finetune_recipe(self, recipe_module): + recipe = recipe_module.finetune_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == finetune + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == MixtralModel + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == SquadDataModule + assert recipe.data.seq_length == 8192 + assert recipe.data.global_batch_size == 512 + assert recipe.data.micro_batch_size == 1 + assert isinstance(recipe.peft, run.Config) + assert recipe.peft.__fn_or_cls__ == LoRA + assert recipe.peft.target_modules == ['linear_qkv', 'linear_proj'] + assert recipe.peft.dim == 32 + + @pytest.mark.parametrize("num_nodes,num_gpus_per_node", [(1, 8), (2, 4), (4, 2)]) + def test_pretrain_recipe_with_different_configurations(self, recipe_module, num_nodes, num_gpus_per_node): + recipe = recipe_module.pretrain_recipe(num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node) + assert recipe.trainer.num_nodes == num_nodes + assert recipe.trainer.devices == num_gpus_per_node + + def test_hf_resume(self, recipe_module): + resume_config = recipe_module.hf_resume() + assert isinstance(resume_config, run.Config) + assert resume_config.__fn_or_cls__ == AutoResume + assert isinstance(resume_config.restore_config, run.Config) + assert resume_config.restore_config.path == "hf://mistralai/Mixtral-8x7B-v0.1" + + def test_trainer_parallelism_options(self, recipe_module): + trainer_config = recipe_module.trainer( + tensor_parallelism=8, + pipeline_parallelism=2, + context_parallelism=4, + sequence_parallelism=False, + expert_parallelism=2, + ) + assert trainer_config.strategy.tensor_model_parallel_size == 8 + assert trainer_config.strategy.pipeline_model_parallel_size == 2 + assert trainer_config.strategy.context_parallel_size == 4 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.expert_model_parallel_size == 2 + + def test_model_config_parameters(self, recipe_module): + model_config = recipe_module.model() + mixtral_config = model_config.config + assert mixtral_config.num_layers == 32 + assert mixtral_config.hidden_size == 2560 + assert mixtral_config.num_attention_heads == 32 + assert mixtral_config.seq_length == 4096 + assert mixtral_config.num_moe_experts == 8 diff --git a/tests/collections/llm/recipes/test_mixtral_8x7b.py b/tests/collections/llm/recipes/test_mixtral_8x7b.py new file mode 100644 index 000000000000..75003891930d --- /dev/null +++ b/tests/collections/llm/recipes/test_mixtral_8x7b.py @@ -0,0 +1,112 @@ +import nemo_run as run +import pytest +import torch +from megatron.core.distributed import DistributedDataParallelConfig + +from nemo.collections.llm.api import finetune, pretrain +from nemo.collections.llm.gpt.data.mock import MockDataModule +from nemo.collections.llm.gpt.data.squad import SquadDataModule +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig8x7B, MixtralModel +from nemo.collections.llm.peft.lora import LoRA +from nemo.collections.llm.recipes import mixtral_8x7b +from nemo.lightning import AutoResume, Trainer + + +class TestMixtral8x7B: + @pytest.fixture(scope="class") + def recipe_module(self): + return mixtral_8x7b + + def test_model(self, recipe_module): + model_config = recipe_module.model() + assert isinstance(model_config, run.Config) + assert model_config.__fn_or_cls__ == MixtralModel + assert isinstance(model_config.config, run.Config) + assert model_config.config.__fn_or_cls__ == MixtralConfig8x7B + + def test_trainer(self, recipe_module): + trainer_config = recipe_module.trainer() + assert isinstance(trainer_config, run.Config) + assert trainer_config.__fn_or_cls__ == Trainer + assert trainer_config.accelerator == "gpu" + assert trainer_config.devices == 8 + assert trainer_config.num_nodes == 8 + + # Check strategy configuration + assert isinstance(trainer_config.strategy, run.Config) + assert trainer_config.strategy.__fn_or_cls__.__name__ == "MegatronStrategy" + assert trainer_config.strategy.tensor_model_parallel_size == 1 + assert trainer_config.strategy.pipeline_model_parallel_size == 4 + assert trainer_config.strategy.pipeline_dtype == torch.bfloat16 + assert trainer_config.strategy.virtual_pipeline_model_parallel_size == 8 + assert trainer_config.strategy.context_parallel_size == 1 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.expert_model_parallel_size == 8 + + # Check DDP configuration + assert isinstance(trainer_config.strategy.ddp, run.Config) + assert trainer_config.strategy.ddp.__fn_or_cls__ == DistributedDataParallelConfig + assert trainer_config.strategy.ddp.check_for_nan_in_grad is True + assert trainer_config.strategy.ddp.grad_reduce_in_fp32 is True + + def test_pretrain_recipe(self, recipe_module): + recipe = recipe_module.pretrain_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == pretrain + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == MixtralModel + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == MockDataModule + assert recipe.data.seq_length == 8192 + assert recipe.data.global_batch_size == 512 + assert recipe.data.micro_batch_size == 1 + + def test_finetune_recipe(self, recipe_module): + recipe = recipe_module.finetune_recipe() + assert isinstance(recipe, run.Partial) + assert recipe.__fn_or_cls__ == finetune + assert isinstance(recipe.model, run.Config) + assert recipe.model.__fn_or_cls__ == MixtralModel + assert isinstance(recipe.trainer, run.Config) + assert recipe.trainer.__fn_or_cls__ == Trainer + assert isinstance(recipe.data, run.Config) + assert recipe.data.__fn_or_cls__ == SquadDataModule + assert recipe.data.seq_length == 8192 + assert recipe.data.global_batch_size == 512 + assert recipe.data.micro_batch_size == 1 + assert isinstance(recipe.peft, run.Config) + assert recipe.peft.__fn_or_cls__ == LoRA + assert recipe.peft.target_modules == ['linear_qkv', 'linear_proj'] + assert recipe.peft.dim == 32 + + def test_hf_resume(self, recipe_module): + resume_config = recipe_module.hf_resume() + assert isinstance(resume_config, run.Config) + assert resume_config.__fn_or_cls__ == AutoResume + assert isinstance(resume_config.restore_config, run.Config) + assert resume_config.restore_config.path == "hf://mistralai/Mixtral-8x7B-v0.1" + + def test_trainer_parallelism_options(self, recipe_module): + trainer_config = recipe_module.trainer( + tensor_parallelism=4, + pipeline_parallelism=4, + context_parallelism=2, + sequence_parallelism=False, + expert_parallelism=2, + ) + assert trainer_config.strategy.tensor_model_parallel_size == 4 + assert trainer_config.strategy.pipeline_model_parallel_size == 4 + assert trainer_config.strategy.context_parallel_size == 2 + assert trainer_config.strategy.sequence_parallel is False + assert trainer_config.strategy.expert_model_parallel_size == 2 + + def test_model_config_parameters(self, recipe_module): + model_config = recipe_module.model() + mixtral_config = model_config.config + assert mixtral_config.num_layers == 32 + assert mixtral_config.hidden_size == 4096 + assert mixtral_config.num_attention_heads == 32 + assert mixtral_config.seq_length == 4096 + assert mixtral_config.num_moe_experts == 8