Skip to content

Commit aaa5352

Browse files
committed
creative chaos edition
1 parent 871ce88 commit aaa5352

File tree

2 files changed

+80
-48
lines changed

2 files changed

+80
-48
lines changed

main.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pathlib
22
import typing
3-
3+
import tensorflow as tf
44
import argh
55
import yaml
66

@@ -65,6 +65,11 @@ def inference(generated_tokens: int = 20, temp: float = 0.2, config_path: str =
6565

6666

6767
if __name__ == '__main__':
68+
tf.debugging.experimental.enable_dump_debug_info(
69+
"/tmp/tfdbg2_logdir",
70+
tensor_debug_mode="FULL_HEALTH",
71+
circular_buffer_size=-1)
72+
6873
parser = argh.ArghParser()
6974
parser.add_commands([preprocess, train, inference])
7075
parser.dispatch()

src/model.py

+74-47
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from tensorflow.python.ops.numpy_ops import np_config
99
np_config.enable_numpy_behavior()
1010
from src.dataclass import Context
11-
11+
tf.compat.v1.enable_eager_execution()
1212
def split_norm(inp: tf.Tensor) -> tf.Tensor:
1313
scale0, scale1, shift = tf.split(inp,3, 1)
14-
return tf.norm(tf.add(tf.matmul(scale0 , scale1) , shift))
14+
return tf.norm(tf.add(tf.multiply(scale0 , scale1) , shift))
1515

1616
def norm(out: tf.Tensor) -> tf.Tensor:
1717
out = out - out.mean(1, keepdim=True)
18-
return tf.divide(out, (tf.add(tf.math.pow(tf.matmul (tf.norm(out, (2, 1)) , out.size(1)) , -0.5 ), 1e-5)))
18+
return tf.divide(out, (tf.add(tf.math.pow(tf.multiply (tf.norm(out, (2, 1)) , out.size(1)) , -0.5 ), 1e-5)))
1919

2020

2121
def conv(inp: tf.Tensor, weight: tf.Tensor, groups: int, use_pad: bool) -> tf.Tensor:
@@ -49,7 +49,7 @@ def orthonormal(inp: typing.Union[tf.Tensor, tf.Variable, typing.List[int]], gai
4949
a = g1.normal(flat_shape)
5050
u, _, v = tf.linalg.svd(a, full_matrices=False)
5151

52-
inp = tf.math.multiply((u if u.shape == flat_shape else v),gain)
52+
inp = (u if u.shape == flat_shape else v)*gain
5353
if isinstance(original_input, list):
5454
return tf.Variable(inp)
5555
return original_input
@@ -61,13 +61,12 @@ def moe(inp: tf.Tensor, w: typing.List[tf.Variable],
6161
gates = tf.nn.softmax(out, dim=1)
6262
one_hot = tf.one_hot(tf.argmax(out, dim=1), out.shape[1])
6363
gumbel = one_hot.transpose(1, 2) - gates.detach() + gates
64-
one_hot = one_hot.to(dtype=tf.bool)
6564
inp_t = inp.transpose(1, 2)
6665
batch, features, sequence = inp.size()
6766
out = tf.zeros((batch * sequence, w[0].size(1)), dtype=inp.dtype)
6867
for expert, g, param in zip(one_hot.unbind(-1), gumbel.unbind(1), w):
6968
tmp =tf.boolean_mask(inp_t * g.unsqueeze(2), expert.unsqueeze(2)).view(-1, features).mm(param)
70-
out = out.masked_scatter(expert.view(-1, 1), tmp)
69+
out = out.boolean_mask(expert.view(-1, 1), tmp)
7170
loss = tf.math.reduce_sum(tf.math.reduce_mean(gates, dim=(0, 2)) * tf.math.reduce_mean(one_hot.float(), dim=(0, 1)))
7271
return loss, out.view(batch, sequence, -1).transpose(1, 2)
7372

@@ -133,56 +132,65 @@ def conv_weight(in_features: int, out_features: int, kernel_size: int, groups: i
133132
local_conv.build(in_features)
134133
return orthonormal( local_conv.kernel, 1 / std)
135134

136-
class Trainer(tf.keras.Model):
135+
class Trainer(object):
137136
def __init__(self,model):
138137
super(Trainer, self).__init__()
139138

140139
self.model = model
141140
self.optimizer = tf.keras.optimizers.Adam()
142141

142+
def softargmax(self,x, beta=1e10):
143+
x_range = tf.range(x.shape.as_list()[-1], dtype=x.dtype)
144+
return tf.reduce_sum(tf.nn.softmax(x * beta) * x_range, axis=-1)
145+
@tf.function
146+
def _forward_backward(self, src_arr: tf.Tensor, tgt_arr: tf.Tensor) -> tf.Tensor:
147+
148+
with tf.GradientTape() as tape:
149+
tape.watch(self.model.trainable_variables)
150+
loss = 0
151+
for (s, t), _ in zip(src_arr, tgt_arr):
152+
src = s.squeeze(0)
153+
tgt = t.squeeze(0)
154+
model_out = self.model(np.array(src))
155+
local_tgt = []
156+
for row in tgt:
157+
lc = [0.0] * 256
158+
lc[np.argmax(row).numpy()] = 1.0
159+
local_tgt.append(lc)
160+
161+
# loss += tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False)(tf.reshape(model_out[:,1],(131072,1)),tf.reshape(local_tgt,(131072,1)))
162+
loss += tf.keras.losses.CategoricalCrossentropy()(model_out[:,1],local_tgt)
143163

144-
def _to_device_detach(self, inp: tf.Tensor) -> tf.Tensor:
145-
return inp.to(device=self.ctx.model.device, non_blocking=True).detach()
146-
147-
def _forward_backward(self, src: tf.Tensor, tgt: tf.Tensor) -> tf.Tensor:
148-
src = src.cpu().detach().numpy()
149-
tgt = tgt.cpu().detach().numpy()/128.0
150-
151-
with tf.GradientTape(persistent=True) as tape:
152-
model_out = self.model(src)
153-
model_out = model_out/128.0
154-
#print('model_out')
155-
#print(model_out[:,:,1].shape)
156-
#print(tgt.shape)
157-
loss = tf.keras.losses.binary_crossentropy(model_out[:,:,1], tgt)
158164
gradients = tape.gradient(loss, self.model.trainable_variables)
159-
#print(loss)
160-
#print(gradients)
165+
print(loss)
166+
167+
print('-------------------------------')
168+
#print(model_out)
169+
print('gradients',gradients)
161170

162-
#model_out = self.model(src)
163-
#model_out = np.array(model_out)/128.0
164-
#print(tgt.shape)
165-
#print(model_out.shape)
166-
#loss = tf.keras.losses.binary_crossentropy(model_out, tgt)
167-
return loss
171+
return gradients
168172

169173
def _clip_gradient(self,gradients):
170174
for p in gradients:
175+
print(p)
171176
if type(p) == tf.IndexedSlices:
172-
p_v = p.values
173-
for row in p_v:
177+
p = p.values
178+
print(p)
179+
for row in p:
174180
g_norm = tf.clip_by_value(row,clip_value_min=self.model.ctx.optimizer.agc.zero_division_eps,clip_value_max=1000)
175181
p_norm = tf.clip_by_value(row,clip_value_min=self.model.ctx.optimizer.agc.eps,clip_value_max=1000)
176182
grad_scale = tf.clip_by_value((p_norm / g_norm * self.model.ctx.optimizer.agc.gradient_clipping),clip_value_min=-1000,clip_value_max=1)
177183
row = row* grad_scale
178-
#print(row)
184+
print('row')
185+
print(row)
179186

180187
def accumulated_step(self, dataloader) -> tf.Tensor:
181-
with tf.GradientTape(persistent=True) as tape:
182-
loss = sum(self._forward_backward(s.squeeze(0), t.squeeze(0)) for (s, t), _ in
183-
zip(dataloader, range(self.model.ctx.optimizer.gradient_accumulation_steps)))
184188

185-
gradients = tape.gradient(loss, self.model.trainable_variables)
189+
gradients = self._forward_backward(dataloader, range(self.model.ctx.optimizer.gradient_accumulation_steps))
190+
# add sum into the self.__forward_backward gradient decent
191+
#sum(self._forward_backward(s.squeeze(0), t.squeeze(0)) for (s, t), _ in zip(dataloader, range(self.model.ctx.optimizer.gradient_accumulation_steps)))
192+
193+
print('gradients', gradients)
186194
#print( "Gradients")
187195
#print( gradients)
188196

@@ -207,7 +215,7 @@ def accumulated_step(self, dataloader) -> tf.Tensor:
207215

208216
self.gradients_vars[i] = p
209217
self._clip_gradient(gradients)
210-
return loss
218+
return gradients
211219

212220
def zero_grad(self):
213221
for p in self.model.parameters():
@@ -235,7 +243,7 @@ def __init__(self, beta: float):
235243

236244
def forward(self, inp: tf.Tensor):
237245
return tf.matmul(inp , self.beta)
238-
246+
'''
239247
class LinearAttention(tf.keras.Model):
240248
def __init__(self, ctx: Context):
241249
super(LinearAttention, self).__init__()
@@ -248,25 +256,44 @@ def __init__(self, ctx: Context):
248256
pos_embd = tf.range(0, ctx.model.sequence_length)
249257
#self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device))
250258
251-
cell = LinearAttentionCell(self, ctx, 1)
252-
self.stem = revlib.ReversibleSequential(*[c
253-
for i in range(1, 1 + ctx.model.depth)
254-
for c in [cell.momentum((1 - ctx.model.momentumnet_beta) /
255-
ctx.model.momentumnet_beta ** i),
256-
MomentumNetSide(ctx.model.momentumnet_beta ** i)]],
257-
target_device=ctx.model.device)
259+
self.cell = LinearAttentionCell(self, ctx, 1)
260+
258261
local_conv1d = tf.keras.layers.Conv1D(filters=ctx.dataset.classes, kernel_size=(1,))
259262
self.local_output = local_conv1d
260263
261264
def call(self, inp: tf.Tensor,traing=None,mask=None):
262-
return self.local_output(self.embedding(inp).transpose())
265+
return self.local_output(self.cell(self.embedding(inp).transpose()))
263266
264267
def reset_cache(self):
265268
for mod in self.stem.modules():
266269
if isinstance(mod, LinearAttentionCell):
267270
mod.reset_cache()
271+
'''
272+
class LinearAttention(tf.keras.Model):
273+
def __init__(self, ctx: Context):
274+
super(LinearAttention, self).__init__()
275+
self.ctx = ctx
276+
self.embedding =tf.keras.layers.Embedding(ctx.dataset.classes, ctx.model.features * 2)
277+
self.embedding.build(ctx.dataset.classes)
278+
279+
orthonormal(self.embedding.embeddings, ctx.model.input_embedding_std * 2 ** -0.5)
280+
281+
pos_embd = tf.range(0, ctx.model.sequence_length)
282+
#self.register_buffer("divisor", pos_embd.unsqueeze(0).to(torch.float).to(ctx.model.device))
283+
284+
self.cell = LinearAttentionCell(self, ctx, 1)
268285

286+
local_conv1d = tf.keras.layers.Dense(256)#tf.keras.layers.Conv1D(filters=ctx.dataset.classes, kernel_size=(1,))#tf.keras.layers.Dense(256)#
287+
self.local_output = local_conv1d#local_conv1d
269288

289+
def call(self, inp: tf.Tensor,traing=None,mask=None):
290+
#return self.embedding(inp).transpose()
291+
return self.local_output(self.cell(self.embedding(inp).transpose()))
292+
293+
def reset_cache(self):
294+
for mod in self.stem.modules():
295+
if isinstance(mod, LinearAttentionCell):
296+
mod.reset_cache()
270297
class ParameterStore(object):
271298
"""
272299
Something (likely deepspeed) changes all parameters in a ParameterList to [1] even though standalone parameters
@@ -326,7 +353,7 @@ def forward(self, inp: tf.Tensor) -> tf.Tensor:
326353
div = self.divisor()
327354
elif self.caching:
328355
self.idx += inp.size(2)
329-
div = tf.Tensor([self.idx]).to(inp.device)
356+
div = tf.Tensor([self.idx])
330357
else:
331358
self.idx = inp.size(2)
332359
div = tf.range(self.idx, device=inp.device).view(1, 1, -1) + 1

0 commit comments

Comments
 (0)