From d7e58875a31f5cf79651ebd4af681090ab7c67ca Mon Sep 17 00:00:00 2001 From: XChen-Zero Date: Mon, 15 Dec 2025 14:29:51 +0800 Subject: [PATCH] Fix: apply sequence packing loss wrapper in SFT validation --- roll/pipeline/sft/sft_worker.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/roll/pipeline/sft/sft_worker.py b/roll/pipeline/sft/sft_worker.py index 8d63bd51..ae6bf63b 100644 --- a/roll/pipeline/sft/sft_worker.py +++ b/roll/pipeline/sft/sft_worker.py @@ -46,7 +46,13 @@ def val_step(self, data: DataProto): data = data.to(current_platform.device_type) data.meta_info["micro_batch_size"] = self.worker_config.infer_batch_size data = self.strategy.get_data_input(data) - metrics = self.strategy.forward_step(batch=data, forward_func=self.loss_func) + + loss_func = self.loss_func + if self.worker_config.use_sequence_packing: + from roll.utils.sequence_packing import SequencePackingSFTLossWrapper + loss_func = SequencePackingSFTLossWrapper(self.strategy, loss_func) + + metrics = self.strategy.forward_step(batch=data, forward_func=loss_func) output = DataProto(meta_info={"metrics": metrics}).to("cpu") return output @@ -68,4 +74,4 @@ def do_checkpoint(self, global_step): def loss_func(self, data: DataProto, output_tensor: torch.Tensor): labels = data.batch["labels"] - return self.strategy.op_compute_language_loss(output_tensor, labels) \ No newline at end of file + return self.strategy.op_compute_language_loss(output_tensor, labels)