System Info
transformers 5.13.0.dev0 (main, 9fd7b67; also reproduced on 36193bf)
torch 2.11.0, Python 3.13.5, macOS 15.6 arm64 (forced CPU via use_cpu=True; not device-specific — also reproduces on the default MPS device)
Who can help?
@muellerzr @SunMarc
Reproduction
import torch, datasets
from transformers import (AutoModelForCausalLM, AutoTokenizer, Trainer,
TrainingArguments, DataCollatorForLanguageModeling, TrainerCallback)
name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tok = AutoTokenizer.from_pretrained(name); tok.pad_token = tok.eos_token
ds = datasets.load_dataset("Salesforce/wikitext", "wikitext-2-raw-v1", split="train[:200]")
ds = ds.filter(lambda e: len(e["text"].strip()) > 0)
ds = ds.map(lambda e: tok(e["text"], max_length=16, padding="max_length", truncation=True), batched=True)
coll = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
def run(bs, gas):
out = []
class CB(TrainerCallback):
def on_log(self, a, s, c, logs=None, **k):
if logs and "grad_norm" in logs: out.append(logs["grad_norm"])
m = AutoModelForCausalLM.from_pretrained(name, dtype=torch.float32)
args = TrainingArguments("/tmp/.ls_repro", per_device_train_batch_size=bs,
gradient_accumulation_steps=gas, label_smoothing_factor=0.1, max_steps=3,
logging_steps=1, learning_rate=1e-4, max_grad_norm=0.0, report_to=[], use_cpu=True)
Trainer(m, args, train_dataset=ds, data_collator=coll, callbacks=[CB()]).train()
return out
base = run(8, 1) # effective batch 8, no accumulation
gas = run(1, 8) # same effective batch 8, via 8 accumulation steps
print("base (bs=8, gas=1):", [round(x, 4) for x in base])
print("gas (bs=1, gas=8):", [round(x, 4) for x in gas])
print("ratio gas/base :", [round(g / b, 2) for b, g in zip(base, gas) if b])
Output on current main:
base (bs=8, gas=1): [0.3281, 0.3209, 0.3021]
gas (bs=1, gas=8): [2.9545, 3.0083, 2.6038]
ratio gas/base : [9.01, 9.37, 8.62]
Both runs have the same effective batch (8), so the grad norms should match. With label_smoothing_factor=0.1 the accumulation run is ~gradient_accumulation_steps× larger. Set label_smoothing_factor=0.0 and the ratio is ~1.0.
Expected behavior
Grad norms and logged loss should be independent of how the effective batch is split into accumulation steps — as they already are without label smoothing.
Root cause
With label smoothing the loss comes from LabelSmoother, not the model:
_get_num_items_in_batch returns a non-None count because model_accepts_loss_kwargs is True, so training_step skips the 1 / gradient_accumulation_steps normalization (the if (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None: guard).
- But
compute_loss calls self.label_smoother(outputs, labels, ...) without num_items_in_batch, and LabelSmoother.__call__ mean-reduces over the current micro-batch's active tokens (trainer_pt_utils.py).
- So neither path normalizes by the effective batch: the accumulated gradient becomes
sum_i(S_i / N_i) instead of sum_i(S_i) / sum_i(N_i).
The grad-accum change that added num_items_in_batch (#34198) only touched trainer.py, not trainer_pt_utils.py, so LabelSmoother was never wired up. The same num_items_in_batch is already passed into the compute_loss_func branch right next to it.
Fix
Passing num_items_in_batch into LabelSmoother and using it as the denominator (falling back to the active-token count when None) restores the ratio:
gas (bs=1, gas=8): [0.3281, 0.3209, 0.3021] # ratio 1.0
This shifts the loss/grad magnitude by ~gradient_accumulation_steps for existing label-smoothing + accumulation runs (the old values were over-scaled). Single-device and default DDP token averaging become exact; legacy DataParallel keeps the same n_gpu approximation the compute_loss_func path already has.
I have the fix and a red→green regression test ready — happy to open a PR.
System Info
transformers5.13.0.dev0 (main, 9fd7b67; also reproduced on 36193bf)torch2.11.0, Python 3.13.5, macOS 15.6 arm64 (forced CPU viause_cpu=True; not device-specific — also reproduces on the default MPS device)Who can help?
@muellerzr @SunMarc
Reproduction
Output on current main:
Both runs have the same effective batch (8), so the grad norms should match. With
label_smoothing_factor=0.1the accumulation run is ~gradient_accumulation_steps× larger. Setlabel_smoothing_factor=0.0and the ratio is ~1.0.Expected behavior
Grad norms and logged loss should be independent of how the effective batch is split into accumulation steps — as they already are without label smoothing.
Root cause
With label smoothing the loss comes from
LabelSmoother, not the model:_get_num_items_in_batchreturns a non-Nonecount becausemodel_accepts_loss_kwargsis True, sotraining_stepskips the1 / gradient_accumulation_stepsnormalization (theif (not self.model_accepts_loss_kwargs or num_items_in_batch is None) and self.compute_loss_func is None:guard).compute_losscallsself.label_smoother(outputs, labels, ...)withoutnum_items_in_batch, andLabelSmoother.__call__mean-reduces over the current micro-batch's active tokens (trainer_pt_utils.py).sum_i(S_i / N_i)instead ofsum_i(S_i) / sum_i(N_i).The grad-accum change that added
num_items_in_batch(#34198) only touchedtrainer.py, nottrainer_pt_utils.py, soLabelSmootherwas never wired up. The samenum_items_in_batchis already passed into thecompute_loss_funcbranch right next to it.Fix
Passing
num_items_in_batchintoLabelSmootherand using it as the denominator (falling back to the active-token count whenNone) restores the ratio:This shifts the loss/grad magnitude by ~
gradient_accumulation_stepsfor existing label-smoothing + accumulation runs (the old values were over-scaled). Single-device and default DDP token averaging become exact; legacy DataParallel keeps the samen_gpuapproximation thecompute_loss_funcpath already has.I have the fix and a red→green regression test ready — happy to open a PR.