Skip to content

Commit db82f8b

Browse files
authored
deprecate input_batch in model inputs (#1896)
Now that we are creating FlexAttention block masks outside the model, the `input_batch` field in model inputs is no longer needed. This PR also renames `extra_args` to be `extra_kwargs`, as it's a dictionary of kwargs, technically.
1 parent 5fb7cc2 commit db82f8b

File tree

8 files changed

+12
-40
lines changed

8 files changed

+12
-40
lines changed

torchtitan/components/validate.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,12 +137,9 @@ def validate(
137137
inputs,
138138
target=targets,
139139
losses=losses,
140-
input_batch=inputs,
141140
)
142141
else:
143-
self.pp_schedule.eval(
144-
target=targets, losses=losses, input_batch=inputs
145-
)
142+
self.pp_schedule.eval(target=targets, losses=losses)
146143

147144
# accumulate losses across pipeline microbatches
148145
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU

torchtitan/experiments/forge/example_train.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,10 @@ def forward_backward_step(
158158
parallel_dims = self.parallel_dims
159159

160160
inputs = input_dict["input"]
161-
extra_args = {}
161+
extra_kwargs = {}
162162

163163
if getattr(self.model_args, "use_flex_attn", False):
164-
extra_args["attention_masks"] = model_parts[0].get_attention_masks(
164+
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
165165
input_batch=inputs,
166166
tokenizer=self.tokenizer,
167167
)
@@ -187,17 +187,15 @@ def forward_backward_step(
187187
if self.pp_has_first_stage:
188188
self.pp_schedule.step(
189189
inputs,
190-
**extra_args,
190+
**extra_kwargs,
191191
target=targets,
192192
losses=losses,
193-
input_batch=inputs,
194193
)
195194
else:
196195
self.pp_schedule.step(
197-
**extra_args,
196+
**extra_kwargs,
198197
target=targets,
199198
losses=losses,
200-
input_batch=inputs,
201199
)
202200

203201
# accumulate losses across pipeline microbatches
@@ -215,7 +213,7 @@ def forward_backward_step(
215213
with self.train_context(optional_context_parallel_ctx):
216214
assert len(model_parts) == 1
217215
with self.maybe_enable_amp:
218-
pred = model_parts[0](inputs, **extra_args)
216+
pred = model_parts[0](inputs, **extra_kwargs)
219217
loss = self.loss_fn(pred, labels)
220218
# need to free to before bwd to avoid peaking memory
221219
del pred

torchtitan/experiments/vlm/model/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@ def forward(
9696
grid_thw: torch.Tensor,
9797
special_tokens: SpecialTokens,
9898
attention_masks: AttentionMasksType | None = None,
99-
input_batch: torch.Tensor | None = None,
10099
):
101100
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
102101
h_BSD = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,6 @@ def forward(
411411
self,
412412
tokens: torch.Tensor,
413413
attention_masks: AttentionMasksType | None = None,
414-
input_batch: torch.Tensor | None = None,
415414
):
416415
"""
417416
Forward pass for the Transformer model.
@@ -421,10 +420,6 @@ def forward(
421420
If pipeline parallelism is enabled, this will be the input token indices
422421
for the ranks on the first pipeline stage. This will be the activation of the
423422
previous pipeline stage if the current rank is not on the first stage.
424-
input_batch (torch.Tensor): The input batch read from the dataloader.
425-
This will always be the input batch regardless of the pipeline stage.
426-
This field is required for non-first PP stages to perform document
427-
masking attention (to analyze the boundary of the document).
428423
429424
Returns:
430425
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).

torchtitan/models/llama3/model/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,6 @@ def forward(
478478
self,
479479
tokens: torch.Tensor,
480480
attention_masks: AttentionMasksType | None = None,
481-
input_batch: torch.Tensor | None = None,
482481
):
483482
"""
484483
Perform a forward pass through the Transformer model.
@@ -488,10 +487,6 @@ def forward(
488487
If pipeline parallelism is enabled, this will be the input token indices
489488
for the ranks on the first pipeline stage. This will be the activation of the
490489
previous pipeline stage if the current rank is not on the first stage.
491-
input_batch (torch.Tensor): The input batch read from the dataloader.
492-
This will always be the input batch regardless of the pipeline stage.
493-
This field is required for non-first PP stages to perform document
494-
masking attention (to analyze the boundary of the document).
495490
496491
Returns:
497492
torch.Tensor: Output logits after applying the Transformer model.

torchtitan/models/llama4/model/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,6 @@ def forward(
539539
self,
540540
tokens: torch.Tensor,
541541
attention_masks: AttentionMasksType | None = None,
542-
input_batch: torch.Tensor | None = None,
543542
):
544543
"""
545544
Perform a forward pass through the Transformer model.
@@ -549,10 +548,6 @@ def forward(
549548
If pipeline parallelism is enabled, this will be the input token indices
550549
for the ranks on the first pipeline stage. This will be the activation of the
551550
previous pipeline stage if the current rank is not on the first stage.
552-
input_batch (torch.Tensor): The input batch read from the dataloader.
553-
This will always be the input batch regardless of the pipeline stage.
554-
This field is required for non-first PP stages to perform document
555-
masking attention (to analyze the boundary of the document).
556551
557552
Returns:
558553
torch.Tensor: Output logits after applying the Transformer model.

torchtitan/models/qwen3/model/model.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,6 @@ def forward(
470470
self,
471471
tokens: torch.Tensor,
472472
attention_masks: AttentionMasksType | None = None,
473-
input_batch: torch.Tensor | None = None,
474473
):
475474
"""
476475
Perform a forward pass through the Transformer model.
@@ -480,10 +479,6 @@ def forward(
480479
If pipeline parallelism is enabled, this will be the input token indices
481480
for the ranks on the first pipeline stage. This will be the activation of the
482481
previous pipeline stage if the current rank is not on the first stage.
483-
input_batch (torch.Tensor): The input batch read from the dataloader.
484-
This will always be the input batch regardless of the pipeline stage.
485-
This field is required for non-first PP stages to perform document
486-
masking attention (to analyze the boundary of the document).
487482
488483
Returns:
489484
torch.Tensor: Output logits after applying the Transformer model.

torchtitan/train.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -422,11 +422,11 @@ def forward_backward_step(
422422
extra_inputs = {k: v for k, v in input_dict.items() if k != "input"}
423423
# For arguments, like attention_masks, we have to put them in a separate
424424
# dict as extra_inputs are not forwarded to other stages in PP, but
425-
# extra_args are.
426-
extra_args = {}
425+
# extra_kwargs are.
426+
extra_kwargs = {}
427427

428428
if getattr(self.model_args, "use_flex_attn", False):
429-
extra_args["attention_masks"] = model_parts[0].get_attention_masks(
429+
extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks(
430430
input_batch=inputs,
431431
tokenizer=self.tokenizer,
432432
extra_inputs=extra_inputs,
@@ -457,17 +457,15 @@ def forward_backward_step(
457457
self.pp_schedule.step(
458458
inputs,
459459
**extra_inputs,
460-
**extra_args,
460+
**extra_kwargs,
461461
target=targets,
462462
losses=losses,
463-
input_batch=inputs,
464463
)
465464
else:
466465
self.pp_schedule.step(
467-
**extra_args,
466+
**extra_kwargs,
468467
target=targets,
469468
losses=losses,
470-
input_batch=inputs,
471469
)
472470

473471
# accumulate losses across pipeline microbatches
@@ -485,7 +483,7 @@ def forward_backward_step(
485483
with self.train_context(optional_context_parallel_ctx):
486484
assert len(model_parts) == 1
487485
with self.maybe_enable_amp:
488-
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
486+
pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs)
489487
loss = self.loss_fn(pred, labels)
490488
# need to free pred before bwd to avoid peaking memory
491489
del pred

0 commit comments

Comments
 (0)