Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/EleutherAI/gpt-neox into ad…
Browse files Browse the repository at this point in the history
…d-context-parallel-support
  • Loading branch information
Quentin-Anthony committed Jan 29, 2025
2 parents 4969683 + f7a5a6f commit 5f13813
Show file tree
Hide file tree
Showing 48 changed files with 2,386 additions and 276 deletions.
34 changes: 26 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, monitor experiments via [WandB](https://wandb.ai/site)/[Comet](https://www.comet.com/site/)/TensorBoard, and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).

## News
**[10/9/2024]** We now support [Transformer Engine](https://github.com/NVIDIA/TransformerEngine) integration

**[9/9/2024]** We now support preference learning via [DPO](https://arxiv.org/abs/2305.18290), [KTO](https://arxiv.org/abs/2402.01306), and reward modeling

**[9/9/2024]** We now support integration with [Comet ML](https://www.comet.com/site/), a machine learning monitoring platform
Expand Down Expand Up @@ -60,6 +62,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA
* [Environment and Dependencies](#environment-and-dependencies)
+ [Host Setup](#host-setup)
+ [Flash Attention](#flash-attention)
+ [Transformer Engine](#transformer-engine)
+ [Multi-Node Launching](#multi-node-launching)
+ [Containerized Setup](#containerized-setup)
* [Usage](#usage)
Expand Down Expand Up @@ -100,7 +103,7 @@ Prior to 3/9/2023, GPT-NeoX relied on [DeeperSpeed](https://github.com/EleutherA

### Host Setup

First make sure you are in an environment with Python 3.8 with an appropriate version of PyTorch 1.8 or later installed. **Note:** Some of the libraries that GPT-NeoX depends on have not been updated to be compatible with Python 3.10+. Python 3.9 appears to work, but this codebase has been developed and tested for Python 3.8.
This codebase has primarily developed and tested for Python 3.8-3.10, and PyTorch 1.8-2.0. This is not a strict requirement, and other versions and combinations of libraries may work.

To install the remaining basic dependencies, run:

Expand Down Expand Up @@ -130,7 +133,20 @@ This will automatically adapts building process over different GPU vendors (AMD,

### Flash Attention

To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` and set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.
To use [Flash-Attention](https://github.com/HazyResearch/flash-attention), install the additional dependencies in `./requirements/requirements-flashattention.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). Then set the attention type in your configuration accordingly (see [configs](./configs/)). This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere GPUs (such as A100s); see the repository for more details.

### Transformer Engine

To use [Transformer Engine (TE)](https://github.com/NVIDIA/TransformerEngine), install the additional dependencies in `./requirements/requirements-transformer-engine.txt` or use a PyTorch NGC container with it pre-installed (note that functionality is not guaranteed using versions different from our requirements file). See [this config](https://github.com/EleutherAI/gpt-neox/configs/1-3B-transformer-engine.yml) for an example of using TE on a 1.3B model. This can provide significant speed-ups over regular attention on certain GPU architectures, including Ampere and Hopper GPUs; see the repository for more details.


TE provides very efficient kernels for both A100 and H100 GPUs. We've run some sample ablations on A100:



and H100:




### Multi-Node Launching
Expand Down Expand Up @@ -679,7 +695,9 @@ We support profiling with Nsight Systems, the PyTorch Profiler, and PyTorch Memo
## Nsight Systems Profiling
To use the Nsight Systems profiling, set config options `profile`, `profile_step_start`, and `profile_step_stop`. Launch training with:
To use the Nsight Systems profiling, set config options `profile`, `profile_step_start`, and `profile_step_stop` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config).
To populate nsys metrics, launch training with:
```
nsys profile -s none -t nvtx,cuda -o <path/to/profiling/output> --force-overwrite true \
Expand All @@ -689,22 +707,22 @@ $TRAIN_PATH/train.py --conf_dir configs <config files>
The generated output file can then by viewed with the Nsight Systems GUI:
![Alt text](images/nsight_profiling.png)
![nsight-prof](images/nsight_profiling.png)
## PyTorch Profiling
To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop`.
To use the built-in PyTorch profiler, set config options `profile`, `profile_step_start`, and `profile_step_stop` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config).
The PyTorch profiler will save traces to your `tensorboard` log directory. You can view these traces within
TensorBoard by following the steps [here](https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html).
![Alt text](images/pytorch_profiling.png)
![torch-prof](images/pytorch_profiling.png)
## PyTorch Memory Profiling
To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path`.
To use PyTorch Memory Profiling, set config options `memory_profiling` and `memory_profiling_path` (see [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/neox_arguments.md) for argument usage, and [here](https://github.com/EleutherAI/gpt-neox/blob/main/configs/prof.yml) for a sample config).
![Alt text](images/memory_profiling.png)
![mem-prof](images/memory_profiling.png)
View the generated profile with the [memory_viz.py](https://github.com/pytorch/pytorch/blob/main/torch/cuda/_memory_viz.py) script. Run with:
Expand Down
105 changes: 105 additions & 0 deletions configs/1-3B-transformer-engine.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# GPT-2 pretraining setup
{
# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages
# across the node boundaries )
"pipe_parallel_size": 1,
"model_parallel_size": 1,

# model settings
"num_layers": 24,
"hidden_size": 2048,
"num_attention_heads": 16,
"seq_length": 2048,
"max_position_embeddings": 2048,
"norm": "layernorm",
"pos_emb": "rotary",
"no_weight_tying": true,
"gpt_j_residual": false,
"output_layer_parallelism": "column",

# Transformer Engine settings
"te_columnparallel": false,
"te_rowparallel": false,
"te_layernorm_mlp": true,
"te_mha": true,
"te_fp8_format": "hybrid",
"te_fp8_wgrad": true,
"te_fp8_amax_history_len": 1,
"te_fp8_amax_compute_algo": "most_recent",
"te_fp8_margin": 0,
"te_fp8_mha": false,

# these should provide some speedup but takes a while to build, set to true if desired
"scaled_upper_triang_masked_softmax_fusion": false,
"bias_gelu_fusion": false,
"rope_fusion": false,
"layernorm_fusion": false,

# init methods
"init_method": "small_init",
"output_layer_init_method": "wang_init",

# optimizer settings
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.0002,
"betas": [0.9, 0.95],
"eps": 1.0e-8,
}
},
"min_lr": 0.00002,

# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training
"zero_optimization": {
"stage": 1,
"allgather_partitions": True,
"allgather_bucket_size": 500000000,
"overlap_comm": True,
"reduce_scatter": True,
"reduce_bucket_size": 500000000,
"contiguous_gradients": True,
},

# batch / data settings
"train_micro_batch_size_per_gpu": 4,
"data_impl": "mmap",

# activation checkpointing
"checkpoint_activations": true,
"checkpoint_num_layers": 1,
"partition_activations": true,
"synchronize_each_layer": true,

# regularization
"gradient_clipping": 1.0,
"weight_decay": 0.1,
"hidden_dropout": 0,
"attention_dropout": 0,

# precision settings
"fp16": {
"fp16": true,
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},

# misc. training settings
"train_iters": 320000,
"lr_decay_iters": 320000,
"distributed_backend": "nccl",
"lr_decay_style": "cosine",
"warmup": 0.01,
"checkpoint_factor": 10000,
"eval_interval": 1000,
"eval_iters": 10,

# logging
"log_interval": 100,
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
}
1 change: 1 addition & 0 deletions configs/eleutherai_cluster.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"tensorboard_dir": "/mnt/ssd-1/tensorboard",
"log_dir": "/mnt/ssd-1/logs",
"wandb_team": "eleutherai",
#"wandb_run_name": "experiment"
"wandb_project": "neox",
"wandb_group": "example"
}
2 changes: 2 additions & 0 deletions configs/llama/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 40,
"hidden_size": 5120,
"intermediate_size": 40960,
"num_attention_heads": 40,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 2 additions & 0 deletions configs/llama/30B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 60,
"hidden_size": 6656,
"intermediate_size": 53248,
"num_attention_heads": 52,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 2 additions & 0 deletions configs/llama/65B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 80,
"hidden_size": 8192,
"intermediate_size": 65536,
"num_attention_heads": 64,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 2 additions & 0 deletions configs/llama/7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 32,
"hidden_size": 4096,
"intermediate_size": 32768,
"num_attention_heads": 32,
"seq_length": 2048,
"max_position_embeddings": 2048,
Expand All @@ -16,6 +17,7 @@
"output_layer_parallelism": "column",
"norm": "rmsnorm",
"rms_norm_epsilon": 1.0e-6,
"use_bias_in_mlp": False,

"scaled_upper_triang_masked_softmax_fusion": true,
"bias_gelu_fusion": false,
Expand Down
2 changes: 1 addition & 1 deletion configs/llama/train_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@
"steps_per_print": 10,
"keep_last_n_checkpoints": 4,
"wall_clock_breakdown": true,
"mlp_multiple_of": 256,

}
1 change: 1 addition & 0 deletions configs/llama2/13B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 40,
"hidden_size": 5120,
"intermediate_size": 41472,
"num_attention_heads": 40,
"seq_length": 4096,
"max_position_embeddings": 4096,
Expand Down
2 changes: 1 addition & 1 deletion configs/llama2/70B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# model settings
"num_layers": 80,
"hidden_size": 8192,
"intermediate_size": 28672,
"intermediate_size": 86016,
"num_attention_heads": 64,
"num_kv_heads": 8,
"seq_length": 4096,
Expand Down
1 change: 1 addition & 0 deletions configs/llama2/7B.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
# model settings
"num_layers": 32,
"hidden_size": 4096,
"intermediate_size": 32768,
"num_attention_heads": 32,
"seq_length": 4096,
"max_position_embeddings": 4096,
Expand Down
23 changes: 23 additions & 0 deletions configs/neox_arguments.md
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,29 @@ Model Arguments
- **dim_att**: int
Default = None
Total dimension of the attention mechanism for RWKV. If not set, defaults to hidden_size.
- **head_size**: int
Default = None
Size of each attention head for RWKV. Calculated as dim_att // num_attention_heads.
- **ffn_dim**: int
Default = None
Dimension of the feed-forward network for RWKV. If not set, calculated based on hidden_size and expansion_factor.
## NeoXArgsOptimizer
Optimizer Arguments
Expand Down
17 changes: 17 additions & 0 deletions configs/prof.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Sample profiling config
{
# Turns on nsys and pytorch profiling
"profile": true,

# pytorch profiler options
"profile_step_start": 10,
"profile_step_stop": 12,

# pytorch memory profiler options
"memory_profiling": true,
"memory_profiling_path": tensorboard,


# All trace files (pytorch, nsys, tensorboard, etc) will be written here
"tensorboard_dir": "tensorboard",
}
Loading

0 comments on commit 5f13813

Please sign in to comment.