Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit fe598cb

Browse files
authored
Cherry pick fix transformers prediction (#717)
* Fix for prediction step when teacher model has more inputs than student. * Updated signature of prediction_step method.
1 parent 32e6d84 commit fe598cb

File tree

1 file changed

+19
-0
lines changed
  • src/sparseml/transformers/sparsification

1 file changed

+19
-0
lines changed

src/sparseml/transformers/sparsification/trainer.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,25 @@ def compute_loss(
365365

366366
return (loss, student_outputs) if return_outputs else loss
367367

368+
def prediction_step(
369+
self,
370+
model: Module,
371+
inputs: Dict[str, Union[torch.Tensor, Any]],
372+
prediction_loss_only: bool,
373+
ignore_keys: Optional[List[str]] = None,
374+
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
375+
"""
376+
Wraps the prediction step from the original trainer to remove any input entry
377+
that should not be passed to model.
378+
This situation may arise when distillation is used and the teacher model
379+
contains more inputs than the student model.
380+
"""
381+
self._check_super_defined("prediction_step")
382+
383+
inputs = {k: inputs[k] for k in inputs if k in self._model_signature_columns}
384+
385+
return super().prediction_step(model, inputs, prediction_loss_only, ignore_keys)
386+
368387
def save_model(
369388
self,
370389
output_dir: Optional[str] = None,

0 commit comments

Comments
 (0)