Skip to content

Commit a94818f

Browse files
authored
Add files via upload
1 parent f2ec866 commit a94818f

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

convert-weights/convert_gemma2.py

+99
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
config={
2+
"type_vocab_size": 0,
3+
4+
5+
}
6+
7+
import os
8+
os.environ['KAGGLE_USERNAME'] = 'passlin'
9+
os.environ['KAGGLE_KEY'] = '848acf174f9616e8eeae54691d758f93'
10+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11+
batch_size=14
12+
epochs = 15
13+
os.environ["KERAS_BACKEND"] = "torch"
14+
import keras
15+
import keras_nlp
16+
from bert4keras3.models import build_transformer_model
17+
from bert4keras3.snippets import sequence_padding
18+
keras.config.set_dtype_policy("bfloat16")
19+
model_name = "gemma2_9b_en"
20+
import torch
21+
with torch.no_grad():
22+
try:
23+
os.makedirs(model_name)
24+
except:
25+
pass
26+
gemma = keras_nlp.models.GemmaCausalLM.from_preset(model_name)
27+
from keras import ops
28+
backbone = gemma.get_layer('gemma_backbone')
29+
layers_dict = {}
30+
for layer in backbone.layers:
31+
if layer.weights!=[]:
32+
layers_dict[layer.name]=layer
33+
gemma_config = backbone.get_config()
34+
config[ "vocab_size"]=gemma_config['vocabulary_size']
35+
config[ "num_hidden_layers"]=gemma_config['num_layers']
36+
config[ "query_head"]=gemma_config['num_query_heads']
37+
config[ "num_attention_heads"]=gemma_config['num_key_value_heads']
38+
config[ "hidden_size"]=gemma_config['hidden_dim']
39+
config[ "intermediate_size"]=gemma_config['intermediate_dim']//2
40+
config[ "attention_head_size"]=gemma_config['head_dim']
41+
config[ "attention_probs_dropout_prob"]=gemma_config['dropout']
42+
config[ "dropout_rate"]=gemma_config['dropout']
43+
config[ "use_post_ffw_norm"]=backbone.use_post_ffw_norm
44+
config[ "use_post_attention_norm"] = backbone.use_post_attention_norm
45+
config[ "logit_soft_cap"] = layers_dict['decoder_block_'+str(0)].attention.logit_soft_cap
46+
config[ "use_sliding_window_attention"] = layers_dict['decoder_block_'+str(0)].attention.use_sliding_window_attention
47+
config[ "sliding_window_size"] = layers_dict['decoder_block_'+str(0)].attention.sliding_window_size
48+
config[ "query_head_dim_normalize"] = layers_dict['decoder_block_'+str(0)].attention.query_head_dim_normalize
49+
hidden_dim=config[ "hidden_size"]
50+
import json
51+
with open(model_name+'/config.json', 'w') as f:
52+
json.dump(config, f, indent=4, ensure_ascii=False)
53+
gemma.eval()
54+
self = build_transformer_model(
55+
config_path=model_name+'/config.json',
56+
model='gemma2',
57+
return_keras_model=False,
58+
with_lm='linear',
59+
)
60+
MyGemma= self.model
61+
MyGemma.eval()
62+
gemma.summary()
63+
64+
65+
66+
67+
def get_weights(layer,i):
68+
return layer.weights[i].value
69+
embeding_weights = [get_weights(layers_dict['token_embedding'],0)]
70+
MyGemma.get_layer('Embedding-Token').set_weights(embeding_weights)
71+
72+
fln_weights = [get_weights(layers_dict['final_normalization'],0)]
73+
MyGemma.get_layer('Output-Norm').set_weights(fln_weights)
74+
from tqdm import tqdm
75+
for i in tqdm(range(gemma_config['num_layers'])):
76+
block = layers_dict['decoder_block_'+str(i)]
77+
attention_name = 'Transformer-%d-MultiHeadSelfAttention' % i
78+
feed_forward_name = 'Transformer-%d-FeedForward' % i
79+
80+
81+
MyGemma.get_layer('%s-Norm' % attention_name).set_weights([block.weights[0]])
82+
MyGemma.get_layer('%s-Norm-post' % attention_name).set_weights([block.weights[1]])
83+
MyGemma.get_layer(attention_name).set_weights(block.weights[2:6])
84+
85+
MyGemma.get_layer('%s-Norm' % feed_forward_name).set_weights([block.weights[6]])
86+
87+
MyGemma.get_layer('%s-Norm-post' % feed_forward_name).set_weights([block.weights[7]])
88+
MyGemma.get_layer(feed_forward_name).set_weights(block.weights[8:])
89+
90+
91+
import numpy as np
92+
MyGemma.save_weights(model_name+'/model.weights.h5')
93+
print('saving')
94+
x = np.random.randint(1,100000,[3,128])
95+
x1 = gemma([np.ones_like(x),x])
96+
x2 = MyGemma(x)
97+
print('-'*20)
98+
print(keras.ops.mean(keras.ops.abs(x1-x2)))
99+
print(keras.ops.max(keras.ops.abs(x1-x2)))

0 commit comments

Comments
 (0)