8
8
from tensorflow .python .ops .numpy_ops import np_config
9
9
np_config .enable_numpy_behavior ()
10
10
from src .dataclass import Context
11
-
11
+ tf . compat . v1 . enable_eager_execution ()
12
12
def split_norm (inp : tf .Tensor ) -> tf .Tensor :
13
13
scale0 , scale1 , shift = tf .split (inp ,3 , 1 )
14
- return tf .norm (tf .add (tf .matmul (scale0 , scale1 ) , shift ))
14
+ return tf .norm (tf .add (tf .multiply (scale0 , scale1 ) , shift ))
15
15
16
16
def norm (out : tf .Tensor ) -> tf .Tensor :
17
17
out = out - out .mean (1 , keepdim = True )
18
- return tf .divide (out , (tf .add (tf .math .pow (tf .matmul (tf .norm (out , (2 , 1 )) , out .size (1 )) , - 0.5 ), 1e-5 )))
18
+ return tf .divide (out , (tf .add (tf .math .pow (tf .multiply (tf .norm (out , (2 , 1 )) , out .size (1 )) , - 0.5 ), 1e-5 )))
19
19
20
20
21
21
def conv (inp : tf .Tensor , weight : tf .Tensor , groups : int , use_pad : bool ) -> tf .Tensor :
@@ -49,7 +49,7 @@ def orthonormal(inp: typing.Union[tf.Tensor, tf.Variable, typing.List[int]], gai
49
49
a = g1 .normal (flat_shape )
50
50
u , _ , v = tf .linalg .svd (a , full_matrices = False )
51
51
52
- inp = tf . math . multiply (( u if u .shape == flat_shape else v ), gain )
52
+ inp = ( u if u .shape == flat_shape else v )* gain
53
53
if isinstance (original_input , list ):
54
54
return tf .Variable (inp )
55
55
return original_input
@@ -61,13 +61,12 @@ def moe(inp: tf.Tensor, w: typing.List[tf.Variable],
61
61
gates = tf .nn .softmax (out , dim = 1 )
62
62
one_hot = tf .one_hot (tf .argmax (out , dim = 1 ), out .shape [1 ])
63
63
gumbel = one_hot .transpose (1 , 2 ) - gates .detach () + gates
64
- one_hot = one_hot .to (dtype = tf .bool )
65
64
inp_t = inp .transpose (1 , 2 )
66
65
batch , features , sequence = inp .size ()
67
66
out = tf .zeros ((batch * sequence , w [0 ].size (1 )), dtype = inp .dtype )
68
67
for expert , g , param in zip (one_hot .unbind (- 1 ), gumbel .unbind (1 ), w ):
69
68
tmp = tf .boolean_mask (inp_t * g .unsqueeze (2 ), expert .unsqueeze (2 )).view (- 1 , features ).mm (param )
70
- out = out .masked_scatter (expert .view (- 1 , 1 ), tmp )
69
+ out = out .boolean_mask (expert .view (- 1 , 1 ), tmp )
71
70
loss = tf .math .reduce_sum (tf .math .reduce_mean (gates , dim = (0 , 2 )) * tf .math .reduce_mean (one_hot .float (), dim = (0 , 1 )))
72
71
return loss , out .view (batch , sequence , - 1 ).transpose (1 , 2 )
73
72
@@ -133,56 +132,65 @@ def conv_weight(in_features: int, out_features: int, kernel_size: int, groups: i
133
132
local_conv .build (in_features )
134
133
return orthonormal ( local_conv .kernel , 1 / std )
135
134
136
- class Trainer (tf . keras . Model ):
135
+ class Trainer (object ):
137
136
def __init__ (self ,model ):
138
137
super (Trainer , self ).__init__ ()
139
138
140
139
self .model = model
141
140
self .optimizer = tf .keras .optimizers .Adam ()
142
141
142
+ def softargmax (self ,x , beta = 1e10 ):
143
+ x_range = tf .range (x .shape .as_list ()[- 1 ], dtype = x .dtype )
144
+ return tf .reduce_sum (tf .nn .softmax (x * beta ) * x_range , axis = - 1 )
145
+ @tf .function
146
+ def _forward_backward (self , src_arr : tf .Tensor , tgt_arr : tf .Tensor ) -> tf .Tensor :
147
+
148
+ with tf .GradientTape () as tape :
149
+ tape .watch (self .model .trainable_variables )
150
+ loss = 0
151
+ for (s , t ), _ in zip (src_arr , tgt_arr ):
152
+ src = s .squeeze (0 )
153
+ tgt = t .squeeze (0 )
154
+ model_out = self .model (np .array (src ))
155
+ local_tgt = []
156
+ for row in tgt :
157
+ lc = [0.0 ] * 256
158
+ lc [np .argmax (row ).numpy ()] = 1.0
159
+ local_tgt .append (lc )
160
+
161
+ # loss += tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(tf.reshape(model_out[:,1],(131072,1)),tf.reshape(local_tgt,(131072,1)))
162
+ loss += tf .keras .losses .CategoricalCrossentropy ()(model_out [:,1 ],local_tgt )
143
163
144
- def _to_device_detach (self , inp : tf .Tensor ) -> tf .Tensor :
145
- return inp .to (device = self .ctx .model .device , non_blocking = True ).detach ()
146
-
147
- def _forward_backward (self , src : tf .Tensor , tgt : tf .Tensor ) -> tf .Tensor :
148
- src = src .cpu ().detach ().numpy ()
149
- tgt = tgt .cpu ().detach ().numpy ()/ 128.0
150
-
151
- with tf .GradientTape (persistent = True ) as tape :
152
- model_out = self .model (src )
153
- model_out = model_out / 128.0
154
- #print('model_out')
155
- #print(model_out[:,:,1].shape)
156
- #print(tgt.shape)
157
- loss = tf .keras .losses .binary_crossentropy (model_out [:,:,1 ], tgt )
158
164
gradients = tape .gradient (loss , self .model .trainable_variables )
159
- #print(loss)
160
- #print(gradients)
165
+ print (loss )
166
+
167
+ print ('-------------------------------' )
168
+ #print(model_out)
169
+ print ('gradients' ,gradients )
161
170
162
- #model_out = self.model(src)
163
- #model_out = np.array(model_out)/128.0
164
- #print(tgt.shape)
165
- #print(model_out.shape)
166
- #loss = tf.keras.losses.binary_crossentropy(model_out, tgt)
167
- return loss
171
+ return gradients
168
172
169
173
def _clip_gradient (self ,gradients ):
170
174
for p in gradients :
175
+ print (p )
171
176
if type (p ) == tf .IndexedSlices :
172
- p_v = p .values
173
- for row in p_v :
177
+ p = p .values
178
+ print (p )
179
+ for row in p :
174
180
g_norm = tf .clip_by_value (row ,clip_value_min = self .model .ctx .optimizer .agc .zero_division_eps ,clip_value_max = 1000 )
175
181
p_norm = tf .clip_by_value (row ,clip_value_min = self .model .ctx .optimizer .agc .eps ,clip_value_max = 1000 )
176
182
grad_scale = tf .clip_by_value ((p_norm / g_norm * self .model .ctx .optimizer .agc .gradient_clipping ),clip_value_min = - 1000 ,clip_value_max = 1 )
177
183
row = row * grad_scale
178
- #print(row)
184
+ print ('row' )
185
+ print (row )
179
186
180
187
def accumulated_step (self , dataloader ) -> tf .Tensor :
181
- with tf .GradientTape (persistent = True ) as tape :
182
- loss = sum (self ._forward_backward (s .squeeze (0 ), t .squeeze (0 )) for (s , t ), _ in
183
- zip (dataloader , range (self .model .ctx .optimizer .gradient_accumulation_steps )))
184
188
185
- gradients = tape .gradient (loss , self .model .trainable_variables )
189
+ gradients = self ._forward_backward (dataloader , range (self .model .ctx .optimizer .gradient_accumulation_steps ))
190
+ # add sum into the self.__forward_backward gradient decent
191
+ #sum(self._forward_backward(s.squeeze(0), t.squeeze(0)) for (s, t), _ in zip(dataloader, range(self.model.ctx.optimizer.gradient_accumulation_steps)))
192
+
193
+ print ('gradients' , gradients )
186
194
#print( "Gradients")
187
195
#print( gradients)
188
196
@@ -207,7 +215,7 @@ def accumulated_step(self, dataloader) -> tf.Tensor:
207
215
208
216
self .gradients_vars [i ] = p
209
217
self ._clip_gradient (gradients )
210
- return loss
218
+ return gradients
211
219
212
220
def zero_grad (self ):
213
221
for p in self .model .parameters ():
@@ -235,7 +243,7 @@ def __init__(self, beta: float):
235
243
236
244
def forward (self , inp : tf .Tensor ):
237
245
return tf .matmul (inp , self .beta )
238
-
246
+ '''
239
247
class LinearAttention(tf.keras.Model):
240
248
def __init__(self, ctx: Context):
241
249
super(LinearAttention, self).__init__()
@@ -248,25 +256,44 @@ def __init__(self, ctx: Context):
248
256
pos_embd = tf.range(0, ctx.model.sequence_length)
249
257
#self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device))
250
258
251
- cell = LinearAttentionCell (self , ctx , 1 )
252
- self .stem = revlib .ReversibleSequential (* [c
253
- for i in range (1 , 1 + ctx .model .depth )
254
- for c in [cell .momentum ((1 - ctx .model .momentumnet_beta ) /
255
- ctx .model .momentumnet_beta ** i ),
256
- MomentumNetSide (ctx .model .momentumnet_beta ** i )]],
257
- target_device = ctx .model .device )
259
+ self.cell = LinearAttentionCell(self, ctx, 1)
260
+
258
261
local_conv1d = tf.keras.layers.Conv1D(filters=ctx.dataset.classes, kernel_size=(1,))
259
262
self.local_output = local_conv1d
260
263
261
264
def call(self, inp: tf.Tensor,traing=None,mask=None):
262
- return self .local_output (self .embedding (inp ).transpose ())
265
+ return self.local_output(self.cell(self. embedding(inp).transpose() ))
263
266
264
267
def reset_cache(self):
265
268
for mod in self.stem.modules():
266
269
if isinstance(mod, LinearAttentionCell):
267
270
mod.reset_cache()
271
+ '''
272
+ class LinearAttention (tf .keras .Model ):
273
+ def __init__ (self , ctx : Context ):
274
+ super (LinearAttention , self ).__init__ ()
275
+ self .ctx = ctx
276
+ self .embedding = tf .keras .layers .Embedding (ctx .dataset .classes , ctx .model .features * 2 )
277
+ self .embedding .build (ctx .dataset .classes )
278
+
279
+ orthonormal (self .embedding .embeddings , ctx .model .input_embedding_std * 2 ** - 0.5 )
280
+
281
+ pos_embd = tf .range (0 , ctx .model .sequence_length )
282
+ #self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device))
283
+
284
+ self .cell = LinearAttentionCell (self , ctx , 1 )
268
285
286
+ local_conv1d = tf .keras .layers .Dense (256 )#tf.keras.layers.Conv1D(filters=ctx.dataset.classes, kernel_size=(1,))#tf.keras.layers.Dense(256)#
287
+ self .local_output = local_conv1d #local_conv1d
269
288
289
+ def call (self , inp : tf .Tensor ,traing = None ,mask = None ):
290
+ #return self.embedding(inp).transpose()
291
+ return self .local_output (self .cell (self .embedding (inp ).transpose ()))
292
+
293
+ def reset_cache (self ):
294
+ for mod in self .stem .modules ():
295
+ if isinstance (mod , LinearAttentionCell ):
296
+ mod .reset_cache ()
270
297
class ParameterStore (object ):
271
298
"""
272
299
Something (likely deepspeed) changes all parameters in a ParameterList to [1] even though standalone parameters
@@ -326,7 +353,7 @@ def forward(self, inp: tf.Tensor) -> tf.Tensor:
326
353
div = self .divisor ()
327
354
elif self .caching :
328
355
self .idx += inp .size (2 )
329
- div = tf .Tensor ([self .idx ]). to ( inp . device )
356
+ div = tf .Tensor ([self .idx ])
330
357
else :
331
358
self .idx = inp .size (2 )
332
359
div = tf .range (self .idx , device = inp .device ).view (1 , 1 , - 1 ) + 1
0 commit comments