Skip to content

UpcycleLm script#987

Open
blahBlahhhJ wants to merge 3 commits intomainfrom
jasonw/upcycle
Open

UpcycleLm script#987
blahBlahhhJ wants to merge 3 commits intomainfrom
jasonw/upcycle

Conversation

@blahBlahhhJ
Copy link
Contributor

The script upcycles dense model(s) into a sparse (moe) model by replicating/concatenating mlp into moe layer and use the average of the rest.

An example config can be:

checkpoint_paths: ["tmp_ckpts/dense/step-50"]
data: !include data/openwebtext_source.yaml
dense_model: 
  type: llama
  hidden_dim: 32
  intermediate_dim: 128
  num_heads: 4
  num_kv_heads: 4
  num_layers: 2
sparse_model: 
  type: mixtral
  hidden_dim: 32
  intermediate_dim: 64
  num_heads: 4
  num_kv_heads: 4
  num_layers: 2
trainer:
  checkpointer:
    base_path: tmp_ckpts/upcycled/

@@ -0,0 +1,155 @@
import dataclasses
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc comment at the top about what this is for

Comment on lines +92 to +95
if avg_model is None:
avg_model = jax.tree.map(lambda x: x / len(config.checkpoint_paths), dense_model)
else:
avg_model = jax.tree.map(lambda x, avg: avg + x / len(config.checkpoint_paths), dense_model, avg_model)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably better to put each of these in jit

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants

Comments