Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
and the first seqlen elements will be sliced, but dim must match x.
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).
Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
Expand All @@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
"""
ndim = x.ndim
assert ndim > 1
batch_size = x.shape[0]
seqlen = x.shape[1]
freqs_cis = freqs_cis[0:seqlen]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


Expand Down Expand Up @@ -474,9 +473,18 @@ def get_attention_masks(
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
)

def get_order_sensitive_buffers(
self,
batch_size: int,
seq_len: int,
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
return ({"freqs_cis": freqs_cis}, {"freqs_cis": 1})

def forward(
self,
tokens: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
Expand All @@ -501,7 +509,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
h = layer(h, freqs_cis, attention_masks=attention_masks)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/protocols/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,10 @@ def get_attention_masks(
raise NotImplementedError(
"This model does not support attention masking/Flex Attention."
)

def get_order_sensitive_buffers(
self,
batch_size: int,
seq_len: int,
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
return ({}, {})
12 changes: 11 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ def forward_backward_step(
extra_inputs=extra_inputs,
)

# Get the order sensitive buffers
order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers(
inputs.size(0), inputs.size(1)
)
extra_args.update(order_sensitive_buffers[0])

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
Expand Down Expand Up @@ -485,7 +491,11 @@ def forward_backward_step(
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
pred = model_parts[0](
inputs,
**extra_inputs,
**extra_args,
)
loss = self.loss_fn(pred, labels)
# need to free pred before bwd to avoid peaking memory
del pred
Expand Down
Loading