Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

muP implementation #637

Open
wants to merge 151 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
151 commits
Select commit Hold shift + click to select a range
52f3d53
mup olmo
AkshitaB Feb 21, 2024
4d715ad
wip
AkshitaB Mar 20, 2024
43d674d
Merge branch 'main' into akshitab-scale
AkshitaB Mar 20, 2024
675290f
MuOLMo, coord check (wip)
AkshitaB Mar 21, 2024
e4a80bc
coord check runs
AkshitaB Apr 10, 2024
2354425
remove extra args, add readme
AkshitaB Apr 10, 2024
07df33e
requirements
AkshitaB Apr 10, 2024
2dd327b
Update README.md
AkshitaB Apr 10, 2024
74f56a9
Update README.md
AkshitaB Apr 10, 2024
325cad1
coord checks
AkshitaB Apr 10, 2024
a31b698
legend for coord checks
AkshitaB Apr 10, 2024
d1533b1
updates to mup implementation
AkshitaB Apr 11, 2024
c4dcbda
coord checks
AkshitaB Apr 11, 2024
ff3778a
Merge branch 'akshitab-scale' of https://github.com/allenai/OLMo into…
AkshitaB Apr 11, 2024
0f39e6e
coord checks
AkshitaB Apr 11, 2024
c180a8e
smaller model
AkshitaB Apr 11, 2024
d498385
fixes for network output
AkshitaB Apr 11, 2024
e566b72
smaller bs, larger model
AkshitaB Apr 11, 2024
12016bd
lr=0.005, bs=2
AkshitaB Apr 11, 2024
a5807a1
confirm normal init
AkshitaB Apr 11, 2024
4285275
scale by d
AkshitaB Apr 11, 2024
85b23fd
scale by d
AkshitaB Apr 11, 2024
fe5c71d
remove unnecessary code
AkshitaB Apr 11, 2024
c875dcc
extra cleanup
AkshitaB Apr 11, 2024
27a4eef
more width
AkshitaB Apr 11, 2024
1a35e1c
Update README.md
AkshitaB Apr 11, 2024
9aa8e66
Update README.md
AkshitaB May 7, 2024
ccbcd08
don't use readout
AkshitaB May 7, 2024
0ea01e3
init scale
AkshitaB May 7, 2024
29eaf76
rope
AkshitaB May 7, 2024
d4ab24f
readme
AkshitaB May 7, 2024
520fd72
remove plots
AkshitaB May 8, 2024
ea58502
formatting
AkshitaB May 8, 2024
d3abad1
get_batch_loss func
AkshitaB May 8, 2024
d61de42
wip: train, eval funcs
AkshitaB May 8, 2024
acd5032
gitignore
AkshitaB May 8, 2024
505f46e
bug fix
AkshitaB May 9, 2024
5a73a1a
coords file
AkshitaB May 9, 2024
cd326aa
easy file name
AkshitaB May 10, 2024
efcd37b
script for examining coords
AkshitaB May 10, 2024
f7f7f34
Merge branch 'main' into akshitab-scale
AkshitaB Jun 23, 2024
a9ea54b
add mup to main model
AkshitaB Jun 24, 2024
06101ad
tests
AkshitaB Jun 24, 2024
6b65600
easier testing
AkshitaB Jun 24, 2024
d088220
simplified coord check
AkshitaB Jun 24, 2024
c2595ea
30 iter coord checks
AkshitaB Jun 24, 2024
b055d41
rename
AkshitaB Jun 24, 2024
424a361
refactor, add tests
AkshitaB Jun 24, 2024
effe212
fix save_base_shapes
AkshitaB Jun 24, 2024
67158cb
Revert "fix save_base_shapes"
AkshitaB Jun 24, 2024
5c02f3f
Revert "refactor, add tests"
AkshitaB Jun 24, 2024
f6722b5
rename output files
AkshitaB Jun 24, 2024
f0775f8
remove extra code
AkshitaB Jun 24, 2024
febca85
function for save_base_shapes
AkshitaB Jun 24, 2024
41d293d
make plot optional, for testing purposes
AkshitaB Jun 24, 2024
d060fb5
make widths configurable
AkshitaB Jun 24, 2024
e7b6f66
split out the args
AkshitaB Jun 24, 2024
14db801
rename plotdir
AkshitaB Jun 24, 2024
e001879
add tests
AkshitaB Jun 24, 2024
eb2f3ee
rename to coordinates
AkshitaB Jun 24, 2024
6bbca2d
remove dict_in_out
AkshitaB Jun 24, 2024
5213a21
remove unused args
AkshitaB Jun 25, 2024
258d8d3
always compute loss
AkshitaB Jun 25, 2024
c44c8f1
add bigger config
AkshitaB Jun 25, 2024
282fbe6
updated outputs
AkshitaB Jun 25, 2024
81b331f
reorganize
AkshitaB Jun 25, 2024
954f761
update paths in readme
AkshitaB Jun 25, 2024
48fa137
add scaffolding for scaling laws
AkshitaB Jun 25, 2024
1d5eaf2
Merge branch 'main' into akshitab-scale
AkshitaB Jun 28, 2024
ea900ee
fix lint, etc
AkshitaB Jun 28, 2024
1302614
correct mup version
AkshitaB Jun 28, 2024
7fb6839
minor updates
AkshitaB Jun 29, 2024
96a61b1
fix pyproject for testing
AkshitaB Jun 29, 2024
59e9e67
update changelog
AkshitaB Jun 30, 2024
e3b7253
ensure config runs
AkshitaB Jun 30, 2024
26d1abc
orig_params with fsdp
AkshitaB Jul 1, 2024
dc53949
Merge branch 'akshitab-scale' of https://github.com/allenai/OLMo into…
AkshitaB Jul 1, 2024
288e4a2
remove unnecessary code
AkshitaB Jul 1, 2024
e8ea0fa
Merge branch 'akshitab-scale' of https://github.com/allenai/OLMo into…
AkshitaB Jul 1, 2024
f70ffa4
updated coord checks
AkshitaB Jul 1, 2024
b0dfe0a
simplified base config
AkshitaB Jul 1, 2024
561b813
use mup
AkshitaB Jul 1, 2024
16c935e
scripts for running check
AkshitaB Jul 1, 2024
e9193cb
rename runs
AkshitaB Jul 1, 2024
823ec37
fixes
AkshitaB Jul 1, 2024
1b83d56
no need with 1 node
AkshitaB Jul 1, 2024
fb76dd0
no weka
AkshitaB Jul 1, 2024
137deaf
install on the fly
AkshitaB Jul 1, 2024
35c9a88
bug fix
AkshitaB Jul 1, 2024
9ff73f6
debug
AkshitaB Jul 1, 2024
fe9fb4d
bug fix again
AkshitaB Jul 1, 2024
078c459
full runs
AkshitaB Jul 1, 2024
c7f6820
no warmup
AkshitaB Jul 1, 2024
016c3c5
linear warmup, same base shape
AkshitaB Jul 1, 2024
50073c0
use correct optimizer
AkshitaB Jul 1, 2024
3ef14dc
fix mup optimizer
AkshitaB Jul 1, 2024
9299f82
fix
AkshitaB Jul 1, 2024
1f15eeb
lower priority
AkshitaB Jul 1, 2024
96aed44
base shapes
AkshitaB Jul 1, 2024
15a2390
fix
AkshitaB Jul 1, 2024
ff8520b
save progress
AkshitaB Jul 1, 2024
019daf7
run on jupiter
AkshitaB Jul 1, 2024
59ef245
more cluster options
AkshitaB Jul 1, 2024
77d6d6d
priority
AkshitaB Jul 1, 2024
7cbf520
Revert "priority"
AkshitaB Jul 1, 2024
9b02f7d
ensure that correct lr is used
AkshitaB Jul 1, 2024
23a6986
simplify config further
AkshitaB Jul 1, 2024
f3218c1
update base shapes
AkshitaB Jul 1, 2024
e27d806
stability at init
AkshitaB Jul 1, 2024
09925d1
run sp baseline
AkshitaB Jul 2, 2024
a4a5b6e
fix readout init
AkshitaB Jul 2, 2024
5dc7210
Merge branch 'akshitab-scale' of https://github.com/allenai/OLMo into…
AkshitaB Jul 2, 2024
d4b395b
updated paths
AkshitaB Jul 2, 2024
1adea44
sp scripts
AkshitaB Jul 2, 2024
5cf94df
run for remaining
AkshitaB Jul 2, 2024
10d93cf
base config for experiments
AkshitaB Jul 9, 2024
9251161
Merge branch 'main' into akshitab-scale
AkshitaB Jul 18, 2024
4fe7b3f
Merge branch 'main' into akshitab-scale
AkshitaB Jul 18, 2024
a750ca8
black
AkshitaB Jul 18, 2024
61ffe77
7B width should be multiple of 64
AkshitaB Jul 18, 2024
91d32fb
move general functions to utils
AkshitaB Jul 18, 2024
23bbed8
mup ladder
AkshitaB Jul 18, 2024
73a6346
new scales
AkshitaB Jul 18, 2024
0b1994d
remove cluster
AkshitaB Jul 18, 2024
10dc91d
remove cluster
AkshitaB Jul 18, 2024
51810df
install mup
AkshitaB Jul 18, 2024
60b4bf5
fix device
AkshitaB Jul 18, 2024
a3b1e81
flash attn
AkshitaB Jul 18, 2024
7e7d596
debug
AkshitaB Jul 18, 2024
42c1d14
more debug
AkshitaB Jul 18, 2024
702afc8
test with more layers
AkshitaB Jul 19, 2024
b600ba9
less layers
AkshitaB Jul 19, 2024
9f539b4
update shape
AkshitaB Jul 19, 2024
c56aff1
fsdp
AkshitaB Jul 19, 2024
a19f5f6
device batch size
AkshitaB Jul 19, 2024
a5c6747
hp sweep
AkshitaB Jul 19, 2024
014606b
make it divisible
AkshitaB Jul 19, 2024
9a3c2df
run with my key
AkshitaB Jul 19, 2024
4d15c81
don't rescale when loading from checkpoint
AkshitaB Jul 19, 2024
908463c
more checkpoints for debugging
AkshitaB Jul 19, 2024
ba2c6bf
checkpoint less often
AkshitaB Jul 19, 2024
de2cf97
run with standard parametrization
AkshitaB Jul 22, 2024
d80702e
no weka, expand cluster options
AkshitaB Jul 22, 2024
b5d040b
increase priority
AkshitaB Jul 22, 2024
4f2dc47
revert
AkshitaB Jul 22, 2024
dacaa95
increase timeout
AkshitaB Jul 22, 2024
32bf956
make mup optional
AkshitaB Jul 22, 2024
447ab47
fix
AkshitaB Jul 22, 2024
b3d871c
bug fix
AkshitaB Jul 22, 2024
8cc17c9
test with cosine
AkshitaB Jul 22, 2024
57ef19f
fix
AkshitaB Jul 22, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ doc/_build/
.DS_Store


