Skip to content

Commit 11f2d93

Browse files
Supervised Fine-tuning for HugginFace pretrained weight. (deepspeedai#318)
* alpaca hf weight finetune clean up update update update update update update update arg fix update clean up update update update refine weight converter don't cat when dim=0 format update update update * add finetune script * add condition for no padded token case * add reference --------- Co-authored-by: Conglong Li <[email protected]>
1 parent f9323e3 commit 11f2d93

File tree

11 files changed

+960
-8
lines changed

11 files changed

+960
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
## Example of Finetuning LLAMA-7B from Hugging Face Weights
2+
3+
### Dataset
4+
You can access the dataset from [here](https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json).
5+
6+
### Pre-trained Weights
7+
The pre-trained weights can be found at [Hugging Face - LLAMA-7B](https://huggingface.co/huggyllama/llama-7b).
8+
9+
### Usage:
10+
11+
#### 1. Converting Hugging Face Model Weights to Megatron-Deepspeed Model
12+
```bash
13+
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh convert
14+
```
15+
This command writes the Hugging Face model weights into the Megatron-Deepspeed model and saves it. You can adjust the parallel configuration in the script.
16+
17+
#### 2. Fine-tuning Process
18+
```bash
19+
bash examples_deepspeed/finetune_hf_llama/finetune_llama.sh
20+
```
21+
Execute this command to initiate the finetuning process. The task originates from [Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca.git).
22+
23+
24+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
{
2+
"train_batch_size" : 256,
3+
"train_micro_batch_size_per_gpu": 16,
4+
"steps_per_print": 100,
5+
"zero_optimization": {
6+
"stage": 0
7+
},
8+
"bf16": {
9+
"enabled": true
10+
}
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
DS_CONFIG=./examples_deepspeed/finetune_hf_llama/ds_config.json
2+
DATASET_PATH=./alpaca_data.json
3+
# dataset link: https://github.com/tatsu-lab/stanford_alpaca/blob/main/alpaca_data.json
4+
5+
HF_LLAMA_PATH=/data/llama-7b/
6+
# weights link: https://huggingface.co/huggyllama/llama-7b
7+
8+
MICRO_BATCH_SIZE=16
9+
GLOBAL_BATCH_SIZE=256
10+
TP=2
11+
PP=2
12+
# require to align with weight dimensions
13+
HIDDEN_SIZE=4096
14+
FFN_HIDDEN_SIZE=11008
15+
NUM_LAYERS=32
16+
NUM_HEADS=32
17+
SEQ_LENGTH=512
18+
######################################
19+
20+
MEGA_DS_LLAMA_PATH=./"llama-7b-mega-ds-T${TP}P${PP}"
21+
22+
# Below configuration required for llama model as per llama paper
23+
# --no-query-key-layer-scaling \
24+
# --attention-dropout 0 \
25+
# --hidden-dropout 0 \
26+
# --use-rotary-position-embeddings \
27+
# --untie-embeddings-and-output-weights \
28+
# --swiglu \
29+
# --normalization rmsnorm \
30+
# --disable-bias-linear \
31+
######################################
32+
cat <<EOT > $DS_CONFIG
33+
{
34+
"train_batch_size" : $GLOBAL_BATCH_SIZE,
35+
"train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
36+
"steps_per_print": 100,
37+
"zero_optimization": {
38+
"stage": 0
39+
},
40+
"bf16": {
41+
"enabled": true
42+
}
43+
}
44+
EOT
45+
46+
47+
covert_args="deepspeed tools/hf2megads_weight_converter.py \
48+
--hf-ckpt-num-shards 2 \
49+
--origin-hf-ckpt-dir $HF_LLAMA_PATH \
50+
--save $MEGA_DS_LLAMA_PATH"
51+
52+
finetune_args="deepspeed finetune_llama.py \
53+
--load $MEGA_DS_LLAMA_PATH"
54+
55+
comm_args="--tensor-model-parallel-size $TP \
56+
--pipeline-model-parallel-size $PP \
57+
--lr-warmup-iters 2000 \
58+
--weight-decay 0.1 \
59+
--clip-grad 1 \
60+
--num-layers $NUM_LAYERS \
61+
--hidden-size $HIDDEN_SIZE \
62+
--num-attention-heads $NUM_HEADS \
63+
--ffn-hidden-size $FFN_HIDDEN_SIZE \
64+
--attention-dropout 0 \
65+
--hidden-dropout 0 \
66+
--no-query-key-layer-scaling \
67+
--disable-bias-linear \
68+
--normalization rmsnorm \
69+
--use-rotary-position-embeddings \
70+
--untie-embeddings-and-output-weights \
71+
--swiglu \
72+
--seq-length $SEQ_LENGTH \
73+
--max-position-embeddings $SEQ_LENGTH \
74+
--micro-batch-size $MICRO_BATCH_SIZE \
75+
--global-batch-size $GLOBAL_BATCH_SIZE \
76+
--train-iters 3500 \
77+
--lr 2e-5 \
78+
--tensorboard-dir tensorboard_output \
79+
--lr-decay-iters 320000 \
80+
--lr-decay-style cosine \
81+
--log-interval 1 \
82+
--eval-iters 100 \
83+
--eval-interval 100 \
84+
--data-path $DATASET_PATH \
85+
--save-interval 1500 \
86+
--split 100,0,0 \
87+
--bf16 \
88+
--zero-stage 0 \
89+
--tokenizer-type HFTokenizer \
90+
--tokenizer-model $HF_LLAMA_PATH \
91+
--deepspeed_config ./examples_deepspeed/finetune_hf_llama/ds_config.json \
92+
--deepspeed \
93+
--distributed-backend nccl \
94+
--num-workers 0 \
95+
--no-masked-softmax-fusion \
96+
--no-bias-gelu-fusion \
97+
--no-bias-dropout-fusion \
98+
--no-gradient-accumulation-fusion \
99+
--repeated-dataloader"
100+
101+
if [ "$1" = "convert" ]; then
102+
task_args="$covert_args"
103+
else
104+
task_args="$finetune_args"
105+
fi
106+
107+
full_cmd="$task_args $comm_args"
108+
109+
eval "$full_cmd"
110+

0 commit comments

Comments
 (0)