@@ -171,7 +171,8 @@ def local_load(name, **load_config):
171171 dataset_dict = train_dataset .train_test_split (
172172 test_size = validation_split_percentage / 100
173173 )
174- dataset_dict ["validation" ] = dataset_dict ["test" ]
174+ test_dataset = dataset_dict .pop ("test" )
175+ dataset_dict ["validation" ] = test_dataset
175176 return dataset_dict
176177
177178 return datasets .DatasetDict ({"train" : train_dataset })
@@ -188,7 +189,8 @@ def local_load(name, **load_config):
188189 dataset_dict = raw_dataset ["train" ].train_test_split (
189190 test_size = validation_split_percentage / 100
190191 )
191- dataset_dict ["validation" ] = dataset_dict ["test" ]
192+ test_dataset = dataset_dict .pop ("test" )
193+ dataset_dict ["validation" ] = test_dataset
192194 return dataset_dict
193195
194196 return raw_dataset
@@ -367,10 +369,20 @@ def train_func(config: Dict[str, Any]):
367369
368370 training_args , trainer = get_trainer (config , model , tokenizer , tokenized_dataset , data_collator )
369371
370- common .logger .info ("train start" )
371- trainer .train (resume_from_checkpoint = training_args .resume_from_checkpoint )
372- trainer .save_model ()
373- common .logger .info ("train finish" )
372+ if training_args .do_train :
373+ common .logger .info ("train start" )
374+ result = trainer .train (resume_from_checkpoint = training_args .resume_from_checkpoint )
375+ trainer .save_model ()
376+ metrics = result .metrics
377+ metrics ["throughput" ] = len (tokenized_dataset ['train' ]) / metrics ["train_runtime" ]
378+ trainer .log_metrics ("train" , metrics )
379+ common .logger .info ("train finish" )
380+
381+ if training_args .do_eval :
382+ common .logger .info ("eval start" )
383+ metrics = trainer .evaluate ()
384+ trainer .log_metrics ("eval" , metrics )
385+ common .logger .info ("eval finish" )
374386
375387
376388def get_finetune_config ():
0 commit comments