# mup artifacts

#*.bsh

# python

*.pyc
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added config options for `model.norm_after`, `model.scale_emb_init`, and `auxiliary_loss_multiplier` (used with zloss).
- Added scripts for running experiments on qk_norm, norm reordering, and zloss.
- Added `mup` implementation for OLMo.

### Changed

Expand Down
134 changes: 134 additions & 0 deletions configs/mup/base-olmo-cosine.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
run_name: base-olmo
seed: 6198
dry_run: false

wandb:
name: ${run_name}
project: olmo-mup

model:
use_mup: false # set in the calling script
mup_query_zero_init: true
d_model: 128
n_heads: 2
n_layers: 2
mlp_ratio: 1
weight_tying: false
alibi: false
rope: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
clip_qkv: null
include_bias: false
block_type: sequential
layer_norm_type: rms
layer_norm_with_affine: true
layer_norm_eps: 1e-6
bias_for_layer_norm: false
attention_layer_norm_with_affine: false
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 4096
vocab_size: 50280
embedding_size: 50304
eos_token_id: 0
pad_token_id: 1
init_device: cuda
init_fn: normal
init_std: 0.02
init_cutoff_factor: 3

ddp:
grad_sync_mode: batch
find_unused_params: false

compile: null

optimizer:
name: adamw
learning_rate: 1.0e-3
weight_decay: 0.1
eps: 1e-8
decay_norm_and_bias: true
decay_embeddings: false
betas:
- 0.9
- 0.95
metrics_log_interval: 10

