@@ -327,6 +327,7 @@ def main_train(args: arg_util.Args):
327327 # build wandb logger
328328 if dist .is_master ():
329329 wandb_utils .wandb .init (project = args .project_name , name = args .exp_name , config = {})
330+
330331 for ep in range (start_ep , args .ep ):
331332 if ep % ep_lg == 0 or ep == start_ep :
332333 print (f'[PT info] from ep{ start_ep } it{ start_it } , acc_str: { acc_str } , diffs: { args .diffs } , =======> bed: { args .bed } <=======\n ' )
@@ -483,10 +484,15 @@ def train_one_ep(
483484 with maybe_record_function ('before_train' ):
484485 # [get data]
485486 inp , captions = data
486- tokens = text_tokenizer (text = captions , max_length = text_tokenizer .model_max_length , padding = 'max_length' , truncation = True , return_tensors = 'pt' ) # todo: put this into dataset
487+ tokens = text_tokenizer (text = captions , max_length = text_tokenizer .model_max_length ,
488+ padding = 'max_length' , truncation = True , return_tensors = 'pt' ) # todo: put this into dataset
489+ print ("gongwb tokens:" , tokens )
490+
487491 input_ids = tokens .input_ids .cuda (non_blocking = True )
488492 mask = tokens .attention_mask .cuda (non_blocking = True )
493+
489494 text_features = text_encoder (input_ids = input_ids , attention_mask = mask )['last_hidden_state' ].float ()
495+ print ("gongwb text_features:" , text_features )
490496
491497 lens : List [int ] = mask .sum (dim = - 1 ).tolist ()
492498 cu_seqlens_k = F .pad (mask .sum (dim = - 1 ).to (dtype = torch .int32 ).cumsum_ (0 ), (1 , 0 ))
@@ -521,7 +527,8 @@ def train_one_ep(
521527 step_cnt += int (stepping )
522528
523529 with maybe_record_function ('in_training' ):
524- grad_norm_t , scale_log2_t = trainer .train_step (
530+ #grad_norm_t, scale_log2_t =
531+ trainer .train_step (
525532 ep = ep , it = it , g_it = g_it , stepping = stepping , clip_decay_ratio = clip_decay_ratio ,
526533 metric_lg = me ,
527534 logging_params = stepping and step_cnt == 1 and (ep < 4 or ep in logging_params_milestone ),
0 commit comments