Skip to content

Commit

Permalink
Add option for tied embeddings (#38)
Browse files Browse the repository at this point in the history
* Add tied embedding option

* update readme with new tied embedding flag
  • Loading branch information
Quentin-Anthony authored Apr 19, 2024
1 parent 33a1a82 commit 86ff050
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
7 changes: 5 additions & 2 deletions calc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,14 @@ options:
```
Example with Fairseq-MoE 15B: python calc_transformer_params.py -l 12 -hs 768 --moe -e 512
Example with GPT-3 175B: python calc_transformer_params.py -l 96 -hs 12288
usage: calc_transformer_params.py [-h] [--vocab-size VOCAB_SIZE] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--moe] [--num-experts NUM_EXPERTS] [--expert-interval EXPERT_INTERVAL]
[--topk TOPK] [--ffn-expansion-factor FFN_EXPANSION_FACTOR]
usage: calc_transformer_params.py [-h] [--vocab-size VOCAB_SIZE] [--tied-embeddings] [--hidden-size HIDDEN_SIZE] [--sequence-length SEQUENCE_LENGTH] [--num-layers NUM_LAYERS] [--moe] [--num-experts NUM_EXPERTS]
[--expert-interval EXPERT_INTERVAL] [--topk TOPK] [--ffn-expansion-factor FFN_EXPANSION_FACTOR] [--kv-size-ratio KV_SIZE_RATIO]
options:
-h, --help show this help message and exit
--vocab-size VOCAB_SIZE, -v VOCAB_SIZE
Size of the vocab
--tied-embeddings Whether embeddings are tied (shared between input and output)
--hidden-size HIDDEN_SIZE, -hs HIDDEN_SIZE
Dimension of the model's hidden size
--sequence-length SEQUENCE_LENGTH, -s SEQUENCE_LENGTH
Expand All @@ -77,6 +78,8 @@ options:
--topk TOPK, -t TOPK Top k routing for MoE
--ffn-expansion-factor FFN_EXPANSION_FACTOR, -ff FFN_EXPANSION_FACTOR
How much the MLP hidden size expands
--kv-size-ratio KV_SIZE_RATIO, -kv KV_SIZE_RATIO
What fraction of num. query heads is num. key/value heads
```


Expand Down
10 changes: 8 additions & 2 deletions calc/calc_transformer_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def config_parser():
type=int,
default=51200,
help='Size of the vocab')
parser.add_argument("--tied-embeddings",
action="store_true",
help='Whether embeddings are tied (shared between input and output)')
parser.add_argument("--hidden-size", "-hs",
type=int,
default=6144,
Expand Down Expand Up @@ -58,8 +61,11 @@ def config_parser():

# calculates the params of a model given their hparams
def calc_params(args):
# Assumes that the embedding and unembedding are tied
embedding_params = args.hidden_size * args.vocab_size
# Calculate embedding and unembedding params. If tied, re-use the same params
if args.tied_embeddings:
embedding_params = args.hidden_size * args.vocab_size
else:
embedding_params = 2 * args.hidden_size * args.vocab_size
position_embedding_params = args.hidden_size * args.sequence_length
# Each QKVO matrix is (hxh)
# Unless using GQA/MQA which makes K/V smaller
Expand Down

0 comments on commit 86ff050

Please sign in to comment.