Skip to content

Commit f9a786f

Browse files
committedApr 4, 2024
add µP
1 parent d9bbe35 commit f9a786f

11 files changed

+771
-4
lines changed
 

‎.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,4 @@ cython_debug/
162162
.vscode
163163

164164
checkpoints/
165+
wandb/

‎examples/config_tiny_llama.yaml

+4-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ optimizer:
7676
adam_eps: 1.0e-08
7777
clip_grad: 1.0
7878
learning_rate_scheduler:
79-
learning_rate: 0.0003
79+
# learning_rate: 0.0003
80+
learning_rate: 0.01
8081
lr_decay_starting_step: null
8182
lr_decay_steps: 8
8283
lr_decay_style: cosine
@@ -102,7 +103,7 @@ tokens:
102103
batch_accumulation_per_replica: 1
103104
limit_test_batches: 0
104105
limit_val_batches: 0
105-
micro_batch_size: 2
106-
sequence_length: 32
106+
micro_batch_size: 16
107+
sequence_length: 1024
107108
train_steps: 20
108109
val_check_interval: -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints
4+
checkpoints_path_is_shared_file_system: false
5+
resume_checkpoint_path: null
6+
save_initial_state: false
7+
8+
data_stages:
9+
- name: Stable Training Stage
10+
start_training_step: 1
11+
data:
12+
dataset:
13+
dataset_overwrite_cache: false
14+
dataset_processing_num_proc_per_process: 1
15+
hf_dataset_config_name: null
16+
hf_dataset_or_datasets: Fiery101/radar_textbooks
17+
hf_dataset_splits: train
18+
text_column_name: text
19+
num_loading_workers: 1
20+
seed: 42
21+
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: false
26+
project: µTransfer_for_nanotron
27+
run: llama_width_1024_config
28+
seed: 42
29+
step: null
30+
lighteval: null
31+
logging:
32+
iteration_step_info_interval: 1
33+
log_level: info
34+
log_level_replica: info
35+
model:
36+
ddp_bucket_cap_mb: 120
37+
dtype: bfloat16
38+
init_method:
39+
# std: 0.025 # original
40+
# std: 0.03125 # 1/sqrt(1024)=0.03125
41+
std: 0.022097086912079608 # 1/sqrt(2048)=0.022097086912079608
42+
make_vocab_size_divisible_by: 1
43+
model_config:
44+
bos_token_id: 1
45+
eos_token_id: 2
46+
hidden_act: silu
47+
initializer_range: 0.02
48+
49+
# NOTE: 250m
50+
# hidden_size: 1024
51+
# intermediate_size: 4096
52+
# num_hidden_layers: 10
53+
54+
hidden_size: 1024
55+
intermediate_size: 4096
56+
num_hidden_layers: 4
57+
58+
is_llama_config: true
59+
max_position_embeddings: 1024
60+
num_attention_heads: 32
61+
num_key_value_heads: 4
62+
pad_token_id: null
63+
pretraining_tp: 1
64+
rms_norm_eps: 1.0e-05
65+
rope_scaling: null
66+
# tie_word_embeddings: true
67+
tie_word_embeddings: false # original use true
68+
use_cache: true
69+
vocab_size: 49152
70+
optimizer:
71+
accumulate_grad_in_fp32: false
72+
adam_beta1: 0.9
73+
adam_beta2: 0.95
74+
adam_eps: 1.0e-08
75+
clip_grad: 1.0
76+
learning_rate_scheduler:
77+
learning_rate: 0.001 # note: 1/2 of pythia use this for a 400m model
78+
lr_decay_starting_step: null
79+
lr_decay_steps: null
80+
lr_decay_style: cosine
81+
lr_warmup_steps: 6 # 10% warm up of total training steps
82+
lr_warmup_style: linear
83+
min_decay_lr: 1.0e-05
84+
torch_adam_is_fused: true
85+
weight_decay: 0.1
86+
zero_stage: 1
87+
parallelism:
88+
dp: 2
89+
pp: 1
90+
pp_engine: 1f1b
91+
tp: 4
92+
tp_linear_async_communication: true
93+
tp_mode: REDUCE_SCATTER
94+
profiler: null
95+
tokenizer:
96+
tokenizer_max_length: null
97+
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
98+
tokenizer_revision: null
99+
tokens:
100+
batch_accumulation_per_replica: 1
101+
limit_test_batches: 0
102+
limit_val_batches: 0
103+
micro_batch_size: 64
104+
sequence_length: 512
105+
train_steps: 30
106+
val_check_interval: -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints
4+
checkpoints_path_is_shared_file_system: false
5+
resume_checkpoint_path: null
6+
save_initial_state: false
7+
8+
data_stages:
9+
- name: Stable Training Stage
10+
start_training_step: 1
11+
data:
12+
dataset:
13+
dataset_overwrite_cache: false
14+
dataset_processing_num_proc_per_process: 1
15+
hf_dataset_config_name: null
16+
hf_dataset_or_datasets: Fiery101/radar_textbooks
17+
hf_dataset_splits: train
18+
text_column_name: text
19+
num_loading_workers: 1
20+
seed: 42
21+
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: false
26+
project: µTransfer_for_nanotron
27+
run: llama_width_128_config
28+
seed: 42
29+
step: null
30+
lighteval: null
31+
logging:
32+
iteration_step_info_interval: 1
33+
log_level: info
34+
log_level_replica: info
35+
model:
36+
ddp_bucket_cap_mb: 120
37+
dtype: bfloat16
38+
init_method:
39+
# std: 0.025 # original
40+
# std: 0.03125 # 1/sqrt(1024)=0.03125
41+
std: 0.022097086912079608 # 1/sqrt(2048)=0.022097086912079608
42+
make_vocab_size_divisible_by: 1
43+
model_config:
44+
bos_token_id: 1
45+
eos_token_id: 2
46+
hidden_act: silu
47+
initializer_range: 0.02
48+
49+
# NOTE: 250m
50+
# hidden_size: 1024
51+
# intermediate_size: 4096
52+
# num_hidden_layers: 10
53+
54+
hidden_size: 128
55+
intermediate_size: 512
56+
num_hidden_layers: 4
57+
58+
is_llama_config: true
59+
max_position_embeddings: 1024
60+
num_attention_heads: 32
61+
num_key_value_heads: 4
62+
pad_token_id: null
63+
pretraining_tp: 1
64+
rms_norm_eps: 1.0e-05
65+
rope_scaling: null
66+
# tie_word_embeddings: true
67+
tie_word_embeddings: false # original use true
68+
use_cache: true
69+
vocab_size: 49152
70+
optimizer:
71+
accumulate_grad_in_fp32: false
72+
adam_beta1: 0.9
73+
adam_beta2: 0.95
74+
adam_eps: 1.0e-08
75+
clip_grad: 1.0
76+
learning_rate_scheduler:
77+
learning_rate: 0.001 # note: 1/2 of pythia use this for a 400m model
78+
lr_decay_starting_step: null
79+
lr_decay_steps: null
80+
lr_decay_style: cosine
81+
lr_warmup_steps: 6 # 10% warm up of total training steps
82+
lr_warmup_style: linear
83+
min_decay_lr: 1.0e-05
84+
torch_adam_is_fused: true
85+
weight_decay: 0.1
86+
zero_stage: 1
87+
parallelism:
88+
dp: 2
89+
pp: 1
90+
pp_engine: 1f1b
91+
tp: 4
92+
tp_linear_async_communication: true
93+
tp_mode: REDUCE_SCATTER
94+
profiler: null
95+
tokenizer:
96+
tokenizer_max_length: null
97+
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
98+
tokenizer_revision: null
99+
tokens:
100+
batch_accumulation_per_replica: 1
101+
limit_test_batches: 0
102+
limit_val_batches: 0
103+
micro_batch_size: 64
104+
sequence_length: 512
105+
train_steps: 30
106+
val_check_interval: -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints
4+
checkpoints_path_is_shared_file_system: false
5+
resume_checkpoint_path: null
6+
save_initial_state: false
7+
8+
data_stages:
9+
- name: Stable Training Stage
10+
start_training_step: 1
11+
data:
12+
dataset:
13+
dataset_overwrite_cache: false
14+
dataset_processing_num_proc_per_process: 1
15+
hf_dataset_config_name: null
16+
hf_dataset_or_datasets: Fiery101/radar_textbooks
17+
hf_dataset_splits: train
18+
text_column_name: text
19+
num_loading_workers: 1
20+
seed: 42
21+
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: false
26+
project: µTransfer_for_nanotron
27+
run: llama_width_2048_config
28+
seed: 42
29+
step: null
30+
lighteval: null
31+
logging:
32+
iteration_step_info_interval: 1
33+
log_level: info
34+
log_level_replica: info
35+
model:
36+
ddp_bucket_cap_mb: 120
37+
dtype: bfloat16
38+
init_method:
39+
# std: 0.025 # original
40+
# std: 0.03125 # 1/sqrt(1024)=0.03125
41+
std: 0.022097086912079608 # 1/sqrt(2048)=0.022097086912079608
42+
make_vocab_size_divisible_by: 1
43+
model_config:
44+
bos_token_id: 1
45+
eos_token_id: 2
46+
hidden_act: silu
47+
initializer_range: 0.02
48+
49+
# NOTE: 250m
50+
# hidden_size: 1024
51+
# intermediate_size: 4096
52+
# num_hidden_layers: 10
53+
54+
hidden_size: 2048
55+
intermediate_size: 8192
56+
num_hidden_layers: 4
57+
58+
is_llama_config: true
59+
max_position_embeddings: 1024
60+
num_attention_heads: 32
61+
num_key_value_heads: 4
62+
pad_token_id: null
63+
pretraining_tp: 1
64+
rms_norm_eps: 1.0e-05
65+
rope_scaling: null
66+
# tie_word_embeddings: true
67+
tie_word_embeddings: false # original use true
68+
use_cache: true
69+
vocab_size: 49152
70+
optimizer:
71+
accumulate_grad_in_fp32: false
72+
adam_beta1: 0.9
73+
adam_beta2: 0.95
74+
adam_eps: 1.0e-08
75+
clip_grad: 1.0
76+
learning_rate_scheduler:
77+
learning_rate: 0.001 # note: 1/2 of pythia use this for a 400m model
78+
lr_decay_starting_step: null
79+
lr_decay_steps: null
80+
lr_decay_style: cosine
81+
lr_warmup_steps: 6 # 10% warm up of total training steps
82+
lr_warmup_style: linear
83+
min_decay_lr: 1.0e-05
84+
torch_adam_is_fused: true
85+
weight_decay: 0.1
86+
zero_stage: 1
87+
parallelism:
88+
dp: 2
89+
pp: 1
90+
pp_engine: 1f1b
91+
tp: 4
92+
tp_linear_async_communication: true
93+
tp_mode: REDUCE_SCATTER
94+
profiler: null
95+
tokenizer:
96+
tokenizer_max_length: null
97+
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
98+
tokenizer_revision: null
99+
tokens:
100+
batch_accumulation_per_replica: 1
101+
limit_test_batches: 0
102+
limit_val_batches: 0
103+
micro_batch_size: 64
104+
sequence_length: 512
105+
train_steps: 30
106+
val_check_interval: -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints
4+
checkpoints_path_is_shared_file_system: false
5+
resume_checkpoint_path: null
6+
save_initial_state: false
7+
8+
data_stages:
9+
- name: Stable Training Stage
10+
start_training_step: 1
11+
data:
12+
dataset:
13+
dataset_overwrite_cache: false
14+
dataset_processing_num_proc_per_process: 1
15+
hf_dataset_config_name: null
16+
hf_dataset_or_datasets: Fiery101/radar_textbooks
17+
hf_dataset_splits: train
18+
text_column_name: text
19+
num_loading_workers: 1
20+
seed: 42
21+
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: false
26+
project: µTransfer_for_nanotron
27+
run: llama_width_256_config
28+
seed: 42
29+
step: null
30+
lighteval: null
31+
logging:
32+
iteration_step_info_interval: 1
33+
log_level: info
34+
log_level_replica: info
35+
model:
36+
ddp_bucket_cap_mb: 120
37+
dtype: bfloat16
38+
init_method:
39+
# std: 0.025 # original
40+
# std: 0.03125 # 1/sqrt(1024)=0.03125
41+
std: 0.022097086912079608 # 1/sqrt(2048)=0.022097086912079608
42+
make_vocab_size_divisible_by: 1
43+
model_config:
44+
bos_token_id: 1
45+
eos_token_id: 2
46+
hidden_act: silu
47+
initializer_range: 0.02
48+
49+
# NOTE: 250m
50+
# hidden_size: 1024
51+
# intermediate_size: 4096
52+
# num_hidden_layers: 10
53+
54+
hidden_size: 256
55+
intermediate_size: 1024
56+
num_hidden_layers: 4
57+
58+
is_llama_config: true
59+
max_position_embeddings: 1024
60+
num_attention_heads: 32
61+
num_key_value_heads: 4
62+
pad_token_id: null
63+
pretraining_tp: 1
64+
rms_norm_eps: 1.0e-05
65+
rope_scaling: null
66+
# tie_word_embeddings: true
67+
tie_word_embeddings: false # original use true
68+
use_cache: true
69+
vocab_size: 49152
70+
optimizer:
71+
accumulate_grad_in_fp32: false
72+
adam_beta1: 0.9
73+
adam_beta2: 0.95
74+
adam_eps: 1.0e-08
75+
clip_grad: 1.0
76+
learning_rate_scheduler:
77+
learning_rate: 0.001 # note: 1/2 of pythia use this for a 400m model
78+
lr_decay_starting_step: null
79+
lr_decay_steps: null
80+
lr_decay_style: cosine
81+
lr_warmup_steps: 6 # 10% warm up of total training steps
82+
lr_warmup_style: linear
83+
min_decay_lr: 1.0e-05
84+
torch_adam_is_fused: true
85+
weight_decay: 0.1
86+
zero_stage: 1
87+
parallelism:
88+
dp: 2
89+
pp: 1
90+
pp_engine: 1f1b
91+
tp: 4
92+
tp_linear_async_communication: true
93+
tp_mode: REDUCE_SCATTER
94+
profiler: null
95+
tokenizer:
96+
tokenizer_max_length: null
97+
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
98+
tokenizer_revision: null
99+
tokens:
100+
batch_accumulation_per_replica: 1
101+
limit_test_batches: 0
102+
limit_val_batches: 0
103+
micro_batch_size: 64
104+
sequence_length: 512
105+
train_steps: 30
106+
val_check_interval: -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints
4+
checkpoints_path_is_shared_file_system: false
5+
resume_checkpoint_path: null
6+
save_initial_state: false
7+
8+
data_stages:
9+
- name: Stable Training Stage
10+
start_training_step: 1
11+
data:
12+
dataset:
13+
dataset_overwrite_cache: false
14+
dataset_processing_num_proc_per_process: 1
15+
hf_dataset_config_name: null
16+
hf_dataset_or_datasets: Fiery101/radar_textbooks
17+
hf_dataset_splits: train
18+
text_column_name: text
19+
num_loading_workers: 1
20+
seed: 42
21+
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: false
26+
project: µTransfer_for_nanotron
27+
run: llama_width_4096_config
28+
seed: 42
29+
step: null
30+
lighteval: null
31+
logging:
32+
iteration_step_info_interval: 1
33+
log_level: info
34+
log_level_replica: info
35+
model:
36+
ddp_bucket_cap_mb: 120
37+
dtype: bfloat16
38+
init_method:
39+
# std: 0.025 # original
40+
# std: 0.03125 # 1/sqrt(1024)=0.03125
41+
std: 0.022097086912079608 # 1/sqrt(2048)=0.022097086912079608
42+
make_vocab_size_divisible_by: 1
43+
model_config:
44+
bos_token_id: 1
45+
eos_token_id: 2
46+
hidden_act: silu
47+
initializer_range: 0.02
48+
49+
# NOTE: 250m
50+
# hidden_size: 1024
51+
# intermediate_size: 4096
52+
# num_hidden_layers: 10
53+
54+
hidden_size: 4096
55+
intermediate_size: 16384
56+
num_hidden_layers: 4
57+
58+
is_llama_config: true
59+
max_position_embeddings: 1024
60+
num_attention_heads: 32
61+
num_key_value_heads: 4
62+
pad_token_id: null
63+
pretraining_tp: 1
64+
rms_norm_eps: 1.0e-05
65+
rope_scaling: null
66+
# tie_word_embeddings: true
67+
tie_word_embeddings: false # original use true
68+
use_cache: true
69+
vocab_size: 49152
70+
optimizer:
71+
accumulate_grad_in_fp32: false
72+
adam_beta1: 0.9
73+
adam_beta2: 0.95
74+
adam_eps: 1.0e-08
75+
clip_grad: 1.0
76+
learning_rate_scheduler:
77+
learning_rate: 0.001 # note: 1/2 of pythia use this for a 400m model
78+
lr_decay_starting_step: null
79+
lr_decay_steps: null
80+
lr_decay_style: cosine
81+
lr_warmup_steps: 6 # 10% warm up of total training steps
82+
lr_warmup_style: linear
83+
min_decay_lr: 1.0e-05
84+
torch_adam_is_fused: true
85+
weight_decay: 0.1
86+
zero_stage: 1
87+
parallelism:
88+
dp: 2
89+
pp: 1
90+
pp_engine: 1f1b
91+
tp: 4
92+
tp_linear_async_communication: true
93+
tp_mode: REDUCE_SCATTER
94+
profiler: null
95+
tokenizer:
96+
tokenizer_max_length: null
97+
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
98+
tokenizer_revision: null
99+
tokens:
100+
batch_accumulation_per_replica: 1
101+
limit_test_batches: 0
102+
limit_val_batches: 0
103+
micro_batch_size: 64
104+
sequence_length: 512
105+
train_steps: 30
106+
val_check_interval: -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints
4+
checkpoints_path_is_shared_file_system: false
5+
resume_checkpoint_path: null
6+
save_initial_state: false
7+
8+
data_stages:
9+
- name: Stable Training Stage
10+
start_training_step: 1
11+
data:
12+
dataset:
13+
dataset_overwrite_cache: false
14+
dataset_processing_num_proc_per_process: 1
15+
hf_dataset_config_name: null
16+
hf_dataset_or_datasets: Fiery101/radar_textbooks
17+
hf_dataset_splits: train
18+
text_column_name: text
19+
num_loading_workers: 1
20+
seed: 42
21+
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: false
26+
project: µTransfer_for_nanotron
27+
run: llama_width_512_config
28+
seed: 42
29+
step: null
30+
lighteval: null
31+
logging:
32+
iteration_step_info_interval: 1
33+
log_level: info
34+
log_level_replica: info
35+
model:
36+
ddp_bucket_cap_mb: 120
37+
dtype: bfloat16
38+
init_method:
39+
# std: 0.025 # original
40+
# std: 0.03125 # 1/sqrt(1024)=0.03125
41+
std: 0.022097086912079608 # 1/sqrt(2048)=0.022097086912079608
42+
make_vocab_size_divisible_by: 1
43+
model_config:
44+
bos_token_id: 1
45+
eos_token_id: 2
46+
hidden_act: silu
47+
initializer_range: 0.02
48+
49+
# NOTE: 250m
50+
# hidden_size: 1024
51+
# intermediate_size: 4096
52+
# num_hidden_layers: 10
53+
54+
hidden_size: 512
55+
intermediate_size: 2048
56+
num_hidden_layers: 4
57+
58+
is_llama_config: true
59+
max_position_embeddings: 1024
60+
num_attention_heads: 32
61+
num_key_value_heads: 4
62+
pad_token_id: null
63+
pretraining_tp: 1
64+
rms_norm_eps: 1.0e-05
65+
rope_scaling: null
66+
# tie_word_embeddings: true
67+
tie_word_embeddings: false # original use true
68+
use_cache: true
69+
vocab_size: 49152
70+
optimizer:
71+
accumulate_grad_in_fp32: false
72+
adam_beta1: 0.9
73+
adam_beta2: 0.95
74+
adam_eps: 1.0e-08
75+
clip_grad: 1.0
76+
learning_rate_scheduler:
77+
learning_rate: 0.001 # note: 1/2 of pythia use this for a 400m model
78+
lr_decay_starting_step: null
79+
lr_decay_steps: null
80+
lr_decay_style: cosine
81+
lr_warmup_steps: 6 # 10% warm up of total training steps
82+
lr_warmup_style: linear
83+
min_decay_lr: 1.0e-05
84+
torch_adam_is_fused: true
85+
weight_decay: 0.1
86+
zero_stage: 1
87+
parallelism:
88+
dp: 2
89+
pp: 1
90+
pp_engine: 1f1b
91+
tp: 4
92+
tp_linear_async_communication: true
93+
tp_mode: REDUCE_SCATTER
94+
profiler: null
95+
tokenizer:
96+
tokenizer_max_length: null
97+
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
98+
tokenizer_revision: null
99+
tokens:
100+
batch_accumulation_per_replica: 1
101+
limit_test_batches: 0
102+
limit_val_batches: 0
103+
micro_batch_size: 64
104+
sequence_length: 512
105+
train_steps: 30
106+
val_check_interval: -1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
checkpoints:
2+
checkpoint_interval: 1000
3+
checkpoints_path: checkpoints
4+
checkpoints_path_is_shared_file_system: false
5+
resume_checkpoint_path: null
6+
save_initial_state: false
7+
8+
data_stages:
9+
- name: Stable Training Stage
10+
start_training_step: 1
11+
data:
12+
dataset:
13+
dataset_overwrite_cache: false
14+
dataset_processing_num_proc_per_process: 1
15+
hf_dataset_config_name: null
16+
hf_dataset_or_datasets: Fiery101/radar_textbooks
17+
hf_dataset_splits: train
18+
text_column_name: text
19+
num_loading_workers: 1
20+
seed: 42
21+
22+
general:
23+
benchmark_csv_path: null
24+
consumed_train_samples: null
25+
ignore_sanity_checks: false
26+
project: µTransfer_for_nanotron
27+
run: llama_width_8192_config
28+
seed: 42
29+
step: null
30+
lighteval: null
31+
logging:
32+
iteration_step_info_interval: 1
33+
log_level: info
34+
log_level_replica: info
35+
model:
36+
ddp_bucket_cap_mb: 120
37+
dtype: bfloat16
38+
init_method:
39+
# std: 0.025 # original
40+
# std: 0.03125 # 1/sqrt(1024)=0.03125
41+
std: 0.022097086912079608 # 1/sqrt(2048)=0.022097086912079608
42+
make_vocab_size_divisible_by: 1
43+
model_config:
44+
bos_token_id: 1
45+
eos_token_id: 2
46+
hidden_act: silu
47+
initializer_range: 0.02
48+
49+
# NOTE: 250m
50+
# hidden_size: 1024
51+
# intermediate_size: 4096
52+
# num_hidden_layers: 10
53+
54+
hidden_size: 8192
55+
intermediate_size: 32768
56+
num_hidden_layers: 4
57+
58+
is_llama_config: true
59+
max_position_embeddings: 1024
60+
num_attention_heads: 32
61+
num_key_value_heads: 4
62+
pad_token_id: null
63+
pretraining_tp: 1
64+
rms_norm_eps: 1.0e-05
65+
rope_scaling: null
66+
# tie_word_embeddings: true
67+
tie_word_embeddings: false # original use true
68+
use_cache: true
69+
vocab_size: 49152
70+
optimizer:
71+
accumulate_grad_in_fp32: false
72+
adam_beta1: 0.9
73+
adam_beta2: 0.95
74+
adam_eps: 1.0e-08
75+
clip_grad: 1.0
76+
learning_rate_scheduler:
77+
learning_rate: 0.001 # note: 1/2 of pythia use this for a 400m model
78+
lr_decay_starting_step: null
79+
lr_decay_steps: null
80+
lr_decay_style: cosine
81+
lr_warmup_steps: 6 # 10% warm up of total training steps
82+
lr_warmup_style: linear
83+
min_decay_lr: 1.0e-05
84+
torch_adam_is_fused: true
85+
weight_decay: 0.1
86+
zero_stage: 1
87+
parallelism:
88+
dp: 2
89+
pp: 1
90+
pp_engine: 1f1b
91+
tp: 4
92+
tp_linear_async_communication: true
93+
tp_mode: REDUCE_SCATTER
94+
profiler: null
95+
tokenizer:
96+
tokenizer_max_length: null
97+
tokenizer_name_or_path: lvwerra/the-tokenizer-v1
98+
tokenizer_revision: null
99+
tokens:
100+
batch_accumulation_per_replica: 1
101+
limit_test_batches: 0
102+
limit_val_batches: 0
103+
micro_batch_size: 64
104+
sequence_length: 512
105+
train_steps: 30
106+
val_check_interval: -1

