Skip to content

Commit

Permalink
[NeMo-UX] Add token drop callback and optimize mixtral configs (#10361)
Browse files Browse the repository at this point in the history
* add token drop plugin

Signed-off-by: Jimmy Zhang <[email protected]>

* add checks

Signed-off-by: Jimmy Zhang <[email protected]>

* add expert parallel configs

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* amend comment

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* add comm overlap

Signed-off-by: Jimmy Zhang <[email protected]>

* fix rebase errors

Signed-off-by: Jimmy Zhang <[email protected]>

* Apply isort and black reformatting

Signed-off-by: JimmyZhang12 <[email protected]>

* fix typo

Signed-off-by: Jimmy Zhang <[email protected]>

* add test configs

Signed-off-by: Jimmy Zhang <[email protected]>

* fix

Signed-off-by: Jimmy Zhang <[email protected]>

---------

Signed-off-by: Jimmy Zhang <[email protected]>
Signed-off-by: JimmyZhang12 <[email protected]>
Co-authored-by: Jimmy Zhang <[email protected]>
Co-authored-by: JimmyZhang12 <[email protected]>
Co-authored-by: Pablo Garay <[email protected]>
Co-authored-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
5 people committed Sep 18, 2024
1 parent 6f615d3 commit 8815b16
Show file tree
Hide file tree
Showing 8 changed files with 757 additions and 43 deletions.
3 changes: 1 addition & 2 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,10 @@ class MixtralConfig(GPTConfig):
# MoE
num_moe_experts: int = 8
moe_aux_loss_coeff: float = 0.01
moe_expert_capacity_factor: float = 1.0
moe_pad_expert_input_to_capacity: bool = True
moe_router_topk: int = 2
moe_router_pre_softmax: bool = True
moe_token_dispatcher_type: str = "alltoall"
moe_router_load_balancing_type: str = 'aux_loss'

init_method_std: float = 0.02
layernorm_epsilon: float = 1e-5
Expand Down
147 changes: 132 additions & 15 deletions nemo/collections/llm/recipes/mixtral_8x22b.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin
from nemo.collections.llm.utils import Config, Partial
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x22b"
Expand All @@ -24,19 +26,51 @@ def model() -> Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int,
pipeline_parallelism: int,
pipeline_parallelism_type: Optional[torch.dtype],
virtual_pipeline_parallelism: Optional[int],
context_parallelism: int,
sequence_parallelism: bool,
expert_parallelism: int,
num_nodes: int = 1,
tensor_parallelism: int = 2,
pipeline_parallelism: int = 4,
pipeline_parallelism_type: Optional[torch.dtype] = torch.bfloat16,
virtual_pipeline_parallelism: Optional[int] = 14,
context_parallelism: int = 2,
sequence_parallelism: bool = True,
expert_parallelism: int = 8,
num_nodes: int = 16,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[Config[Callback]]] = None,
) -> Config[nl.Trainer]:
strategy = Config(
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Mixtral 8x22B model.
This function sets up the distributed training strategy optimized for the large Mixtral 8x22B model.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
expert_parallelism (int): Degree of expert parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=mixtral_8x22b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=16, num_gpus_per_node=8)
>>> print(trainer_config)
Note:
This configuration uses extensive parallelism to handle the large model size efficiently.
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
Expand Down Expand Up @@ -71,9 +105,34 @@ def trainer(


def pretrain_recipe(
name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain
name: str, ckpt_dir: str, num_nodes: int = 16, num_gpus_per_node: int = 8, fn: Callable = pretrain
) -> Partial:
return Partial(
"""
Create a pre-training recipe for Mixtral 8x22B model.
This function sets up a complete configuration for pre-training, including
model, trainer, data, logging, optimization, and resumption settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory mixtral_8x22b
$ nemo llm pretrain --factory "mixtral_8x22b(num_nodes=16, name='my_mixtral_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mixtral_pretrain", num_nodes=16)
>>> print(recipe)
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
Expand All @@ -95,8 +154,66 @@ def pretrain_recipe(
)


def hf_resume() -> Config[nl.AutoResume]:
return Config(
@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 8, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x22B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x22b.pretrain_recipe_performance(num_nodes=8, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x22b_perf", num_nodes=8)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)
recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure automatic resumption from a Hugging Face checkpoint for Mixtral 8x22B model.
This function sets up the configuration to resume training from a pre-trained
Hugging Face model checkpoint.
More info about the model can be found at: https://huggingface.co/mistralai/Mixtral-8x22B-v0.1
Returns:
run.Config[nl.AutoResume]: Configuration for resuming from HuggingFace checkpoint.
Note:
This is particularly useful for fine-tuning scenarios where you want to
start from the pre-trained Mixtral 8x22B model.
"""
return run.Config(
nl.AutoResume,
restore_config=Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x22B-v0.1"),
)
Expand Down
148 changes: 133 additions & 15 deletions nemo/collections/llm/recipes/mixtral_8x3b.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from nemo.collections.llm.recipes.optim.adam import distributed_fused_adam_with_cosine_annealing
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed_plugin
from nemo.collections.llm.utils import Config, Partial
from nemo.collections.llm.recipes.precision.mixed_precision import bf16_mixed
from nemo.lightning.pytorch.callbacks.megatron_comm_overlap import MegatronCommOverlapCallback
from nemo.lightning.pytorch.callbacks.moe_token_drop import MegatronTokenDropCallback
from nemo.utils.exp_manager import TimingCallback

NAME = "mixtral_8x3b"
Expand All @@ -24,19 +27,48 @@ def model() -> Config[pl.LightningModule]:


def trainer(
tensor_parallelism: int,
pipeline_parallelism: int,
pipeline_parallelism_type: Optional[torch.dtype],
virtual_pipeline_parallelism: Optional[int],
context_parallelism: int,
sequence_parallelism: bool,
expert_parallelism: int,
num_nodes: int = 1,
tensor_parallelism: int = 1,
pipeline_parallelism: int = 1,
pipeline_parallelism_type: Optional[torch.dtype] = None,
virtual_pipeline_parallelism: Optional[int] = None,
context_parallelism: int = 1,
sequence_parallelism: bool = False,
expert_parallelism: int = 4,
num_nodes: int = 2,
num_gpus_per_node: int = 8,
max_steps: int = 1168251,
callbacks: Optional[list[Config[Callback]]] = None,
) -> Config[nl.Trainer]:
strategy = Config(
callbacks: Optional[list[run.Config[Callback]]] = None,
) -> run.Config[nl.Trainer]:
"""
Configure the NeMo Lightning Trainer for Mixtral 8x3B model.
This function sets up the distributed training strategy optimized for the Mixtral 8x3B model.
Args:
tensor_parallelism (int): Degree of tensor model parallelism.
pipeline_parallelism (int): Degree of pipeline model parallelism.
pipeline_parallelism_type (Optional[torch.dtype]): Data type for pipeline parallelism.
virtual_pipeline_parallelism (Optional[int]): Size of virtual pipeline parallelism.
context_parallelism (int): Degree of context parallelism.
sequence_parallelism (bool): Whether to use sequence parallelism.
expert_parallelism (int): Degree of expert parallelism.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
max_steps (int): Maximum number of training steps.
callbacks (Optional[list[run.Config[Callback]]]): List of callback configurations.
Returns:
run.Config[nl.Trainer]: Configuration for the NeMo Lightning Trainer.
Examples:
CLI usage:
$ nemo llm pretrain trainer=mixtral_8x3b ...
Python API usage:
>>> trainer_config = trainer(num_nodes=2, num_gpus_per_node=8)
>>> print(trainer_config)
"""
strategy = run.Config(
nl.MegatronStrategy,
tensor_model_parallel_size=tensor_parallelism,
pipeline_model_parallel_size=pipeline_parallelism,
Expand Down Expand Up @@ -71,9 +103,34 @@ def trainer(


def pretrain_recipe(
name: str, ckpt_dir: str, num_nodes: int, num_gpus_per_node: int, fn: Callable = pretrain
name: str, ckpt_dir: str, num_nodes: int = 2, num_gpus_per_node: int = 8, fn: Callable = pretrain
) -> Partial:
return Partial(
"""
Create a pre-training recipe for Mixtral 8x3B model.
This function sets up a complete configuration for pre-training, including
model, trainer, and data settings.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): Function to use for pre-training (default: nemo.collections.llm.api.pretrain).
Returns:
run.Partial: Partial configuration for pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory mixtral_8x3b
$ nemo llm pretrain --factory "mixtral_8x3b(num_nodes=2, name='my_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe(name="mixtral_8x3b_pretrain", num_nodes=2)
>>> print(recipe)
"""
return run.Partial(
fn,
model=model(),
trainer=trainer(
Expand All @@ -95,8 +152,69 @@ def pretrain_recipe(
)


def hf_resume() -> Config[nl.AutoResume]:
return Config(
@run.cli.factory(target=pretrain, name=NAME + "_performance")
def pretrain_recipe_performance(
dir: Optional[str] = None, name: str = "default", num_nodes: int = 2, num_gpus_per_node: int = 8, fn=pretrain
) -> run.Partial:
"""
Create a performance-optimized pre-training recipe for Mixtral 8x3B model.
This recipe enables performance optimizations that may not be suitable for all use cases.
It builds upon the standard pre-training recipe and adds additional performance enhancements.
Args:
dir (Optional[str]): Directory for saving logs and checkpoints.
name (str): Name of the pre-training run.
num_nodes (int): Number of compute nodes to use.
num_gpus_per_node (int): Number of GPUs per node.
fn (Callable): The pre-training function to use.
Returns:
run.Partial: Partial configuration for performance-optimized pre-training.
Examples:
CLI usage:
$ nemo llm pretrain --factory "mixtral_8x3b.pretrain_recipe_performance(num_nodes=2, name='perf_pretrain')"
Python API usage:
>>> recipe = pretrain_recipe_performance(name="mixtral_8x3b", num_nodes=4)
>>> print(recipe)
Note:
Use this recipe with caution and only when you need maximum performance.
It may not be suitable for all hardware configurations or use cases.
"""
recipe = pretrain_recipe(name=name, dir=dir, num_nodes=num_nodes, num_gpus_per_node=num_gpus_per_node, fn=fn)

recipe.trainer.callbacks.extend(
[
run.Config(MegatronTokenDropCallback),
run.Config(MegatronCommOverlapCallback),
]
)

return recipe


def hf_resume() -> run.Config[nl.AutoResume]:
"""
Configure the Hugging Face model resuming for Mixtral 8x3B model.
This function sets up the configuration for resuming training from a Hugging Face model.
Returns:
run.Config[nl.AutoResume]: Configuration for resuming from a Hugging Face model.
Examples:
CLI usage:
$ nemo llm finetune --factory "mixtral_8x3b(resume=hf_resume())"
Python API usage:
>>> recipe = finetune_recipe(name="mixtral_8x3b_finetune", num_nodes=2)
>>> recipe.resume = hf_resume()
>>> print(recipe)
"""
return run.Config(
nl.AutoResume,
restore_config=Config(nl.RestoreConfig, path="hf://mistralai/Mixtral-8x7B-v0.1"),
)
Expand Down
Loading

0 comments on commit 8815b16

Please sign in to comment.