|
18 | 18 | #ifndef THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
19 | 19 | #define THIRD_PARTY_GEMMA_CPP_CONFIGS_H_
|
20 | 20 |
|
21 |
| -#include <cstddef> |
| 21 | +#include <stddef.h> |
22 | 22 |
|
23 | 23 | namespace gcpp {
|
24 | 24 |
|
25 | 25 | static constexpr size_t kSeqLen = 7168;
|
26 | 26 |
|
27 | 27 | struct ConfigGemma7B {
|
28 |
| - // NOLINTBEGIN(google3-readability-class-member-naming) |
29 |
| - static constexpr int seq_len = kSeqLen; |
30 |
| - static constexpr int vocab_size = 256128; |
31 |
| - static constexpr int n_layers = 28; |
32 |
| - static constexpr int dim_model = 3072; |
33 |
| - static constexpr int dim_ffw_hidden = 16 * 3072 / 2; // = 24576 |
34 |
| - static constexpr int n_heads = 16; |
35 |
| - static constexpr int n_kv_heads = 16; // standard MHA, no GQA or MQA |
36 |
| - static constexpr int dim_qkv = 256; // query size == key size == value size |
37 |
| - static constexpr int top_k = 1; |
38 |
| - // NOLINTEND(google3-readability-class-member-naming) |
| 28 | + static constexpr int kSeqLen = gcpp::kSeqLen; |
| 29 | + static constexpr int kVocabSize = 256128; |
| 30 | + static constexpr int kLayers = 28; |
| 31 | + static constexpr int kModelDim = 3072; |
| 32 | + static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576 |
| 33 | + static constexpr int kHeads = 16; |
| 34 | + static constexpr int kKVHeads = 16; // standard MHA, no GQA or MQA |
| 35 | + static constexpr int kQKVDim = 256; // query size == key size == value size |
| 36 | + static constexpr int kTopK = 1; |
39 | 37 | };
|
40 | 38 |
|
41 | 39 | struct ConfigGemma2B {
|
42 |
| - // NOLINTBEGIN(google3-readability-class-member-naming) |
43 |
| - static constexpr int seq_len = kSeqLen; |
44 |
| - static constexpr int vocab_size = 256128; |
45 |
| - static constexpr int n_layers = 18; |
46 |
| - static constexpr int dim_model = 2048; |
47 |
| - static constexpr int dim_ffw_hidden = 16 * 2048 / 2; // = 16384 |
48 |
| - static constexpr int n_heads = 8; |
49 |
| - static constexpr int n_kv_heads = 8; // TODO(austinvhuang): add MQA support |
50 |
| - static constexpr int dim_qkv = 256; // query size == key size == value size |
51 |
| - static constexpr int top_k = 1; |
52 |
| - // NOLINTEND(google3-readability-class-member-naming) |
| 40 | + static constexpr int kSeqLen = gcpp::kSeqLen; |
| 41 | + static constexpr int kVocabSize = 256128; |
| 42 | + static constexpr int kLayers = 18; |
| 43 | + static constexpr int kModelDim = 2048; |
| 44 | + static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384 |
| 45 | + static constexpr int kHeads = 8; |
| 46 | + static constexpr int kKVHeads = 8; // TODO(austinvhuang): add MQA support |
| 47 | + static constexpr int kQKVDim = 256; // query size == key size == value size |
| 48 | + static constexpr int kTopK = 1; |
53 | 49 | };
|
54 | 50 |
|
55 | 51 | } // namespace gcpp
|
|
0 commit comments