1
+ config = {
2
+ "type_vocab_size" : 0 ,
3
+
4
+
5
+ }
6
+
7
+ import os
8
+ os .environ ['KAGGLE_USERNAME' ] = 'passlin'
9
+ os .environ ['KAGGLE_KEY' ] = '848acf174f9616e8eeae54691d758f93'
10
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = "0"
11
+ batch_size = 14
12
+ epochs = 15
13
+ os .environ ["KERAS_BACKEND" ] = "torch"
14
+ import keras
15
+ import keras_nlp
16
+ from bert4keras3 .models import build_transformer_model
17
+ from bert4keras3 .snippets import sequence_padding
18
+ keras .config .set_dtype_policy ("bfloat16" )
19
+ model_name = "gemma2_9b_en"
20
+ import torch
21
+ with torch .no_grad ():
22
+ try :
23
+ os .makedirs (model_name )
24
+ except :
25
+ pass
26
+ gemma = keras_nlp .models .GemmaCausalLM .from_preset (model_name )
27
+ from keras import ops
28
+ backbone = gemma .get_layer ('gemma_backbone' )
29
+ layers_dict = {}
30
+ for layer in backbone .layers :
31
+ if layer .weights != []:
32
+ layers_dict [layer .name ]= layer
33
+ gemma_config = backbone .get_config ()
34
+ config [ "vocab_size" ]= gemma_config ['vocabulary_size' ]
35
+ config [ "num_hidden_layers" ]= gemma_config ['num_layers' ]
36
+ config [ "query_head" ]= gemma_config ['num_query_heads' ]
37
+ config [ "num_attention_heads" ]= gemma_config ['num_key_value_heads' ]
38
+ config [ "hidden_size" ]= gemma_config ['hidden_dim' ]
39
+ config [ "intermediate_size" ]= gemma_config ['intermediate_dim' ]// 2
40
+ config [ "attention_head_size" ]= gemma_config ['head_dim' ]
41
+ config [ "attention_probs_dropout_prob" ]= gemma_config ['dropout' ]
42
+ config [ "dropout_rate" ]= gemma_config ['dropout' ]
43
+ config [ "use_post_ffw_norm" ]= backbone .use_post_ffw_norm
44
+ config [ "use_post_attention_norm" ] = backbone .use_post_attention_norm
45
+ config [ "logit_soft_cap" ] = layers_dict ['decoder_block_' + str (0 )].attention .logit_soft_cap
46
+ config [ "use_sliding_window_attention" ] = layers_dict ['decoder_block_' + str (0 )].attention .use_sliding_window_attention
47
+ config [ "sliding_window_size" ] = layers_dict ['decoder_block_' + str (0 )].attention .sliding_window_size
48
+ config [ "query_head_dim_normalize" ] = layers_dict ['decoder_block_' + str (0 )].attention .query_head_dim_normalize
49
+ hidden_dim = config [ "hidden_size" ]
50
+ import json
51
+ with open (model_name + '/config.json' , 'w' ) as f :
52
+ json .dump (config , f , indent = 4 , ensure_ascii = False )
53
+ gemma .eval ()
54
+ self = build_transformer_model (
55
+ config_path = model_name + '/config.json' ,
56
+ model = 'gemma2' ,
57
+ return_keras_model = False ,
58
+ with_lm = 'linear' ,
59
+ )
60
+ MyGemma = self .model
61
+ MyGemma .eval ()
62
+ gemma .summary ()
63
+
64
+
65
+
66
+
67
+ def get_weights (layer ,i ):
68
+ return layer .weights [i ].value
69
+ embeding_weights = [get_weights (layers_dict ['token_embedding' ],0 )]
70
+ MyGemma .get_layer ('Embedding-Token' ).set_weights (embeding_weights )
71
+
72
+ fln_weights = [get_weights (layers_dict ['final_normalization' ],0 )]
73
+ MyGemma .get_layer ('Output-Norm' ).set_weights (fln_weights )
74
+ from tqdm import tqdm
75
+ for i in tqdm (range (gemma_config ['num_layers' ])):
76
+ block = layers_dict ['decoder_block_' + str (i )]
77
+ attention_name = 'Transformer-%d-MultiHeadSelfAttention' % i
78
+ feed_forward_name = 'Transformer-%d-FeedForward' % i
79
+
80
+
81
+ MyGemma .get_layer ('%s-Norm' % attention_name ).set_weights ([block .weights [0 ]])
82
+ MyGemma .get_layer ('%s-Norm-post' % attention_name ).set_weights ([block .weights [1 ]])
83
+ MyGemma .get_layer (attention_name ).set_weights (block .weights [2 :6 ])
84
+
85
+ MyGemma .get_layer ('%s-Norm' % feed_forward_name ).set_weights ([block .weights [6 ]])
86
+
87
+ MyGemma .get_layer ('%s-Norm-post' % feed_forward_name ).set_weights ([block .weights [7 ]])
88
+ MyGemma .get_layer (feed_forward_name ).set_weights (block .weights [8 :])
89
+
90
+
91
+ import numpy as np
92
+ MyGemma .save_weights (model_name + '/model.weights.h5' )
93
+ print ('saving' )
94
+ x = np .random .randint (1 ,100000 ,[3 ,128 ])
95
+ x1 = gemma ([np .ones_like (x ),x ])
96
+ x2 = MyGemma (x )
97
+ print ('-' * 20 )
98
+ print (keras .ops .mean (keras .ops .abs (x1 - x2 )))
99
+ print (keras .ops .max (keras .ops .abs (x1 - x2 )))
0 commit comments