Skip to content
This repository was archived by the owner on Sep 23, 2025. It is now read-only.

Commit 89a5fc7

Browse files
author
Wu, Gangsheng
committed
print metrics by rank0 for train and evalutation in fine-tuning
1 parent 7b16ced commit 89a5fc7

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

llm_on_ray/finetune/finetune.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def local_load(name, **load_config):
171171
dataset_dict = train_dataset.train_test_split(
172172
test_size=validation_split_percentage / 100
173173
)
174-
dataset_dict["validation"] = dataset_dict["test"]
174+
test_dataset = dataset_dict.pop("test")
175+
dataset_dict["validation"] = test_dataset
175176
return dataset_dict
176177

177178
return datasets.DatasetDict({"train": train_dataset})
@@ -188,7 +189,8 @@ def local_load(name, **load_config):
188189
dataset_dict = raw_dataset["train"].train_test_split(
189190
test_size=validation_split_percentage / 100
190191
)
191-
dataset_dict["validation"] = dataset_dict["test"]
192+
test_dataset = dataset_dict.pop("test")
193+
dataset_dict["validation"] = test_dataset
192194
return dataset_dict
193195

194196
return raw_dataset
@@ -367,10 +369,20 @@ def train_func(config: Dict[str, Any]):
367369

368370
training_args, trainer = get_trainer(config, model, tokenizer, tokenized_dataset, data_collator)
369371

370-
common.logger.info("train start")
371-
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
372-
trainer.save_model()
373-
common.logger.info("train finish")
372+
if training_args.do_train:
373+
common.logger.info("train start")
374+
result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
375+
trainer.save_model()
376+
metrics = result.metrics
377+
metrics["throughput"] = len(tokenized_dataset['train']) / metrics["train_runtime"]
378+
trainer.log_metrics("train", metrics)
379+
common.logger.info("train finish")
380+
381+
if training_args.do_eval:
382+
common.logger.info("eval start")
383+
metrics = trainer.evaluate()
384+
trainer.log_metrics("eval", metrics)
385+
common.logger.info("eval finish")
374386

375387

376388
def get_finetune_config():

0 commit comments

Comments
 (0)