scheduler:
name: cosine_with_warmup
# t_warmup: 3 * ${model.d_model} / 128 # assuming current model size (128 = 13M)
alpha_f: 0.01
warmup_min_lr: 0.0

tokenizer:
identifier: tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json
truncate_direction: right

save_folder: workspace/${run_name} # doesn't matter since we'll upload to S3
remote_save_folder: s3://ai2-llm/checkpoints/olmo-mup/${run_name}
save_overwrite: false

# Unsharded checkpoints (for ddp)
save_interval_unsharded: 5000
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 1ep
stop_at: 10000
global_train_batch_size: 1024
device_train_microbatch_size: 16

precision: amp_bf16
distributed_strategy: ddp

gen1_gc_interval: 1

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
window_size: 20

eval_interval: 5000
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
- label: all-small-ppl-validation
data:
num_workers: 0
drop_last: true
datasets:
wikitext_103-validation:
- s3://ai2-llm/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy
##########################
# Downstream evaluations #
##########################

- label: hellaswag
type: downstream


data:
pad_direction: right
num_workers: 32
drop_last: true
pin_memory: true
prefetch_factor: 8
persistent_workers: true
timeout: 0
instance_filter:
repetition_max_period: 13
repetition_min_period: 1
repetition_max_count: 32
paths:
######### NON WEB DATA #########
# ~> WIKIPEDIA & WIKIBOOKS (3.689 GT), repeated twice to up-sample
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-0-00000.npy
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-1-00000.npy
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-0-00000.npy
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-1-00000.npy
134 changes: 134 additions & 0 deletions configs/mup/base-olmo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
run_name: base-olmo
seed: 6198
dry_run: false