‎src/nanotron/models/llama.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def forward(
213213
# TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not
214214
# what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache.
215215
causal = False if q_sequence_mask.shape[1] == 1 else True
216+
# NOTE: this scale is for µTransfer
217+
softmax_scale = 1 / query_states.shape[-1]
216218
attn_output = flash_attn_varlen_func(
217219
q=query_states,
218220
k=key_states,
@@ -222,7 +224,7 @@ def forward(
222224
max_seqlen_q=q_sequence_mask.shape[1],
223225
max_seqlen_k=kv_sequence_mask.shape[1],
224226
dropout_p=0.0,
225-
softmax_scale=None, # This already defaults to the scale I'm interested in
227+
softmax_scale=softmax_scale, # This already defaults to the scale I'm interested in
226228
causal=causal,
227229
return_attn_probs=False,
228230
)
@@ -774,6 +776,19 @@ def forward_with_hidden_states(
774776
for encoder_block in self.decoder:
775777
hidden_encoder_states = encoder_block(**hidden_encoder_states)
776778

779+
# hidden_states.shape = [seq_length/tp_rank, batch_size, hidden_dim]
780+
mup_l1_norm = hidden_encoder_states["hidden_states"].mean(dim=[0, 1]).abs() # [hidden_dim]
781+
dist.all_reduce(
782+
mup_l1_norm, op=dist.ReduceOp.SUM, group=self.parallel_context.tp_pg
783+
) # sum [hidden_dim] across tp ranks
784+
mup_l1_norm = mup_l1_norm.mean()
785+
dist.all_reduce(mup_l1_norm, op=dist.ReduceOp.AVG, group=self.parallel_context.dp_pg)
786+
787+
if dist.get_rank() == 0:
788+
import wandb
789+
790+
wandb.log({"output_l1_norm": mup_l1_norm.cpu().detach().float().numpy(), "width": self.config.hidden_size})
791+
777792
hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"]
778793

779794
sharded_logits = self.lm_head(x=hidden_states)["logits"]
@@ -839,6 +854,10 @@ def forward(
839854
) -> Dict[str, torch.Tensor]:
840855
# Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision.
841856
# https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38
857+
858+
# mup_logit_l1_norm = sharded_logits.mean(dim=[0,1]).abs()
859+
# dist.all_reduce(mup_logit_l1_norm, op=dist.ReduceOp.AVERAGE, group=self.tp_pg)
860+
842861
loss = sharded_cross_entropy(
843862
sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float
844863
).transpose(0, 1)

‎src/nanotron/scaling.py

+4
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,7 @@ def output_weight_hook(module, input, output):
172172
raise ValueError(f"Unknown linear type: {module.linear_type}")
173173

174174
module.register_forward_hook(hook_func)
175+
176+
177+
def monitor_l1_norm_activations(model: nn.Module):
178+
pass

0 commit comments

Comments
 (0)
Please sign in to comment.