diff --git a/roll/pipeline/sft/sft_worker.py b/roll/pipeline/sft/sft_worker.py index 8d63bd51..ae6bf63b 100644 --- a/roll/pipeline/sft/sft_worker.py +++ b/roll/pipeline/sft/sft_worker.py @@ -46,7 +46,13 @@ def val_step(self, data: DataProto): data = data.to(current_platform.device_type) data.meta_info["micro_batch_size"] = self.worker_config.infer_batch_size data = self.strategy.get_data_input(data) - metrics = self.strategy.forward_step(batch=data, forward_func=self.loss_func) + + loss_func = self.loss_func + if self.worker_config.use_sequence_packing: + from roll.utils.sequence_packing import SequencePackingSFTLossWrapper + loss_func = SequencePackingSFTLossWrapper(self.strategy, loss_func) + + metrics = self.strategy.forward_step(batch=data, forward_func=loss_func) output = DataProto(meta_info={"metrics": metrics}).to("cpu") return output @@ -68,4 +74,4 @@ def do_checkpoint(self, global_step): def loss_func(self, data: DataProto, output_tensor: torch.Tensor): labels = data.batch["labels"] - return self.strategy.op_compute_language_loss(output_tensor, labels) \ No newline at end of file + return self.strategy.op_compute_language_loss(output_tensor, labels)