wandb:
name: ${run_name}
project: olmo-mup

model:
use_mup: false # set in the calling script
mup_query_zero_init: true
d_model: 128
n_heads: 2
n_layers: 2
mlp_ratio: 1
weight_tying: false
alibi: false
rope: true
flash_attention: false
attention_dropout: 0.0
attention_layer_norm: false
clip_qkv: null
include_bias: false
block_type: sequential
layer_norm_type: rms
layer_norm_with_affine: true
layer_norm_eps: 1e-6
bias_for_layer_norm: false
attention_layer_norm_with_affine: false
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 4096
vocab_size: 50280
embedding_size: 50304
eos_token_id: 0
pad_token_id: 1
init_device: cuda
init_fn: normal
init_std: 0.02
init_cutoff_factor: 3

ddp:
grad_sync_mode: batch
find_unused_params: false

compile: null

optimizer:
name: adamw
learning_rate: 1.0e-3
weight_decay: 0.1
eps: 1e-8
decay_norm_and_bias: true
decay_embeddings: false
betas:
- 0.9
- 0.95
metrics_log_interval: 10

scheduler:
name: constant #linear_with_warmup
#t_warmup: 10
alpha_f: 0.1
#warmup_min_lr: 0

tokenizer:
identifier: tokenizers/allenai_gpt-neox-olmo-dolma-v1_5.json
truncate_direction: right

save_folder: workspace/${run_name} # doesn't matter since we'll upload to S3
remote_save_folder: s3://ai2-llm/checkpoints/olmo-mup/${run_name}
save_overwrite: false

# Unsharded checkpoints (for ddp)
save_interval_unsharded: 5000
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 1ep
stop_at: 10000
global_train_batch_size: 1024
device_train_microbatch_size: 16

precision: amp_bf16
distributed_strategy: ddp

gen1_gc_interval: 1

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
window_size: 20

eval_interval: 5000
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
- label: all-small-ppl-validation
data:
num_workers: 0
drop_last: true
datasets:
wikitext_103-validation:
- s3://ai2-llm/eval-data/perplexity/v3_small_gptneox20b/wikitext_103/val/part-0-00000.npy
##########################
# Downstream evaluations #
##########################

- label: hellaswag
type: downstream


data:
pad_direction: right
num_workers: 32
drop_last: true
pin_memory: true
prefetch_factor: 8
persistent_workers: true
timeout: 0
instance_filter:
repetition_max_period: 13
repetition_min_period: 1
repetition_max_count: 32
paths:
######### NON WEB DATA #########
# ~> WIKIPEDIA & WIKIBOOKS (3.689 GT), repeated twice to up-sample
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-0-00000.npy
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-1-00000.npy
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-0-00000.npy
- s3://ai2-llm/preprocessed/olmo-mix/v1_6-decontaminated/wiki/gpt-neox-olmo-dolma-v1_5/part-1-00000.npy
Loading