From dffadc0fdcdd4a8ea73fe756f8a29ab1bc30dcbc Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 15 Oct 2025 10:46:18 -0700 Subject: [PATCH 1/2] [RFC] Lift freqs_cis as an input of models freqs_cis is sensitive to the sequence order. CP load balancing will shuffle the samples, so each batch will have different orders. As a result, we will have to lift these order senstive buffer to the inputs and broadcast them along the batch dimension so that PP will correctly shard freqs_cis without messing up the correctness. ghstack-source-id: 0612109ebd53a2220f703dc1ab815b8f380554b6 Pull-Request-resolved: https://github.com/pytorch/torchtitan/pull/1797 [ghstack-poisoned] --- torchtitan/models/llama3/model/model.py | 20 ++++++++++++++------ torchtitan/protocols/model.py | 7 +++++++ torchtitan/train.py | 12 +++++++++++- 3 files changed, 32 insertions(+), 7 deletions(-) diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 6f10719d12..0346499308 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. + The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim). Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. @@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten """ ndim = x.ndim assert ndim > 1 + batch_size = x.shape[0] seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1]) + shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @@ -474,9 +473,18 @@ def get_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_order_sensitive_buffers( + self, + batch_size: int, + seq_len: int, + ) -> tuple[dict[str, torch.Tensor], dict[str, int]]: + freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1) + return ({"freqs_cis": freqs_cis}, {"freqs_cis": 1}) + def forward( self, tokens: torch.Tensor, + freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): @@ -501,7 +509,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks) + h = layer(h, freqs_cis, attention_masks=attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a713bec65b..5b633243e6 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -70,3 +70,10 @@ def get_attention_masks( raise NotImplementedError( "This model does not support attention masking/Flex Attention." ) + + def get_order_sensitive_buffers( + self, + batch_size: int, + seq_len: int, + ) -> tuple[dict[str, torch.Tensor], dict[str, int]]: + return ({}, {}) diff --git a/torchtitan/train.py b/torchtitan/train.py index 1d5e0e500a..870e22b30c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -432,6 +432,12 @@ def forward_backward_step( extra_inputs=extra_inputs, ) + # Get the order sensitive buffers + order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers( + inputs.size(0), inputs.size(1) + ) + extra_args.update(order_sensitive_buffers[0]) + # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None @@ -485,7 +491,11 @@ def forward_backward_step( with self.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_inputs, **extra_args) + pred = model_parts[0]( + inputs, + **extra_inputs, + **extra_args, + ) loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory del pred From fd563603ee31fa991ac4467ba069c228ce2aaab3 Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Thu, 16 Oct 2025 00:09:34 -0700 Subject: [PATCH 2/2] Update on "[RFC] Lift freqs_cis as an input of models" freqs_cis is sensitive to the sequence order. CP load balancing will shuffle the samples, so each batch will have different orders. As a result, we will have to lift these order senstive buffer to the inputs and broadcast them along the batch dimension so that PP will correctly shard freqs_cis without messing up the correctness. Pull-Request-resolved: https://github.com/pytorch/torchtitan/pull/1797 [ghstack-poisoned]