From 9461315a9023d038f89135c46cd7b7994b8b2950 Mon Sep 17 00:00:00 2001 From: Rohan Pandey Date: Tue, 12 Aug 2025 16:48:02 -0700 Subject: [PATCH 01/18] gptoss experimental support --- torchtitan/experiments/gpt_oss/__init__.py | 58 ++ torchtitan/experiments/gpt_oss/model/args.py | 146 +++++ torchtitan/experiments/gpt_oss/model/model.py | 476 ++++++++++++++++ torchtitan/experiments/gpt_oss/model/moe.py | 283 ++++++++++ .../gpt_oss/scripts/compare_hf_to_tt.py | 405 ++++++++++++++ .../gpt_oss/scripts/convert_gptoss.py | 513 ++++++++++++++++++ .../gpt_oss/train_configs/debug_model.toml | 73 +++ .../gpt_oss/train_configs/gpt_oss_120b.toml | 70 +++ .../gpt_oss/train_configs/gpt_oss_20b.toml | 70 +++ torchtitan/models/attention.py | 83 ++- 10 files changed, 2168 insertions(+), 9 deletions(-) create mode 100644 torchtitan/experiments/gpt_oss/__init__.py create mode 100644 torchtitan/experiments/gpt_oss/model/args.py create mode 100644 torchtitan/experiments/gpt_oss/model/model.py create mode 100644 torchtitan/experiments/gpt_oss/model/moe.py create mode 100644 torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py create mode 100644 torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py create mode 100644 torchtitan/experiments/gpt_oss/train_configs/debug_model.toml create mode 100644 torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml create mode 100644 torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py new file mode 100644 index 0000000000..67a74c124d --- /dev/null +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers + +from torchtitan.protocols.train_spec import register_train_spec, TrainSpec + +from .infra.parallelize import parallelize_gptoss +from .model.args import GptOssModelArgs +from .model.model import GptOssModel + +__all__ = [ + "parallelize_gptoss", + "GptOssModelArgs", + "GptOssModel", + "gptoss_configs", +] + + +gptoss_configs = { + "debugmodel": GptOssModelArgs( + hidden_size=256, + num_hidden_layers=4, + ), + "20B": GptOssModelArgs( + num_hidden_layers=24, + num_local_experts=32, + ), + "120B": GptOssModelArgs( + num_hidden_layers=36, + num_local_experts=128, + ), +} + + +register_train_spec( + TrainSpec( + name="gpt_oss", + cls=GptOssModel, + config=gptoss_configs, + parallelize_fn=parallelize_gptoss, + pipelining_fn=None, + build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) +) diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py new file mode 100644 index 0000000000..227f24ddc2 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -0,0 +1,146 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass +from typing import Literal + +from torch import nn + +from torchtitan.components.tokenizer import Tokenizer +from torchtitan.config_manager import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger + +# from transformers.models.gpt_oss.modeling_gpt_oss import GPT_OSS_PRETRAINED_INIT_CONFIGURATION + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class GptOssModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + n_routed_experts (int): Number of routed experts for MoE layers. + n_shared_experts (int): Number of shared experts for MoE layers. + n_activated_experts (int): Number of activated experts in MoE layers. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. + route_scale (float): Scaling factor for routing scores. + use_grouped_mm (bool): Whether to use grouped matrix multiplication for MoE layers. + load_balance_coeff (float | None): Auxiliary-Loss-Free Load balancing coefficient for MoE layers. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 131072 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 201088 + hidden_size: int = 2880 + num_hidden_layers: int = 24 + norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE + num_local_experts: int = 32 + num_experts_per_tok: int = 4 + use_grouped_mm: bool = True + # Multi-Head Latent Attention (MLA) + head_dim: int = 64 + num_attention_heads: int = 64 + num_key_value_heads: int = 8 + sliding_window: int = 128 + use_flex_attn: bool = True + attn_mask_type: str = "causal" + # yarn + original_seq_len: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32 + beta_fast: int = 32 + beta_slow: int = 1 + + def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: + """ + Update the model_config config from the given job config. + """ + # self.vocab_size = tokenizer.vocab_size # TODO: add tiktokenizer support? + self.max_seq_len = job_config.training.seq_len + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + """ + Adopted from llama4 implementation. + """ + nparams_embedding = 0 + nparams_moe_router = 0 + nparams_shared_expert = 0 + nparams_experts = 0 + nparams_dense = 0 + + for name, p in model.named_parameters(): + if "embedding" in name: + nparams_embedding += p.numel() + nparams_dense += p.numel() + elif "moe.shared_expert" in name: + nparams_shared_expert += p.numel() + elif "moe.router" in name: + nparams_moe_router += p.numel() + elif "moe.experts" in name: + nparams_experts += p.numel() + else: + nparams_dense += p.numel() + + nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts + nparams = nparams_dense + nparams_sparse + nparams_sparse_active = ( + nparams_moe_router + + nparams_shared_expert + + nparams_experts * self.num_experts_per_tok // self.num_local_experts + ) + + logger.info( + f"Total parameter count: dense {nparams_dense:,}, " + f"sparse {nparams_sparse:,}, active {nparams_dense + nparams_sparse_active:,}" + ) + + l, h, q, t = ( + self.num_hidden_layers, + self.num_attention_heads, + self.hidden_size // self.num_attention_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + num_flops_per_token = ( + 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + + 12 * l * h * q * t + ) + + return nparams, num_flops_per_token diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py new file mode 100644 index 0000000000..835816c2a0 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -0,0 +1,476 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Tuple + +import torch +from torch import nn +from torchtitan.models.attention import build_attention +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import GptOssModelArgs +from .moe import MoE + +# TODO: may be able to remove this once parallelized properly +def convert_submodules_to_bf16( + module: nn.Module, + exclude_names: tuple[str, ...] = ("freqs_cis", "attention_norm", "ffn_norm", "norm"), + attr_opt_out: str = "no_bf16", # if a submodule sets `self.no_bf16 = True`, it will be skipped + ) -> None: + """ + Recursively convert parameters & buffers of submodules to bfloat16, + except: + - modules whose *qualified name* ends with any of `exclude_names` + - modules with attribute `{attr_opt_out} == True` + Conversion is *shallow per-module* so exclusions are respected even deep in the tree. + """ + + def should_skip(qname: str, mod: nn.Module) -> bool: + base = qname.rsplit(".", 1)[-1] # local (leaf) name + if base in exclude_names: + return True + if getattr(mod, attr_opt_out, False): + return True + return False + + def convert_shallow(mod: nn.Module): + # convert parameters owned by this module + for _, p in mod.named_parameters(recurse=False): + if p.is_floating_point(): + p.data = p.data.to(torch.bfloat16) + # convert buffers owned by this module + for _, b in mod.named_buffers(recurse=False): + # keep non-float buffers (e.g., ints, bool masks) as-is + if torch.is_floating_point(b): + b.data = b.data.to(torch.bfloat16) + + # walk the module tree; convert only *this* module's tensors if not skipped + for qname, mod in module.named_modules(): + # skip the root container name (empty) check gracefully + local_name = qname.rsplit(".", 1)[-1] if qname else "" + if local_name and should_skip(qname, mod): + continue + convert_shallow(mod) + +# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: GptOssModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (GptOssModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + original_seq_len = args.original_seq_len + + # YaRN default m-scale (attention_factor). Matches HF when attention_factor is None. + mscale = 0.1 * math.log(factor) + 1.0 + + def find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int + ) -> float: + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int + ) -> Tuple[int, int]: + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Basic RoPE frequency calculation + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. + if seqlen > original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + # Create position indices + t = torch.arange(seqlen) + + # Outer product: [positions] Ɨ [frequencies] + freqs = torch.outer(t, freqs) + + # Convert to complex exponentials: e^(i*freq*pos) + freqs_cis = torch.polar(torch.full_like(freqs, fill_value=mscale), freqs) + + return freqs_cis + + +def apply_rotary_emb_inner(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + +def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor): + """ + HF-style inputs (half-split last dim) -> interleave -> Torchtitan complex RoPE -> de-interleave. + Shapes: + q, k: [B, T, H, D] with D even (HF half-split: first D/2 real, last D/2 imag) + freqs_cis: complex, last dim == D/2. Typically [T, D/2] or [1, T, D/2]. + Returns: + q_out, k_out in HF half-split layout (same shape as q, k). + """ + B, T, H, D = q.shape + assert D % 2 == 0, "head_dim must be even for RoPE" + rot = D // 2 + assert freqs_cis.shape[-1] == rot, "freqs_cis last dim must be D/2" + freqs_cis = freqs_cis[:T, :] + + # --- inline: HF half-split -> interleaved (real0, imag0, real1, imag1, ...) + # q_i, k_i: [B, T, H, D] + q_i = torch.empty_like(q) + k_i = torch.empty_like(k) + q_i[..., 0::2] = q[..., :rot] + q_i[..., 1::2] = q[..., rot:] + k_i[..., 0::2] = k[..., :rot] + k_i[..., 1::2] = k[..., rot:] + + # --- Torchtitan default complex apply (expects interleaved last dim) + # freqs_cis will be reshaped inside to [1, T, 1, rot] + q_rot_i = apply_rotary_emb_inner(q_i, freqs_cis) # uses TT's complex path + k_rot_i = apply_rotary_emb_inner(k_i, freqs_cis) + + # --- inline: interleaved -> HF half-split + q_out = torch.cat([q_rot_i[..., 0::2], q_rot_i[..., 1::2]], dim=-1) + k_out = torch.cat([k_rot_i[..., 0::2], k_rot_i[..., 1::2]], dim=-1) + return q_out, k_out + +# Torch Attention backup implementation (for debugging and sampling) from HuggingFace +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + +def eager_attention_forward( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + sinks: torch.Tensor, + attention_mask: torch.Tensor, + scaling: float, + dropout: float = 0.0, + num_key_value_groups: int = 1, + **kwargs, +): + key_states = repeat_kv(key, num_key_value_groups) + value_states = repeat_kv(value, num_key_value_groups) + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + # attention_mask can be [Tq, Tk] or [B, H, Tq, Tk] + # Convert boolean "allowed" -> additive mask + if attention_mask.dtype == torch.bool: + m = attention_mask + add_mask = torch.zeros_like(m, dtype=attn_weights.dtype) + add_mask = add_mask.masked_fill(~m, -float("inf")) + else: + add_mask = attention_mask.to(attn_weights.dtype) + + # Truncate to current key length and add (broadcasts if needed) + add_mask = add_mask[..., : key_states.shape[-2]] + attn_weights = attn_weights + add_mask + + sinks = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = nn.functional.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value_states) + return attn_output + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = False): + super().__init__() + + self.sliding_window = model_args.sliding_window if use_sliding_attention else None + self.head_dim = model_args.head_dim + + self.wq = nn.Linear( + model_args.hidden_size, model_args.num_attention_heads * model_args.head_dim, bias=True + ) + self.wk = nn.Linear( + model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + ) + self.wv = nn.Linear( + model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + ) + self.wo = nn.Linear( + model_args.num_attention_heads * model_args.head_dim, model_args.hidden_size, bias=True + ) + self.sinks = nn.Parameter(torch.empty(model_args.num_attention_heads)) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.attn = build_attention(True, model_args.attn_mask_type) + else: + # NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed + self.attn = eager_attention_forward + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + hidden_shape = (bsz, seqlen, -1, self.head_dim) + + q = self.wq(x).view(hidden_shape) + k = self.wk(x).view(hidden_shape) + v = self.wv(x).view(hidden_shape) + + q, k = apply_rotary_emb(q, k, freqs_cis) + + q = q.transpose(1, 2).contiguous() + k = k.transpose(1, 2).contiguous() + v = v.transpose(1, 2).contiguous() + + if self.use_flex_attn: + output = self.attn(q, k, v, self.sinks, sliding_window=self.sliding_window, enable_gqa=True) + else: + output = self.attn( + q, k, v, self.sinks, + attention_mask=self.sliding_window_causal(seqlen, x.device), + scaling=self.head_dim**-0.5, + dropout=0.0, + num_key_value_groups=8, + ) + output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) + + # Reshape and project output + output = output.reshape(bsz, seqlen, -1).contiguous() # (bsz, seqlen, n_heads * v_head_dim) + output = self.wo(output) # (bsz, seqlen, dim) + return output + + def init_weights(self, init_std: float): + linear_list = [ + self.wq, + self.wk, + self.wv, + ] + + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + # TODO: statically init the mask using train.seq_len + def sliding_window_causal(self, seqlen, device): + i = torch.arange(seqlen, device=device) + q_idx = i[:, None] + kv_idx = i[None, :] + + causal_mask = q_idx >= kv_idx + if self.sliding_window is None: + return causal_mask + window_mask = q_idx - kv_idx <= self.sliding_window + return causal_mask & window_mask + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: GptOssModelArgs): + + super().__init__() + use_sliding_attention = layer_id % 2 == 0 + self.attention = Attention(model_args, use_sliding_attention=use_sliding_attention) + self.attention_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + + self.moe = MoE(model_args) + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.moe(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.moe.init_weights(self.weight_init_std, buffer_device) + + +class GptOssModel(nn.Module, ModelProtocol): + """ + GPT-OSS Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: GptOssModelArgs): + super().__init__() + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_size) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(model_args), persistent=True + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.num_hidden_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16) + convert_submodules_to_bf16(self.layers[str(layer_id)]) + + self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + self.output = nn.Linear( + model_args.hidden_size, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + self.init_weights() + convert_submodules_to_bf16(self) + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = precompute_freqs_cis(self.model_args) + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.hidden_size**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def forward(self, tokens: torch.Tensor): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + h = self.tok_embeddings(tokens) + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis) + h = self.norm(h) + output = self.output(h) + return output diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py new file mode 100644 index 0000000000..667d329b93 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -0,0 +1,283 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn.functional as F +from torch import nn +from torchtitan.experiments.llama4.infra.expert_parallel import expert_parallel + +from .args import GptOssModelArgs + +def swiglu(x, alpha: float = 1.702, limit: float = 7.0): + x_glu, x_linear = x[..., ::2], x[..., 1::2] + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + +class GroupedExperts(nn.Module): + def __init__( + self, + dim: int, + num_experts: int, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.use_grouped_mm = use_grouped_mm + + self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2))) + self.mlp1_bias = nn.Parameter(torch.empty((num_experts, dim * 2))) + self.mlp2_weight = nn.Parameter(torch.empty((num_experts, dim, dim))) + self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.use_grouped_mm: + return GroupedExperts._run_experts_grouped_mm( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + ) + else: + return GroupedExperts._run_experts_for_loop( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + ) + + # TODO: keeping this for-loop implementation for comparison + # and readability, may remove later + # @expert_parallel + @staticmethod + def _run_experts_for_loop( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = torch.matmul(x_expert, mlp1_weight[expert_idx]) + mlp1_bias[expert_idx] + h = swiglu(h) + h = torch.matmul(h, mlp2_weight[expert_idx]) + mlp2_bias[expert_idx] + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + else: + # x shape (num_experts, tokens_per_expert, dim) + h = torch.bmm(x, mlp1_weight) + mlp1_bias.unsqueeze(1) + h = swiglu(h) + out = torch.bmm(h, mlp2_weight) + mlp2_bias.unsqueeze(1) + + return out + + # @expert_parallel # TODO: e-sharding currently breaks shapes + @staticmethod + def _run_experts_grouped_mm( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if num_tokens_per_expert is not None: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + # grouped mm between a 2D tensor and a 3D tensor + assert x.dim() == 2 + else: + offsets = None + # fall back to regular bmm between 3D tensors + assert x.dim() == 3 + + num_tokens_per_expert_long = num_tokens_per_expert.to(torch.long) + + h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets) + h += mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + h = swiglu(h) + h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) + h += mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + + return h + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.mlp2_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std) + + def extra_repr(self): + return (f"num_experts={self.num_experts}, " + f"use_grouped_mm={self.use_grouped_mm}, " + f"mlp1_weight={tuple(self.mlp1_weight.shape)}, " + f"mlp1_bias={tuple(self.mlp1_bias.shape)}, " + f"mlp2_weight={tuple(self.mlp2_weight.shape)}, " + f"mlp2_bias={tuple(self.mlp2_bias.shape)}") + +class TokenChoiceTopKRouter(nn.Module): + """This class implements token-choice routing. In token-choice top-K routing, each token is + routed to top K experts based on the router scores. + + Args: + dim (int): Dimension of the input. + num_experts (int): Number of experts in each moe layer. + top_k (int): Number of experts each token will be routed to in token-choice routing. + """ + + def __init__( + self, + dim: int, + num_experts: int, + top_k: int, + ): + super().__init__() + + self.dim = dim + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(self.dim, self.num_experts, bias=True) + + def forward( + self, x: torch.Tensor, expert_bias: torch.Tensor | None = None + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + TODO: We haven't implement the group-based routing (node limit routing), + and currently EP is not supporting node limit routing yet. + + Args: + x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. + + Returns: + routed_input (torch.Tensor): + Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. + token_indices (torch.Tensor): + Token indices for routed_input with shape ``(bs*slen*top_k,)``. + num_tokens_per_expert (torch.Tensor): + Number of tokens assigned to each expert with shape ``(num_experts,)``. + """ + # scores shape (bs*slen, num_experts) + router_logits = self.gate(x) + + # top scores shape (bs*slen, top_k) + top_scores, selected_experts_indices = torch.topk( + router_logits, k=self.top_k, dim=1 + ) + + top_scores = F.softmax(top_scores, dim=1) + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + # Reorder the token indices to match the order of the experts + # token_indices_experts_sorted shape (bs*slen*top_k,) + token_indices_experts_sorted = torch.argsort( + selected_experts_indices.view(-1), stable=True + ) + + # reorder the scores to match the order of the token indices + top_scores = top_scores.view(-1)[token_indices_experts_sorted] + token_indices_experts_sorted = token_indices_experts_sorted // self.top_k + + return top_scores, token_indices_experts_sorted, num_tokens_per_expert + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) + + +class MoE(nn.Module): + def __init__(self, model_args: GptOssModelArgs): + + super().__init__() + dim = model_args.hidden_size + + num_experts = model_args.num_local_experts + top_k = model_args.num_experts_per_tok + + self.experts = GroupedExperts( + dim=dim, + num_experts=num_experts, + use_grouped_mm=model_args.use_grouped_mm, + ) + self.router = TokenChoiceTopKRouter( + dim=dim, + num_experts=num_experts, + top_k=top_k, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. + + Returns: + out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. + """ + bs, slen, dim = x.shape + + # top_scores and selected_indices shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + token_indices, + num_tokens_per_expert, + ) = self.router(x.reshape(bs * slen, dim)) + + # shape (bs*slen*top_k, dim) + token_indices = token_indices.reshape(-1, 1).expand(-1, dim) + + # shape (bs*slen*top_k, dim) + routed_input = torch.gather( + x.view(-1, dim), + dim=0, + index=token_indices, + ) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_tokens_per_expert) + + routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( + x.dtype + ) + + out = torch.zeros_like(x.reshape(bs * slen, dim)) + + # Accumulate multiple expert results becase each token can be routed to multiple experts + out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.reshape(bs, slen, dim) + return out + + def init_weights( + self, + init_std: float, + buffer_device: torch.device, + ): + self.experts.init_weights(init_std) + self.router.init_weights(init_std) diff --git a/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py b/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py new file mode 100644 index 0000000000..dbbb880af5 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py @@ -0,0 +1,405 @@ +""" +Compare logits and generations of GPT-OSS implemented in TorchTitan and HuggingFace. +This requires at least a 2xH100. + +First ensure you convert the HF model to a TorchTitan DCP checkpoint: +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ + +Then you can run a comparison like this: +uv run torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py \ + --tt_config torchtitan/models/gpt_oss/train_configs/gpt_oss_20b.toml \ + --tt_checkpoint_path gptoss_dcp/ \ + --hf_model_path openai/gpt-oss-20b \ + --prompt "Once upon a time, in a land far away," \ + --temperature 0.8 \ + --max_new_tokens 256 \ + --batch_size 1 \ + --out +""" + +import json +import os +import sys +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Sequence, Tuple, NamedTuple + +import torch +import torch.nn as nn +import torch.distributed.checkpoint as dcp +import tyro +from transformers import AutoModelForCausalLM, AutoTokenizer + +from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.utils import device_module, device_type +from torchtitan.components.metrics import build_device_memory_monitor +from torchtitan.config_manager import ConfigManager +from torchtitan.protocols.train_spec import get_train_spec +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torch.distributed import DeviceMesh +from torch.distributed.elastic.multiprocessing.errors import record + +# -------- Torchtitan Sampling Utils -------- +def multinomial_sample_one( + probs: torch.Tensor, rng: Optional[torch.Generator] = None +) -> torch.Tensor: + q = torch.empty_like(probs).exponential_(1, generator=rng) + return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) + + +def logits_to_probs( + logits: torch.Tensor, + temperature: float = 1.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + logits = logits / max(temperature, 1e-5) + + if top_k is not None: + v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) + pivot = v.select(dim=-1, index=-1).unsqueeze(-1) + logits = torch.where(logits < pivot, -float("Inf"), logits) + + probs = torch.nn.functional.softmax(logits, dim=-1) + return probs + + +def generate_next_token( + model, + x: torch.Tensor, + *, + temperature: float = 1.0, + top_k: Optional[int] = None, + rng: Optional[torch.Generator] = None, +) -> torch.Tensor: + logits = model(x) # (B, T, vocab_size) + probs = logits_to_probs(logits[:, -1, :], temperature, top_k) + next_token = multinomial_sample_one(probs, rng=rng) + return next_token + + +@torch.no_grad() +def tt_generate_text( + model, + input_ids: torch.Tensor, + *, + max_new_tokens: int, + temperature: float = 1.0, + top_k: Optional[int] = None, + seed: Optional[int] = None, +) -> torch.Tensor: + # ensure batch dimension (T,) --> (B, T) + if input_ids.ndim == 1: + input_ids = input_ids.unsqueeze(0) + + rng = None + if seed is not None: + rng = torch.Generator(input_ids.device).manual_seed(seed) + + generated_tokens = input_ids.clone() + + for i in range(max_new_tokens): + next_token = generate_next_token( + model, + x=generated_tokens.to(input_ids.device), + temperature=temperature, + top_k=top_k, + rng=rng, + ) + print(f"generated token {i}: {next_token}") + + generated_tokens = torch.cat([generated_tokens, next_token], dim=1) + + return generated_tokens + +@dataclass +class GenerateConfig: + """Configuration for test generation.""" + hf_model_path: Optional[str] = None + """HuggingFace model path to load (if provided).""" + tt_config: Optional[str] = None + """TOML config file path for TorchTitan model.""" + tt_checkpoint_path: Optional[str] = None + """Checkpoint path for the TorchTitan model (if provided).""" + tt_tokenizer_path: Optional[str] = "libs/torchtitan/torchtitan/models/gpt_oss_20b/tokenizer" + """Tokenizer path to load.""" + temperature: float = 1.0 + """Sampling temperature (0 for greedy).""" + max_new_tokens: int = 32 + """Max number of tokens to generate.""" + batch_size: int = 1 + """Batch size for inputs.""" + top_k: Optional[int] = None + """Top-k sampling (optional).""" + seed: Optional[int] = None + """Random seed for reproducibility.""" + deterministic: bool = False + """Use deterministic algorithms.""" + prompt: str = "" + """Input prompt string.""" + out: bool = False + """If true, print JSON report at end.""" + + +class LogitsComparison(NamedTuple): + max_abs_diff: float + mean_abs_diff: float + max_rel_diff: float + mean_rel_diff: float + allclose_results: Sequence[Tuple[float, float, str, bool]] + sample_diffs: Optional[torch.Tensor] + systematic_offset: Optional[Tuple[float, float]] + + +def load_hf_model(path: str, device: torch.device) -> nn.Module: + model = AutoModelForCausalLM.from_pretrained(path).to(device) + model.eval() + return model + +def print_param_dtypes_first_block(model): + """ + Prints the dtype of every parameter in the given model. + For any parameters under a 'layers' module (e.g., layers.), + only prints those from the first block (idx == "0"). + This works for both GptOssForCausalLM (with a .model submodule) + and GptOssModel architectures. + """ + for name, param in model.named_parameters(): + parts = name.split('.') + # If this parameter is under a 'layers' module, check its index + if 'layers' in parts: + idx = parts.index('layers') + 1 + if idx < len(parts) and parts[idx] != '0': + continue + print(f"{name:50s} → {param.dtype}") + +def get_logits(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + out = model(input_ids) + if hasattr(out, "logits"): + return out.logits + else: + return out + + +def compare_logits( + tt_logits: torch.Tensor, + hf_logits: torch.Tensor, + tolerances: Sequence[Tuple[float, float, str]] = ( + (1e-4, 1e-6, "Very Strict"), + (1e-2, 1e-4, "Strict"), + (1e-1, 1e-2, "Moderate"), + ), +) -> LogitsComparison: + # Apply softmax to convert logits to probabilities + hf_logits = torch.nn.functional.softmax(hf_logits.float(), dim=-1) + tt_logits = torch.nn.functional.softmax(tt_logits.float(), dim=-1) + + diff = torch.abs(tt_logits - hf_logits) + max_abs = float(torch.max(diff)) + mean_abs = float(torch.mean(diff)) + rel = diff / (torch.abs(tt_logits) + 1e-8) + max_rel = float(torch.max(rel)) + mean_rel = float(torch.mean(rel)) + + results = [] + any_match = False + for rtol, atol, name in tolerances: + match = torch.allclose(tt_logits, hf_logits, rtol=rtol, atol=atol) + results.append((rtol, atol, name, bool(match))) + if match: + any_match = True + break + + sample_diffs = None + sys_offset = None + if not any_match: + flat = (tt_logits - hf_logits).flatten() + sample_diffs = flat[:25] + sys_offset = (float(torch.mean(flat)), float(torch.std(flat))) + + return LogitsComparison(max_abs, mean_abs, max_rel, mean_rel, results, sample_diffs, sys_offset) + + +def generate_text( + model: nn.Module, + input_ids: torch.Tensor, + max_new_tokens: int, + temperature: float = 0.0, + top_k: Optional[int] = None, +) -> torch.Tensor: + do_sample = temperature > 0 + temp_arg = temperature if do_sample else None + with torch.no_grad(): + return model.generate( + input_ids, + max_new_tokens=max_new_tokens, + do_sample=do_sample, + temperature=temp_arg, + top_k=top_k, + ) + + +def print_logits_comparison(comp: LogitsComparison): + print("\n" + "="*70) + print("LOGITS COMPARISON") + print("="*70) + print(f"Max abs diff: {comp.max_abs_diff:.6f}") + print(f"Mean abs diff: {comp.mean_abs_diff:.6f}") + print(f"Max rel diff: {comp.max_rel_diff:.6f}") + print(f"Mean rel diff: {comp.mean_rel_diff:.6f}\n") + print("Tolerance tests:") + for rtol, atol, name, match in comp.allclose_results: + print(f" {'āœ…' if match else 'āŒ'} {name} (rtol={rtol}, atol={atol})") + if comp.sample_diffs is not None: + print("\nšŸ” Sample diffs (first 25):") + for v in comp.sample_diffs.tolist(): + print(f" {v:.6f}") + mean, std = comp.systematic_offset + print(f"\nSystematic offset: mean={mean:.6f}, std={std:.6f}") + + +def print_generation(title: str, outputs: torch.Tensor, tokenizer): + text = tokenizer.decode(outputs[0].tolist()) + print("\n" + "="*60) + print(title) + print("="*60) + print(text) + print("="*60) + + +def print_generation_comparison( + tt_out: torch.Tensor, + hf_out: torch.Tensor, + tokenizer, + prompt_len: int, +): + tt_tokens = tt_out[0][prompt_len:].tolist() + hf_tokens = hf_out[0][prompt_len:].tolist() + n = min(len(tt_tokens), len(hf_tokens)) + matches = sum(1 for i in range(n) if tt_tokens[i] == hf_tokens[i]) + print("\n" + "="*70) + print("GENERATION COMPARISON") + print("="*70) + print(f"Match rate: {matches}/{n} ({matches/n*100:.1f}%)") + if matches != n or len(tt_tokens) != len(hf_tokens): + print("First mismatches:") + for i in range(min(10, n)): + if tt_tokens[i] != hf_tokens[i]: + tt_txt = tokenizer.decode([tt_tokens[i]]) + hf_txt = tokenizer.decode([hf_tokens[i]]) + print(f" Pos {i}: TT='{tt_txt}' vs HF='{hf_txt}'") + + +@record +def test_generate(args: GenerateConfig): + init_logger() + + if not args.hf_model_path and not args.tt_config: + raise ValueError("Either hf_model_path or tt_config must be provided.") + if not args.prompt: + logger.warning("Empty prompt; generating from scratch.") + + # --- Common setup: tokenizer & inputs --- + if args.hf_model_path: + tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path) + input_ids = tokenizer.encode(args.prompt, add_special_tokens=False, return_tensors="pt") + print(input_ids) + if args.tt_config: + config_mgr = ConfigManager() + config = config_mgr.parse_args([ + f"--job.config_file={args.tt_config}", + f"--model.tokenizer_path={args.tt_tokenizer_path}", + ]) + train_spec = get_train_spec(config.model.name) + + # --- HuggingFace model (optional) --- + hf_model = None + hf_logits = None + hf_out = None + if args.hf_model_path: # NOTE: comment this block out for rapid tt testing + hf_device = torch.device(f"{device_type}:0") + hf_model = load_hf_model(args.hf_model_path, hf_device) + print("\n" + "="*60) + print("HUGGINGFACE MODEL ARCHITECTURE:") + print(hf_model) + print("="*60) + print_param_dtypes_first_block(hf_model) + print("="*60) + + hf_in = input_ids.to(hf_device) + hf_logits = get_logits(hf_model, hf_in).to(input_ids.device) + print(f"hf_logits: {hf_logits[:, :, 42069:42072]}") + hf_out = generate_text( + hf_model, hf_in, + max_new_tokens=args.max_new_tokens, + temperature=0.0, + top_k=args.top_k, + ).to(input_ids.device) + + # --- TorchTitan model (optional) --- + tt_model = None + tt_logits = None + tt_out = None + if args.tt_config: + # (Original TT setup: distributed, device, checkpoint load, etc.) + world_size = int(os.environ.get("WORLD_SIZE", 1)) + device = torch.device(f"{device_type}:1") + device_module.set_device(device) + dist_utils.set_determinism(None, device, args.seed, args.deterministic) + + # instantiate & load TT model + model_args = train_spec.config[config.model.flavor] + model_args.update_from_config(config, tokenizer) + init_dev = "meta" if world_size > 1 else device + with torch.device(init_dev): + tt_model = train_spec.cls(model_args) + if world_size > 1: + # parallelize if needed + pass + print("\n" + "="*60) + print("TORCHTITAN MODEL ARCHITECTURE:") + print(tt_model) + print("="*60) + print_param_dtypes_first_block(tt_model) + print("="*60) + + tt_model.eval() + if args.tt_checkpoint_path: # only load checkpoint if provided + tt_state = tt_model.state_dict() + tt_state.pop("freqs_cis", None) + state = {"model": tt_state} + dcp.load(state, checkpoint_id=args.tt_checkpoint_path) + + tt_logits = get_logits(tt_model, input_ids.to(device)).to(hf_logits.device if hf_logits is not None else device) + print(f"āœ… Torchtitan model forward pass succeeded: {tt_logits.shape=}") + print(f"tt_logits: {tt_logits[:, :, 42069:42072]}") + + tt_out = tt_generate_text( + tt_model, input_ids.to(device), + max_new_tokens=args.max_new_tokens, + temperature=args.temperature, + top_k=args.top_k, + seed=args.seed, + ) + + # --- Logits comparison (if both present) --- + if hf_logits is not None and tt_logits is not None: + comp = compare_logits(tt_logits, hf_logits) + print_logits_comparison(comp) + + # --- Print generations --- + if hf_out is not None: + print_generation("HUGGINGFACE MODEL OUTPUT:", hf_out, tokenizer) + if tt_out is not None: + print_generation("TORCHTITAN MODEL OUTPUT:", tt_out, tokenizer) + + # --- Generation comparison --- + if hf_out is not None and tt_out is not None: + prompt_len = input_ids.size(1) + print_generation_comparison(tt_out, hf_out, tokenizer, prompt_len) + + +if __name__ == "__main__": + args = tyro.cli(GenerateConfig) + test_generate(args) diff --git a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py new file mode 100644 index 0000000000..f69d5898d7 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py @@ -0,0 +1,513 @@ +""" +Convert checkpoints between TorchTitan and HuggingFace. + +# Convert HF to TorchTitan DCP +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ + +# Convert TorchTitan DCP to HF +uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py dcp-to-hf --input-path gptoss_dcp/ --output-path gptoss_hf/ +""" + +import tempfile +from pathlib import Path +from typing import Union + +import torch +import torch.distributed.checkpoint as DCP +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaConfig +from tqdm import tqdm +from tyro.extras import SubcommandApp + +from torchtitan.tools.logging import init_logger, logger + +app = SubcommandApp() + + +def validate_config_compatibility(hf_config, torchtitan_config_name, torchtitan_configs): + """Validate that HF config is compatible with TorchTitan config.""" + if torchtitan_config_name not in torchtitan_configs: + available = list(torchtitan_configs.keys()) + raise ValueError(f"TorchTitan config '{torchtitan_config_name}' not found. Available: {available}") + + tt_config = torchtitan_configs[torchtitan_config_name] + + # Critical configuration checks with proper field mappings + checks = [ + ("vocab_size", "vocab_size"), + ("hidden_size", "hidden_size"), + ("num_hidden_layers", "num_hidden_layers"), + ("head_dim", "head_dim"), + ("num_attention_heads", "num_attention_heads"), + ("num_key_value_heads", "num_key_value_heads"), + ("sliding_window", "sliding_window"), + ("num_local_experts", "num_local_experts"), + ("num_experts_per_tok", "num_experts_per_tok"), + ("rope_theta", "rope_theta"), + # ("rope_scaling.factor", "rope_factor"), + # ("rope_scaling.beta_fast", "beta_fast"), + # ("rope_scaling.beta_slow", "beta_slow"), + ] + + mismatches = [] + warnings = [] + + for hf_attr, tt_attr in checks: + hf_val = getattr(hf_config, hf_attr, None) + tt_val = getattr(tt_config, tt_attr, None) + + if hf_val != tt_val: + mismatches.append(f"{hf_attr}: HF={hf_val} vs TT.{tt_attr}={tt_val}") + + if mismatches: + raise ValueError(f"Config mismatch for {torchtitan_config_name}:\n" + "\n".join(mismatches)) + + if warnings: + print(f"āš ļø Configuration warnings for {torchtitan_config_name}:") + for warning in warnings: + print(f" {warning}") + print(" These differences might affect model behavior but won't prevent conversion.") + + print(f"āœ“ Configuration validation passed for {torchtitan_config_name}") + return tt_config + +def validate_tt_keys(tt_sd, n_layers, strict=True): + """Ensure the TorchTitan dict looks like gpt-oss as encoded in hf->tt mapping.""" + top_expected = [ + "tok_embeddings.weight", + "output.weight", + "norm.weight", + ] + per_layer_expected = [ + # attention projections + biases + sinks + "attention.wq.weight", "attention.wq.bias", + "attention.wk.weight", "attention.wk.bias", + "attention.wv.weight", "attention.wv.bias", + "attention.wo.weight", "attention.wo.bias", + "attention.sinks", + # MoE experts (mlp1/2) + biases + "moe.experts.mlp1_weight", "moe.experts.mlp1_bias", + "moe.experts.mlp2_weight", "moe.experts.mlp2_bias", + # Router + "moe.router.gate.weight", "moe.router.gate.bias", + # Norms + "attention_norm.weight", "ffn_norm.weight", + ] + + missing = [] + for k in top_expected: + if k not in tt_sd: + missing.append(k) + + for i in range(n_layers): + base = f"layers.{i}." + for suffix in per_layer_expected: + key = base + suffix + if key not in tt_sd: + missing.append(key) + + if missing and strict: + preview = "\n - " + "\n - ".join(missing[:20]) + more = "" if len(missing) <= 20 else f"\n ...and {len(missing)-20} more" + raise KeyError( + "TorchTitan checkpoint is missing keys required for gpt-oss inverse mapping:" + f"{preview}{more}" + ) + return missing # may be useful for logging if strict=False + +def validate_hf_keys(hf_state_dict, model_config, model_name): + """Validate that all expected weight keys exist in the HF state dict.""" + missing_keys = [] + n_layers = model_config.num_hidden_layers + + # Check basic weights + required_keys = [ + "model.embed_tokens.weight", + "lm_head.weight", + "model.norm.weight" + ] + + for key in required_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + # Check layer weights + for layer_idx in range(n_layers): + layer_prefix = f'model.layers.{layer_idx}' + + # Check attention weights + attention_keys = [ + f"{layer_prefix}.self_attn.q_proj.weight", + f"{layer_prefix}.self_attn.k_proj.weight", + f"{layer_prefix}.self_attn.v_proj.weight", + f"{layer_prefix}.self_attn.o_proj.weight", + f"{layer_prefix}.self_attn.q_proj.bias", + f"{layer_prefix}.self_attn.k_proj.bias", + f"{layer_prefix}.self_attn.v_proj.bias", + f"{layer_prefix}.self_attn.o_proj.bias", + f"{layer_prefix}.input_layernorm.weight", + f"{layer_prefix}.post_attention_layernorm.weight", + ] + + for key in attention_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + # Check MoE weights + mlp_keys = [ + f"{layer_prefix}.mlp.router.weight", + f"{layer_prefix}.mlp.router.bias", + f"{layer_prefix}.mlp.experts.gate_up_proj", + f"{layer_prefix}.mlp.experts.gate_up_proj_bias", + f"{layer_prefix}.mlp.experts.down_proj", + f"{layer_prefix}.mlp.experts.down_proj_bias", + ] + + for key in mlp_keys: + if key not in hf_state_dict: + missing_keys.append(key) + + if missing_keys: + logger.error(f"Missing {len(missing_keys)} expected weight keys in HF model:") + for key in missing_keys[:10]: # Show first 10 + logger.error(f" - {key}") + if len(missing_keys) > 10: + logger.error(f" ... and {len(missing_keys) - 10} more") + + # Try to diagnose the issue + logger.info("Available keys in HF model:") + available_keys = list(hf_state_dict.keys()) + for key in available_keys[:20]: # Show first 20 + logger.info(f" - {key}") + if len(available_keys) > 20: + logger.info(f" ... and {len(available_keys) - 20} more") + + raise ValueError(f"HF model '{model_name}' is missing expected weight keys. " + f"This suggests the model architecture doesn't match expectations.") + + logger.info(f"āœ“ Weight key validation passed - found all expected keys") + + +def map_hf_to_torchtitan(hf_state_dict, model_config, max_seq_len=131072, rope_theta=500000.0, model_name="meta-llama/Llama-3.1-8B"): + """Map HuggingFace state dict to TorchTitan format. + + Note: TorchTitan and HuggingFace use different RoPE implementations: + - TorchTitan: Adjacent element pairing with complex arithmetic + - HuggingFace: First/second half pairing with cos/sin arithmetic + + This difference is architectural, not a bug. Converted models will have + slightly different positional encoding but typically minimal impact on performance. + """ + + # Validate that all expected keys exist + validate_hf_keys(hf_state_dict, model_config, model_name) + + n_layers = model_config.num_hidden_layers + n_heads = model_config.num_attention_heads + dim = model_config.hidden_size + dims_per_head = dim // n_heads + + # Fix: Corrected model family detection logic + if "llama" in model_name.lower(): + model_family = "llama3" + elif "qwen" in model_name.lower(): + model_family = "qwen3" + max_seq_len = model_config.max_position_embeddings + rope_theta = model_config.rope_theta + elif "gpt-oss" in model_name.lower(): + model_family = "gptoss" + max_seq_len = model_config.max_position_embeddings + rope_theta = model_config.rope_theta + else: + raise ValueError(f"Unsupported HuggingFace model for conversion: {model_name}") + + # Determine n_kv_heads for GQA models + n_kv_heads = model_config.num_key_value_heads + head_dim = model_config.head_dim + print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}, model_family={model_family}, max_seq_len={max_seq_len}, rope_theta={rope_theta}") + torchtitan_state_dict = {} + + # Convert embeddings and output + torchtitan_state_dict["tok_embeddings.weight"] = hf_state_dict["model.embed_tokens.weight"].clone() + torchtitan_state_dict["output.weight"] = hf_state_dict["lm_head.weight"].clone() + torchtitan_state_dict["norm.weight"] = hf_state_dict["model.norm.weight"].clone() + + def permute(w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Convert layers + for layer_idx in tqdm(range(n_layers), desc="Converting layers"): + hf_layer_prefix = f'model.layers.{layer_idx}' + layer_prefix = f'layers.{layer_idx}' + + wq = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] = wq.clone() + wq_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wq.bias'] = wq_bias.clone() + + wk = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] = wk.clone() + wk_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wk.bias'] = wk_bias.clone() + + wv = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'] = wv.clone() + wv_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wv.bias'] = wv_bias.clone() + + wo = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'] = wo.clone() + wo_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.bias'] + torchtitan_state_dict[f'{layer_prefix}.attention.wo.bias'] = wo_bias.clone() + + sinks = hf_state_dict[f'{hf_layer_prefix}.self_attn.sinks'] + torchtitan_state_dict[f'{layer_prefix}.attention.sinks'] = sinks.clone() + + # MoE weights + mlp1 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_weight'] = mlp1.clone() + + mlp1_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj_bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_bias'] = mlp1_bias.clone() + + mlp2 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_weight'] = mlp2.clone() + + mlp2_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj_bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_bias'] = mlp2_bias.clone() + + # router + gate = hf_state_dict[f'{hf_layer_prefix}.mlp.router.weight'] + torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.weight'] = gate.clone() + router_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.router.bias'] + torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.bias'] = router_bias.clone() + + # # @vwxyzjn: This is technically not needed, but we added here because we haven't figured out + # # how to tell torchtitan to ignore this parameter. + # tokens_per_expert = torch.zeros_like(expert_bias) + # torchtitan_state_dict[f'{layer_prefix}.moe.tokens_per_expert'] = tokens_per_expert.clone() + + # Layer norms + attention_norm = hf_state_dict[f'{hf_layer_prefix}.input_layernorm.weight'] + torchtitan_state_dict[f'{layer_prefix}.attention_norm.weight'] = attention_norm.clone() + ffn_norm = hf_state_dict[f'{hf_layer_prefix}.post_attention_layernorm.weight'] + torchtitan_state_dict[f'{layer_prefix}.ffn_norm.weight'] = ffn_norm.clone() + + # Precompute RoPE frequencies + # NOTE: we no longer precompute RoPE frequencies in TorchTitan + # this `model_config` is HF but needs to be TT (to include e.g. beta_fast) + # torchtitan_state_dict["freqs_cis"] = precompute_freqs_cis(model_config) + + print(f"Converted {len(torchtitan_state_dict)} parameters from HuggingFace to TorchTitan format") + return torchtitan_state_dict + + +def num_layers_from_keys(state_dict): + layer_idxs = [] + pat = re.compile(r"^layers\.(\d+)\.") + for k in state_dict.keys(): + m = pat.match(k) + if m: + layer_idxs.append(int(m.group(1))) + if not layer_idxs: + raise ValueError("Could not find any 'layers..' keys in the TorchTitan state dict.") + return max(layer_idxs) + 1 + +# TODO: correctness of map_torchtitan_to_hf is not yet tested for GPT-OSS +def map_torchtitan_to_hf(torchtitan_state_dict, *, strict=True): + """ + Map TorchTitan (DCP) state dict -> HuggingFace format for *gpt-oss only*. + + This is the exact inverse of your `map_hf_to_torchtitan`: + - No weight permutations. + - Copies biases for q/k/v/o and MoE projections. + - Preserves `.attention.sinks`. + - MoE and router parameters use the same custom names you used on the HF side + (i.e., HF bias keys are `gate_up_proj_bias` / `down_proj_bias`). + + Parameters + ---------- + torchtitan_state_dict : dict[str, Tensor-like] + TorchTitan checkpoint (flat dict). + strict : bool + If True, error on any missing keys. If False, copy what exists and skip missing. + + Returns + ------- + dict[str, Tensor-like] + HuggingFace-formatted state dict. + """ + tt = torchtitan_state_dict + n_layers = num_layers_from_keys(tt) + validate_tt_keys(tt, n_layers, strict=strict) + + hf = {} + + # Top-level + if "tok_embeddings.weight" in tt: hf["model.embed_tokens.weight"] = tt["tok_embeddings.weight"].clone() + if "output.weight" in tt: hf["lm_head.weight"] = tt["output.weight"].clone() + if "norm.weight" in tt: hf["model.norm.weight"] = tt["norm.weight"].clone() + + # Per-layer mappings (exact inverse of your hf->tt) + for i in range(n_layers): + tt_pref = f"layers.{i}" + hf_pref = f"model.layers.{i}" + + # Attention projections (+biases) + m = { + f"{tt_pref}.attention.wq.weight": (f"{hf_pref}.self_attn.q_proj.weight",), + f"{tt_pref}.attention.wq.bias": (f"{hf_pref}.self_attn.q_proj.bias",), + f"{tt_pref}.attention.wk.weight": (f"{hf_pref}.self_attn.k_proj.weight",), + f"{tt_pref}.attention.wk.bias": (f"{hf_pref}.self_attn.k_proj.bias",), + f"{tt_pref}.attention.wv.weight": (f"{hf_pref}.self_attn.v_proj.weight",), + f"{tt_pref}.attention.wv.bias": (f"{hf_pref}.self_attn.v_proj.bias",), + f"{tt_pref}.attention.wo.weight": (f"{hf_pref}.self_attn.o_proj.weight",), + f"{tt_pref}.attention.wo.bias": (f"{hf_pref}.self_attn.o_proj.bias",), + + # Sinks tensor + f"{tt_pref}.attention.sinks": (f"{hf_pref}.self_attn.sinks",), + + # MoE experts (your custom naming on HF side) + f"{tt_pref}.moe.experts.mlp1_weight": (f"{hf_pref}.mlp.experts.gate_up_proj",), + f"{tt_pref}.moe.experts.mlp1_bias": (f"{hf_pref}.mlp.experts.gate_up_proj_bias",), + f"{tt_pref}.moe.experts.mlp2_weight": (f"{hf_pref}.mlp.experts.down_proj",), + f"{tt_pref}.moe.experts.mlp2_bias": (f"{hf_pref}.mlp.experts.down_proj_bias",), + + # Router + f"{tt_pref}.moe.router.gate.weight": (f"{hf_pref}.mlp.router.weight",), + f"{tt_pref}.moe.router.gate.bias": (f"{hf_pref}.mlp.router.bias",), + + # Norms + f"{tt_pref}.attention_norm.weight": (f"{hf_pref}.input_layernorm.weight",), + f"{tt_pref}.ffn_norm.weight": (f"{hf_pref}.post_attention_layernorm.weight",), + } + + for tt_key, (hf_key,) in m.items(): + if tt_key in tt: + hf[hf_key] = tt[tt_key].clone() + elif strict: + raise KeyError(f"Missing expected key in TorchTitan state dict: '{tt_key}'") + + print(f"Converted {len(hf)} parameters from TorchTitan to HuggingFace format (gpt-oss).") + return hf + + +@app.command(name="hf_to_dcp") +@torch.inference_mode() +def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 150000.0, dtype: str = "auto", torchtitan_config: str = "20B"): + """Convert HuggingFace model to TorchTitan DCP format. + + Args: + input_path: HuggingFace model name or path + output_path: Output DCP checkpoint path + max_seq_len: Max sequence length for RoPE + rope_theta: RoPE theta parameter + dtype: Data type to use ("auto" to preserve original, or specific dtype like "float32") + torchtitan_config: TorchTitan model config name (e.g., "16B-A3B", "debugmodel") + """ + # Import TorchTitan configs + try: + from torchtitan.models.gpt_oss import gptoss_configs + except ImportError: + raise ImportError("Cannot import TorchTitan GPT-OSS configs. Make sure you're in the right environment.") + + logger.info(f"Loading model from {input_path}") + + # Load model with original dtype if "auto", otherwise use specified dtype + hf_model = AutoModelForCausalLM.from_pretrained(input_path, torch_dtype=torch.bfloat16) + + # Validate configuration compatibility + logger.info(f"Validating config compatibility with TorchTitan config: {torchtitan_config}") + validate_config_compatibility(hf_model.config, torchtitan_config, gptoss_configs) + + hf_state_dict = hf_model.state_dict() + logger.info(f"Loaded model with dtype: {next(iter(hf_state_dict.values())).dtype}") + + logger.info("Converting weights to TorchTitan format") + torchtitan_state_dict = map_hf_to_torchtitan(hf_state_dict, hf_model.config, max_seq_len, rope_theta, input_path) + + logger.info(f"Writing to DCP at '{output_path}'") + output_path.mkdir(parents=True, exist_ok=True) + storage_writer = DCP.filesystem.FileSystemWriter(output_path, thread_count=8) + DCP.save({"model": torchtitan_state_dict}, storage_writer=storage_writer) + + # Save metadata for reference + metadata = { + "original_hf_model": input_path, + "torchtitan_config": torchtitan_config, + "conversion_time": str(torch.tensor(0).item()), # placeholder + "hf_config": dict(hf_model.config.__dict__), + "torchtitan_config_dict": dict(gptoss_configs[torchtitan_config].__dict__), + } + with open(output_path / "conversion_metadata.json", "w") as f: + import json + json.dump(metadata, f, indent=2, default=str) + + logger.info("Conversion complete!") + logger.info(f"šŸ“‹ Saved conversion metadata to {output_path}/conversion_metadata.json") + logger.info(f"šŸš€ To use in TorchTitan, specify model config: {torchtitan_config}") + + # Final reminder about RoPE differences + if "gpt-oss" in input_path.lower(): + logger.info(f"") + logger.info(f"šŸ”” IMPORTANT: Converted GPT-OSS model uses TorchTitan's RoPE implementation") + logger.info(f" This differs from HuggingFace but is expected behavior") + logger.info(f" See conversion script documentation for details") + + +@app.command(name="dcp_to_hf") +@torch.inference_mode() +def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B"): + """Convert TorchTitan DCP format to HuggingFace model. + + Args: + input_path: Input DCP checkpoint path + output_path: Output HuggingFace model path + max_seq_len: Max sequence length for RoPE + rope_theta: RoPE theta parameter + default_model: Default HuggingFace model for config + """ + from torchtitan.datasets.transformation import get_tokenizer_with_chat_template + from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict + logger.info(f"Loading DCP checkpoint from {input_path}") + + # Load DCP input_path + state_dict = {} + _load_state_dict( + state_dict, + storage_reader=DCP.filesystem.FileSystemReader(input_path), + planner=_EmptyStateDictLoadPlanner(), + no_dist=True, + ) + torchtitan_state_dict = state_dict["model"] + logger.info("Converting weights to HuggingFace format") + hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta) + + # Create HuggingFace config + hf_config = AutoConfig.from_pretrained(default_model) + + # Create and load model + logger.info("Creating HuggingFace model") + # tokenizer = AutoTokenizer.from_pretrained(default_model) + tokenizer = get_tokenizer_with_chat_template(default_model, "tulu", override=True) + hf_model = AutoModelForCausalLM.from_pretrained(default_model) + + # load state dict + logger.info("Loading state dict") + hf_model.load_state_dict(hf_state_dict, strict=True) + + # Save model + logger.info(f"Saving model to {output_path}") + output_path.mkdir(parents=True, exist_ok=True) + hf_model.save_pretrained(output_path) + tokenizer.save_pretrained(output_path) + logger.info("Conversion complete!") + + +if __name__ == "__main__": + init_logger() + app.cli() diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml new file mode 100644 index 0000000000..878e478ff5 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -0,0 +1,73 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS debug training" +print_args = false +use_for_integration_test = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +tokenizer_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 1 +compile = false +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 2 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml new file mode 100644 index 0000000000..81908972ad --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml @@ -0,0 +1,70 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS 120B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 10 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "120B" +tokenizer_path = "./assets/tokenizer/GPT-OSS" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 2.2e-5 + +[training] +local_batch_size = 4 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 10_000 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 8 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 500 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml new file mode 100644 index 0000000000..88d1c4d27f --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml @@ -0,0 +1,70 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "GPT-OSS 20B model training" +print_args = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 10 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "20B" +tokenizer_path = "./assets/tokenizer/GPT-OSS" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +lr_min = 2.2e-5 + +[training] +local_batch_size = 8 +seq_len = 4096 +max_norm = 1.0 # grad norm clipping +steps = 1000 +compile = false +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 + +[checkpoint] +enable_checkpoint = false +folder = "checkpoint" +interval = 10 +last_save_model_weights_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] +moe_fqns = ["experts"] diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d2..a273ac563c 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -78,15 +78,80 @@ def __init__( def mask_key(self) -> FLEX_ATTN_MASK_T: return (self.attn_mask_type, self.fixed_block_size) - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: float | None = None, - ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + def forward(self, q, k, v, sink_weights=None, sliding_window=0, enable_gqa=False): + """ + q : (B, H_q, S_q, D) + k : (B, H_kv, S_kv, D) -- without sink + v : (B, H_kv, S_kv, D) + sink_weights : (H_q,) or (H, M) -- broadcast to all queries + sliding_window : int + enable_gqa : bool + """ + if sink_weights is None: + block_mask = FlexAttention.block_masks[self.mask_key] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) + + B, H_q, S_q, D = q.shape + _, H_kv, S_kv, _ = k.shape + sink_idx = S_kv # sink occupies final key slot + + sink_k = k.new_zeros(B, H_kv, 1, D) # this needn't be 0's since it's overwritten + sink_v = v.new_zeros(B, H_kv, 1, D) # 0 value nullifies sink weight in output + + k_ext = torch.cat([k, sink_k], dim=2) + v_ext = torch.cat([v, sink_v], dim=2) + + # masks ensure sinks are included in softmax + if sliding_window is not None and sliding_window > 0: + mask_mod = FlexAttention._get_sliding_window_with_sink_mask_mod(sliding_window, sink_idx) + else: + mask_mod = FlexAttention._get_causal_with_sink_mask_mod(sink_idx) + + block_mask = FlexAttention.compiled_create_block_mask( + mask_mod, B, H_q, S_q, S_kv+1 + ) + + # overwrite the dummy sink scores with actual sink weights + def score_mod(score, b, h_q, q_idx, kv_idx): + return torch.where( + kv_idx == sink_idx, + sink_weights[h_q].to(score.dtype) + 0.0, # cast + keep grad + score + ) + + return FlexAttention.flex_attn( + q, k_ext, v_ext, + block_mask=block_mask, + score_mod=score_mod, + enable_gqa=enable_gqa + ) + + @staticmethod + def _get_causal_with_sink_mask_mod(sink_idx): + """ + Returns a mask_mod function that + - only allows kv_idx ≤ q_idx (causal) + - or if kv_idx == sink_idx (always allow the sink) + """ + orig = FlexAttention._get_causal_mask_mod() + def causal_with_sink(b, h, q_idx, kv_idx): + return orig(b, h, q_idx, kv_idx) | (kv_idx == sink_idx) + return causal_with_sink + + @staticmethod + def _get_sliding_window_with_sink_mask_mod(window: int, sink_idx: int): + """ + Returns a mask_mod function that + - only allows kv_idx ≤ q_idx (causal) + - and only if (q_idx - kv_idx) ≤ window + - or if kv_idx == sink_idx (always allow the sink) + """ + def sliding_mod(b, h, q_idx, kv_idx): + # causal within window + keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window) + # always allow the sink slot + return keep | (kv_idx == sink_idx) + return sliding_mod @staticmethod def _get_causal_mask_mod() -> _mask_mod_signature: From 371f20486fcb540cae03f82a05a8c0c02ccd6952 Mon Sep 17 00:00:00 2001 From: Rohan Pandey Date: Tue, 12 Aug 2025 16:49:44 -0700 Subject: [PATCH 02/18] clean up tentative licensing --- torchtitan/experiments/gpt_oss/__init__.py | 5 ----- torchtitan/experiments/gpt_oss/model/args.py | 5 ----- torchtitan/experiments/gpt_oss/model/model.py | 3 --- torchtitan/experiments/gpt_oss/model/moe.py | 3 --- 4 files changed, 16 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index 67a74c124d..14c3600dde 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -1,10 +1,5 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index 227f24ddc2..e91441092f 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -1,10 +1,5 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# -# Copyright (c) Meta Platforms, Inc. All Rights Reserved. from dataclasses import dataclass diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index 835816c2a0..2fea0bf2c8 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -1,6 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py index 667d329b93..1bbd7a838a 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -1,6 +1,3 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. From 4957bb0a7695ad712ae674e9023c3c81d0cc8135 Mon Sep 17 00:00:00 2001 From: Rohan Pandey Date: Thu, 4 Sep 2025 19:06:58 -0700 Subject: [PATCH 03/18] training fixes: expert load balancing, TP for sinks + experts, EP works but reduces mfu for 20b --- torchtitan/experiments/gpt_oss/__init__.py | 10 +- .../gpt_oss/infra/expert_parallel.py | 297 ++++++++++++ .../experiments/gpt_oss/infra/optimizer.py | 67 +++ .../experiments/gpt_oss/infra/parallelize.py | 431 ++++++++++++++++++ torchtitan/experiments/gpt_oss/model/args.py | 3 +- torchtitan/experiments/gpt_oss/model/model.py | 15 +- torchtitan/experiments/gpt_oss/model/moe.py | 62 ++- 7 files changed, 867 insertions(+), 18 deletions(-) create mode 100644 torchtitan/experiments/gpt_oss/infra/expert_parallel.py create mode 100644 torchtitan/experiments/gpt_oss/infra/optimizer.py create mode 100644 torchtitan/experiments/gpt_oss/infra/parallelize.py diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index 14c3600dde..715ce943e0 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -5,7 +5,7 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.experiments.llama4.optimizer import build_llama4_optimizers +from .infra.optimizer import build_gptoss_optimizers from torchtitan.protocols.train_spec import register_train_spec, TrainSpec @@ -25,12 +25,14 @@ "debugmodel": GptOssModelArgs( hidden_size=256, num_hidden_layers=4, + use_flex_attn=False, + use_grouped_mm=False, ), - "20B": GptOssModelArgs( + "20b": GptOssModelArgs( num_hidden_layers=24, num_local_experts=32, ), - "120B": GptOssModelArgs( + "120b": GptOssModelArgs( num_hidden_layers=36, num_local_experts=128, ), @@ -44,7 +46,7 @@ config=gptoss_configs, parallelize_fn=parallelize_gptoss, pipelining_fn=None, - build_optimizers_fn=build_llama4_optimizers, # use optimizer hooks to update expert weights + build_optimizers_fn=build_gptoss_optimizers, # use optimizer hooks to update expert weights build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py new file mode 100644 index 0000000000..e47bdeec58 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -0,0 +1,297 @@ +from functools import partial +from typing import Callable + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._functional_collectives import all_to_all_single_autograd +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.placement_types import Placement + + +# implementation of Tensor Parallel for the GroupedExperts in MoE +class TensorParallel(ParallelStyle): + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "mlp1_weight", nn.Parameter(distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)])) + ) # Column-wise sharding + module.register_parameter( + "mlp1_bias", + nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), + ) # Column-wise sharding + module.register_parameter( + "mlp2_weight", + nn.Parameter(distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)])), + ) # Row-wise sharding + module.register_parameter( + "mlp2_bias", + nn.Parameter(distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()])), + ) # Replicate + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + ) + + +# NOTE: This is to achieve replicate computation on the gate module in the MoE router. +# It does nothing other than (1) setting the module parameters as DTensors on the given mesh +# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. +# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, +# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. +class NoParallel(ParallelStyle): + def __init__( + self, + *, + input_layout: Placement | None = None, + output_layout: Placement | None = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layout = input_layout or Replicate() + self.output_layout = output_layout or Replicate() + self.desired_input_layout = Replicate() + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, (input_layout,), run_check=False + ) + + if input_layout != desired_input_layout: + input_tensor = input_tensor.redistribute( + placements=(desired_input_layout,), async_op=True + ) + return (input_tensor, *inputs[1:]) + + @staticmethod + def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): + if outputs.placements != (output_layout,): + outputs = outputs.redistribute(placements=(output_layout,), async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + None, + partial( + self._prepare_input_fn, self.input_layout, self.desired_input_layout + ), + partial(self._prepare_output_fn, self.output_layout, self.use_local_output), + ) + + +class ExpertParallel(ParallelStyle): + def __init__(self): + super().__init__() + self.input_splits = None + self.output_splits = None + + # performing all-to-all dispatch on the input + def _token_dispatch(self, mod, inputs, device_mesh): + # annotate module input placements/sharding with input_layouts + routed_input, num_tokens_per_expert = inputs + + # generate the input splits and output splits for all-to-all + with torch.no_grad(): + num_tokens_per_expert_group = num_tokens_per_expert.new_empty( + num_tokens_per_expert.shape[0] + ) + dist.all_to_all_single( + num_tokens_per_expert_group, + num_tokens_per_expert, + group=device_mesh.get_group(), + ) + # NOTE: this would incur a device-to-host sync + self.input_splits = ( + num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() + ) + self.output_splits = ( + num_tokens_per_expert_group.view(device_mesh.shape[0], -1) + .sum(dim=1) + .tolist() + ) + + # perform all-to-all + routed_input = all_to_all_single_autograd( + routed_input, + self.output_splits, + self.input_splits, + device_mesh.get_group(), + ) + + # NOTE: After this all-to-all, the routed input is put on proper EP rank. + # However, the num_tokens_per_expert_group is not of the final target format + # [#tokens for local expert 0, #tokens for local expert 1, ...] + # Rather, it is of the format + # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., + # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] + # We need to perform another shuffle to get the correct format -- this is done via the function + # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens + # each expert gets locally is a multiple of ALIGN_SIZE_M. + + return routed_input, num_tokens_per_expert_group + + @staticmethod + def _partition_fn(name, mod, device_mesh): + # shard on the expert dimension + for name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(name, dist_param) + + # performing all-to-all combine on the output + def _token_combine(self, mod, routed_output, device_mesh): + routed_output = all_to_all_single_autograd( + routed_output, + self.input_splits, + self.output_splits, + device_mesh.get_group(), + ) + return routed_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=ExpertParallel._partition_fn, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +# This class is for dp2ep with TP (without TP we can just use ExpertParallel) +class ExpertTensorParallel(ExpertParallel): + def __init__( + self, + tp_mesh: DeviceMesh, + ep_mesh: DeviceMesh, + ): + super().__init__() + # TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, + # as DeviceMesh doesn't support slicing from a submesh. + self.tp_mesh = tp_mesh + self.ep_mesh = ep_mesh + + def _token_dispatch(self, mod, inputs, device_mesh): + # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_dispatch(mod, inputs, self.ep_mesh) + + def _partition_fn_2d(self, name, mod, ep_tp_mesh): + mod.register_parameter( + "mlp1_weight", + nn.Parameter(distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + mod.register_parameter( + "mlp1_bias", + nn.Parameter(distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding + mod.register_parameter( + "mlp2_weight", + nn.Parameter(distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)])), + ) # Column-wise sharding + mod.register_parameter( + "mlp2_bias", + nn.Parameter(distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding + + def _token_combine(self, mod, routed_output, device_mesh): + # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh + return super()._token_combine(mod, routed_output, self.ep_mesh) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn_2d, + input_fn=self._token_dispatch, + output_fn=self._token_combine, + ) + + +def expert_parallel(func: Callable) -> Callable: + """ + This is a wrapper applied to the GroupedExperts computation, serving + the following three purposes: + 1. Convert parameters from DTensors to plain Tensors, to work with + dynamic-shape inputs which cannot be easily expressed as DTensors. + 2. In Expert Parallel, apply the generate_permute_indices kernel to + permute the inputs to be ordered by local experts (see the _token_dispatch + function in ExpertParallel) and permute the outputs back. + 3. In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices + kernel also helps achieve this via padding, without incurring synchronization + between device and host. Note that this will create side effects when wrapping + the for-loop implementation of GroupedExperts, as it does not need padding. + + Among the above: + 1 and 2 are needed only when expert_parallel_degree > 1. + 3 is needed even for single-device computation. + 2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. + """ + + def wrapper( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor | None = None, + ) -> torch.Tensor: + if isinstance(mlp1_weight, DTensor): + mlp1_weight = mlp1_weight.to_local() + mlp1_bias = mlp1_bias.to_local() + mlp2_weight = mlp2_weight.to_local() + mlp2_bias = mlp2_bias.to_local() + + if num_tokens_per_expert is not None: + from torchtitan.experiments.kernels.moe.indices import ( + generate_permute_indices, + ) + + experts_per_ep_rank = mlp1_weight.shape[0] + num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank + + ALIGN_SIZE_M = 16 + with torch.no_grad(): + ( + permuted_indices, + num_tokens_per_expert, + _, # offsets, + ) = generate_permute_indices( + num_tokens_per_expert, + experts_per_ep_rank, + num_ep_ranks, + x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, + ALIGN_SIZE_M, + ) + + x = torch.vstack((x, x.new_zeros((x.shape[-1])))) + input_shape = x.shape + x = x[permuted_indices, :] + + out = func(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, x, num_tokens_per_expert) + + if num_tokens_per_expert is not None: + out_unpermuted = out.new_empty(input_shape) + out_unpermuted[permuted_indices, :] = out + out = out_unpermuted[:-1] + + return out + + return wrapper diff --git a/torchtitan/experiments/gpt_oss/infra/optimizer.py b/torchtitan/experiments/gpt_oss/infra/optimizer.py new file mode 100644 index 0000000000..de8537032d --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/optimizer.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.components.ft import FTManager +from torchtitan.components.optimizer import build_optimizers, OptimizersContainer +from torchtitan.config_manager import JobConfig +from torchtitan.distributed import ParallelDims + + +# for MoE auxiliary-loss-free load balancing +def _update_expert_bias( + model_parts: list[nn.Module], + world_mesh: dict[str, DeviceMesh], + parallel_dims: ParallelDims, +): + dp_cp_mesh = world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None + # TODO: Currently this sync is blocking (thus exposed) and happens on the + # default compute stream. Need to assess if this is OK performance-wise. + for model_part in model_parts: + for transformer_block in model_part.layers.values(): + moe = transformer_block.moe + if moe.load_balance_coeff is None: + return + + if dp_cp_mesh is not None: + torch.distributed.all_reduce( + moe.tokens_per_expert, group=dp_cp_mesh.get_group() + ) + + with torch.no_grad(): + expert_bias_delta = moe.load_balance_coeff * torch.sign( + moe.tokens_per_expert.mean() - moe.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + moe.expert_bias.add_(expert_bias_delta) + moe.tokens_per_expert.zero_() + + +def build_gptoss_optimizers( + model_parts: list[nn.Module], + job_config: JobConfig, + parallel_dims: ParallelDims, + world_mesh: DeviceMesh, + ft_manager: FTManager, +) -> OptimizersContainer: + optimizers = build_optimizers( + model_parts=model_parts, + job_config=job_config, + parallel_dims=parallel_dims, + world_mesh=world_mesh, + ft_manager=ft_manager, + ) + + optimizers.register_step_pre_hook( + lambda *args, **kwargs: _update_expert_bias( + model_parts, world_mesh=world_mesh, parallel_dims=parallel_dims + ) + ) + + return optimizers diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py new file mode 100644 index 0000000000..47ad01b99e --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -0,0 +1,431 @@ +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard, distribute_tensor +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +if torch.__version__ >= "2.9": + from torch.distributed.tensor.parallel import PrepareModuleInputOutput +else: + print(f"Since torch version {torch.__version__} < 2.9, PrepareModuleInputOutput is not available and MoE EP TP will fail.") + +from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.tools.logging import logger + +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Partial, Replicate, Shard + +from .expert_parallel import ( + ExpertParallel, + ExpertTensorParallel, + NoParallel, + TensorParallel, +) + + +# Adapted from llama4/infra/parallelize.py +def parallelize_gptoss( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + if parallel_dims.tp_enabled: + if job_config.parallelism.enable_async_tensor_parallel: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, async TP is not tested for gptoss. \ + torch.compile is not supported yet, which is required for async TP." + ) + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, float8 tensorwise TP is not tested for gptoss" + ) + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=parallel_dims.loss_parallel_enabled, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled and parallel_dims.ep_enabled + else None + ), + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + if job_config.training.compile: + raise NotImplementedError("torch.compile is not supported yet for gptoss") + + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if dp_mod_ep_mesh_dim_names + else None + ), + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + dp_mesh = world_mesh + apply_ddp( + model, + dp_mesh, + enable_compile=job_config.training.compile, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), + ), + # use_local_output=False make the output to be a DTensor instead of a plain Tensor + "attention.wkv_a": NoParallel(use_local_output=False), + "attention.wkv_b": colwise_parallel(use_local_output=False), + "attention.kv_norm": NoParallel(use_local_output=False), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + if transformer_block.attention.q_lora_rank == 0: + layer_plan.update( + { + "attention.wq": colwise_parallel( + use_local_output=False + ), # This is only used when q_lora_rank==0 + } + ) + else: + layer_plan.update( + { + "attention.wq_a": NoParallel(use_local_output=False), + "attention.wq_b": colwise_parallel(use_local_output=False), + "attention.q_norm": NoParallel(use_local_output=False), + } + ) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + # shard attention.sinks across heads + attn = transformer_block.attention + attn.register_parameter( + "sinks", + nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + dp_mod_ep_mesh: DeviceMesh | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + for layer_id, transformer_block in model.layers.items(): + if reshard_after_forward_policy == "always": + reshard_after_forward = True + elif reshard_after_forward_policy == "never": + reshard_after_forward = False + elif reshard_after_forward_policy == "default": + if pp_enabled: + # For PP, do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = False + else: + # As an optimization, do not reshard after forward for the last + # transformer block since FSDP would prefetch it immediately + reshard_after_forward = int(layer_id) < len(model.layers) - 1 + else: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + # NOTE: in an MoE layer, the router and the shared experts + # are sharded together with the TransformerBlock + if dp_mod_ep_mesh: + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fully_shard( + transformer_block.moe.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + ) + + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, +): + for transformer_block in model.layers.values(): + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + } + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + # if ep_mesh is not None: + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = TensorParallel() + elif tp_mesh is None: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index e91441092f..63c2b6bb82 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -64,6 +64,7 @@ class GptOssModelArgs(BaseModelArgs): num_local_experts: int = 32 num_experts_per_tok: int = 4 use_grouped_mm: bool = True + load_balance_coeff: float | None = 1e-3 # Multi-Head Latent Attention (MLA) head_dim: int = 64 num_attention_heads: int = 64 @@ -82,7 +83,7 @@ def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> Non """ Update the model_config config from the given job config. """ - # self.vocab_size = tokenizer.vocab_size # TODO: add tiktokenizer support? + # self.vocab_size = tokenizer.vocab_size self.max_seq_len = job_config.training.seq_len def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index 2fea0bf2c8..c16cb53274 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -6,6 +6,7 @@ import torch from torch import nn +from torch.distributed.tensor import DTensor from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol @@ -323,7 +324,12 @@ def forward( v = v.transpose(1, 2).contiguous() if self.use_flex_attn: - output = self.attn(q, k, v, self.sinks, sliding_window=self.sliding_window, enable_gqa=True) + output = self.attn( + q, k, v, + self.sinks.to_local() if isinstance(self.sinks, DTensor) else self.sinks, + sliding_window=self.sliding_window, + enable_gqa=True, + ) else: output = self.attn( q, k, v, self.sinks, @@ -346,8 +352,9 @@ def init_weights(self, init_std: float): self.wv, ] + nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std) for linear in linear_list: - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) # TODO: statically init the mask using train.seq_len @@ -419,7 +426,7 @@ def __init__(self, model_args: GptOssModelArgs): self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.num_hidden_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16) - convert_submodules_to_bf16(self.layers[str(layer_id)]) + # convert_submodules_to_bf16(self.layers[str(layer_id)]) self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) self.output = nn.Linear( @@ -430,7 +437,7 @@ def __init__(self, model_args: GptOssModelArgs): ) self.model_args = model_args self.init_weights() - convert_submodules_to_bf16(self) + # convert_submodules_to_bf16(self) def init_weights(self, buffer_device: torch.device | None = None) -> None: buffer_device = buffer_device or self.freqs_cis.device diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py index 1bbd7a838a..c056819758 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -2,9 +2,10 @@ # LICENSE file in the root directory of this source tree. import torch +from torch.distributed.tensor import DTensor import torch.nn.functional as F from torch import nn -from torchtitan.experiments.llama4.infra.expert_parallel import expert_parallel +from torchtitan.models.gpt_oss.infra.expert_parallel import expert_parallel from .args import GptOssModelArgs @@ -49,7 +50,7 @@ def forward( # TODO: keeping this for-loop implementation for comparison # and readability, may remove later - # @expert_parallel + @expert_parallel @staticmethod def _run_experts_for_loop( mlp1_weight: torch.Tensor, @@ -91,7 +92,7 @@ def _run_experts_for_loop( return out - # @expert_parallel # TODO: e-sharding currently breaks shapes + # @expert_parallel # NOTE: EP currently reduces 20B MFU from 17.8% to 16.5%! @staticmethod def _run_experts_grouped_mm( mlp1_weight: torch.Tensor, @@ -105,24 +106,37 @@ def _run_experts_grouped_mm( offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) # grouped mm between a 2D tensor and a 3D tensor assert x.dim() == 2 + num_tokens_per_expert_long = num_tokens_per_expert.to(torch.long) else: offsets = None # fall back to regular bmm between 3D tensors assert x.dim() == 3 - num_tokens_per_expert_long = num_tokens_per_expert.to(torch.long) + if isinstance(mlp1_weight, DTensor): + mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias = mlp1_weight.to_local(), mlp1_bias.to_local(), mlp2_weight.to_local(), mlp2_bias.to_local() h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets) - h += mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + if offsets is not None: + b1 = mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: + b1 = torch.cat([b1, b1.new_zeros((tail_slack, b1.shape[-1]))], dim=0) + h = h + b1.to(h.dtype) + h = swiglu(h) h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) - h += mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + if offsets is not None: + b2 = mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: + b2 = torch.cat([b2, b2.new_zeros((tail_slack, b2.shape[-1]))], dim=0) + h = h + b2.to(h.dtype) return h def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=init_std) nn.init.trunc_normal_(self.mlp2_weight, mean=0.0, std=init_std) nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std) @@ -178,6 +192,9 @@ def forward( # scores shape (bs*slen, num_experts) router_logits = self.gate(x) + if expert_bias is not None: + router_logits = router_logits + expert_bias + # top scores shape (bs*slen, top_k) top_scores, selected_experts_indices = torch.topk( router_logits, k=self.top_k, dim=1 @@ -228,6 +245,21 @@ def __init__(self, model_args: GptOssModelArgs): num_experts=num_experts, top_k=top_k, ) + self.load_balance_coeff = model_args.load_balance_coeff + if self.load_balance_coeff is not None: + assert self.load_balance_coeff > 0.0 + self.register_buffer( + "expert_bias", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + self.register_buffer( + "tokens_per_expert", + torch.zeros(num_experts, dtype=torch.float32), + persistent=True, + ) + else: + self.expert_bias = None def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -245,7 +277,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: top_scores, token_indices, num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim)) + ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) + + if self.load_balance_coeff is not None and torch.is_grad_enabled(): + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) @@ -278,3 +314,11 @@ def init_weights( ): self.experts.init_weights(init_std) self.router.init_weights(init_std) + if self.load_balance_coeff is not None: + with torch.device(buffer_device): + self.expert_bias = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) + self.tokens_per_expert = torch.zeros( + self.experts.num_experts, dtype=torch.float32 + ) From c3fc9e73dfc93b89bcd89640b20705655bbffed9 Mon Sep 17 00:00:00 2001 From: Rohan Pandey Date: Thu, 4 Sep 2025 19:33:32 -0700 Subject: [PATCH 04/18] only assert sdpa backends if using sdpa; improve conversion script --- torchtitan/distributed/utils.py | 3 +-- torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py | 7 ++++--- torchtitan/train.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 72700fb1ab..75e4fe4ed5 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -186,7 +186,7 @@ def create_context_parallel_ctx( def get_train_context( - enable_loss_parallel: bool, enable_compiled_autograd: bool + enable_loss_parallel: bool, enable_compiled_autograd: bool, use_sdpa: bool = True ) -> Generator[None, None, None]: @contextlib.contextmanager def context(cp_context: Generator[None, None, None] | None = None): @@ -206,7 +206,6 @@ def context(cp_context: Generator[None, None, None] | None = None): if SDPBackend.MATH in ScaledDotProductAttention.backends: ScaledDotProductAttention.backends.remove(SDPBackend.MATH) - stack.enter_context(cp_context) yield diff --git a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py index f69d5898d7..81829c69b9 100644 --- a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py +++ b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py @@ -8,6 +8,7 @@ uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py dcp-to-hf --input-path gptoss_dcp/ --output-path gptoss_hf/ """ +import re import tempfile from pathlib import Path from typing import Union @@ -397,7 +398,7 @@ def map_torchtitan_to_hf(torchtitan_state_dict, *, strict=True): @app.command(name="hf_to_dcp") @torch.inference_mode() -def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 150000.0, dtype: str = "auto", torchtitan_config: str = "20B"): +def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 150000.0, dtype: str = "auto", torchtitan_config: str = "20b"): """Convert HuggingFace model to TorchTitan DCP format. Args: @@ -460,7 +461,7 @@ def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131 @app.command(name="dcp_to_hf") @torch.inference_mode() -def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B"): +def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 150000.0, default_model: str = "openai/gpt-oss-20b"): """Convert TorchTitan DCP format to HuggingFace model. Args: @@ -485,7 +486,7 @@ def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 13 ) torchtitan_state_dict = state_dict["model"] logger.info("Converting weights to HuggingFace format") - hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta) + hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict) # Create HuggingFace config hf_config = AutoConfig.from_pretrained(default_model) diff --git a/torchtitan/train.py b/torchtitan/train.py index 008a4eebba..29ecb428fa 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -329,6 +329,7 @@ def __init__(self, job_config: JobConfig): self.train_context = dist_utils.get_train_context( loss_parallel_enabled, parallelism_config.enable_compiled_autograd, + use_sdpa=not getattr(model_args, "use_flex_attn", False), ) self.maybe_enable_amp = dist_utils.maybe_enable_amp( parallel_dims, From b696028f1118898489ac929022cb3ec3eb2dd595 Mon Sep 17 00:00:00 2001 From: Rohan Pandey Date: Thu, 4 Sep 2025 19:50:18 -0700 Subject: [PATCH 05/18] fixed conversion script with param by param --- .../gpt_oss/scripts/convert_gptoss.py | 575 +++++++++++------- 1 file changed, 361 insertions(+), 214 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py index 81829c69b9..59c15ab944 100644 --- a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py +++ b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py @@ -11,11 +11,16 @@ import re import tempfile from pathlib import Path -from typing import Union +from typing import Union, Tuple, Optional import torch import torch.distributed.checkpoint as DCP +from torch.distributed.checkpoint.format_utils import dcp_to_torch_save +from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner +from torch.distributed.checkpoint.state_dict_loader import _load_state_dict +from torchtitan.datasets.transformation import get_tokenizer_with_chat_template from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaConfig +from torchtitan.models.llama3.model import precompute_freqs_cis from tqdm import tqdm from tyro.extras import SubcommandApp @@ -24,96 +29,6 @@ app = SubcommandApp() -def validate_config_compatibility(hf_config, torchtitan_config_name, torchtitan_configs): - """Validate that HF config is compatible with TorchTitan config.""" - if torchtitan_config_name not in torchtitan_configs: - available = list(torchtitan_configs.keys()) - raise ValueError(f"TorchTitan config '{torchtitan_config_name}' not found. Available: {available}") - - tt_config = torchtitan_configs[torchtitan_config_name] - - # Critical configuration checks with proper field mappings - checks = [ - ("vocab_size", "vocab_size"), - ("hidden_size", "hidden_size"), - ("num_hidden_layers", "num_hidden_layers"), - ("head_dim", "head_dim"), - ("num_attention_heads", "num_attention_heads"), - ("num_key_value_heads", "num_key_value_heads"), - ("sliding_window", "sliding_window"), - ("num_local_experts", "num_local_experts"), - ("num_experts_per_tok", "num_experts_per_tok"), - ("rope_theta", "rope_theta"), - # ("rope_scaling.factor", "rope_factor"), - # ("rope_scaling.beta_fast", "beta_fast"), - # ("rope_scaling.beta_slow", "beta_slow"), - ] - - mismatches = [] - warnings = [] - - for hf_attr, tt_attr in checks: - hf_val = getattr(hf_config, hf_attr, None) - tt_val = getattr(tt_config, tt_attr, None) - - if hf_val != tt_val: - mismatches.append(f"{hf_attr}: HF={hf_val} vs TT.{tt_attr}={tt_val}") - - if mismatches: - raise ValueError(f"Config mismatch for {torchtitan_config_name}:\n" + "\n".join(mismatches)) - - if warnings: - print(f"āš ļø Configuration warnings for {torchtitan_config_name}:") - for warning in warnings: - print(f" {warning}") - print(" These differences might affect model behavior but won't prevent conversion.") - - print(f"āœ“ Configuration validation passed for {torchtitan_config_name}") - return tt_config - -def validate_tt_keys(tt_sd, n_layers, strict=True): - """Ensure the TorchTitan dict looks like gpt-oss as encoded in hf->tt mapping.""" - top_expected = [ - "tok_embeddings.weight", - "output.weight", - "norm.weight", - ] - per_layer_expected = [ - # attention projections + biases + sinks - "attention.wq.weight", "attention.wq.bias", - "attention.wk.weight", "attention.wk.bias", - "attention.wv.weight", "attention.wv.bias", - "attention.wo.weight", "attention.wo.bias", - "attention.sinks", - # MoE experts (mlp1/2) + biases - "moe.experts.mlp1_weight", "moe.experts.mlp1_bias", - "moe.experts.mlp2_weight", "moe.experts.mlp2_bias", - # Router - "moe.router.gate.weight", "moe.router.gate.bias", - # Norms - "attention_norm.weight", "ffn_norm.weight", - ] - - missing = [] - for k in top_expected: - if k not in tt_sd: - missing.append(k) - - for i in range(n_layers): - base = f"layers.{i}." - for suffix in per_layer_expected: - key = base + suffix - if key not in tt_sd: - missing.append(key) - - if missing and strict: - preview = "\n - " + "\n - ".join(missing[:20]) - more = "" if len(missing) <= 20 else f"\n ...and {len(missing)-20} more" - raise KeyError( - "TorchTitan checkpoint is missing keys required for gpt-oss inverse mapping:" - f"{preview}{more}" - ) - return missing # may be useful for logging if strict=False def validate_hf_keys(hf_state_dict, model_config, model_name): """Validate that all expected weight keys exist in the HF state dict.""" @@ -306,99 +221,351 @@ def permute(w, n_heads_arg, dim1=None, dim2=None): return torchtitan_state_dict -def num_layers_from_keys(state_dict): - layer_idxs = [] - pat = re.compile(r"^layers\.(\d+)\.") - for k in state_dict.keys(): - m = pat.match(k) - if m: - layer_idxs.append(int(m.group(1))) - if not layer_idxs: - raise ValueError("Could not find any 'layers..' keys in the TorchTitan state dict.") - return max(layer_idxs) + 1 +def map_torchtitan_to_hf_per_param(name: str, weight: torch.Tensor, model_family: str = "llama3") -> Tuple[Optional[str], Optional[torch.Tensor]]: + """Map a single TorchTitan parameter to HuggingFace format. -# TODO: correctness of map_torchtitan_to_hf is not yet tested for GPT-OSS -def map_torchtitan_to_hf(torchtitan_state_dict, *, strict=True): - """ - Map TorchTitan (DCP) state dict -> HuggingFace format for *gpt-oss only*. - - This is the exact inverse of your `map_hf_to_torchtitan`: - - No weight permutations. - - Copies biases for q/k/v/o and MoE projections. - - Preserves `.attention.sinks`. - - MoE and router parameters use the same custom names you used on the HF side - (i.e., HF bias keys are `gate_up_proj_bias` / `down_proj_bias`). - - Parameters - ---------- - torchtitan_state_dict : dict[str, Tensor-like] - TorchTitan checkpoint (flat dict). - strict : bool - If True, error on any missing keys. If False, copy what exists and skip missing. - - Returns - ------- - dict[str, Tensor-like] - HuggingFace-formatted state dict. + Args: + name: Parameter name in TorchTitan format + weight: Parameter tensor + model_family: Model family ("llama3", "qwen3", or "gptoss") + + Returns: + Tuple of (hf_name, hf_weight) or (None, None) if parameter should be skipped """ - tt = torchtitan_state_dict - n_layers = num_layers_from_keys(tt) - validate_tt_keys(tt, n_layers, strict=strict) - - hf = {} - - # Top-level - if "tok_embeddings.weight" in tt: hf["model.embed_tokens.weight"] = tt["tok_embeddings.weight"].clone() - if "output.weight" in tt: hf["lm_head.weight"] = tt["output.weight"].clone() - if "norm.weight" in tt: hf["model.norm.weight"] = tt["norm.weight"].clone() - - # Per-layer mappings (exact inverse of your hf->tt) - for i in range(n_layers): - tt_pref = f"layers.{i}" - hf_pref = f"model.layers.{i}" - - # Attention projections (+biases) - m = { - f"{tt_pref}.attention.wq.weight": (f"{hf_pref}.self_attn.q_proj.weight",), - f"{tt_pref}.attention.wq.bias": (f"{hf_pref}.self_attn.q_proj.bias",), - f"{tt_pref}.attention.wk.weight": (f"{hf_pref}.self_attn.k_proj.weight",), - f"{tt_pref}.attention.wk.bias": (f"{hf_pref}.self_attn.k_proj.bias",), - f"{tt_pref}.attention.wv.weight": (f"{hf_pref}.self_attn.v_proj.weight",), - f"{tt_pref}.attention.wv.bias": (f"{hf_pref}.self_attn.v_proj.bias",), - f"{tt_pref}.attention.wo.weight": (f"{hf_pref}.self_attn.o_proj.weight",), - f"{tt_pref}.attention.wo.bias": (f"{hf_pref}.self_attn.o_proj.bias",), - - # Sinks tensor - f"{tt_pref}.attention.sinks": (f"{hf_pref}.self_attn.sinks",), - - # MoE experts (your custom naming on HF side) - f"{tt_pref}.moe.experts.mlp1_weight": (f"{hf_pref}.mlp.experts.gate_up_proj",), - f"{tt_pref}.moe.experts.mlp1_bias": (f"{hf_pref}.mlp.experts.gate_up_proj_bias",), - f"{tt_pref}.moe.experts.mlp2_weight": (f"{hf_pref}.mlp.experts.down_proj",), - f"{tt_pref}.moe.experts.mlp2_bias": (f"{hf_pref}.mlp.experts.down_proj_bias",), - - # Router - f"{tt_pref}.moe.router.gate.weight": (f"{hf_pref}.mlp.router.weight",), - f"{tt_pref}.moe.router.gate.bias": (f"{hf_pref}.mlp.router.bias",), - - # Norms - f"{tt_pref}.attention_norm.weight": (f"{hf_pref}.input_layernorm.weight",), - f"{tt_pref}.ffn_norm.weight": (f"{hf_pref}.post_attention_layernorm.weight",), + # Skip freqs_cis as it's computed dynamically in HF + if name == "freqs_cis": + return None, None + + assert model_family in ("llama3", "qwen3", "gptoss"), f"Unsupported model family: {model_family}" + + # HuggingFace permutation function (exact copy from their conversion script) + def permute(w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Handle embeddings and output weights + if name == "tok_embeddings.weight": + return "model.embed_tokens.weight", weight.clone() + elif name == "output.weight": + return "lm_head.weight", weight.clone() + elif name == "norm.weight": + return "model.norm.weight", weight.clone() + + # Handle layer-specific parameters + layer_match = re.match(r"layers\.(\d+)\.", name) + if not layer_match: + return None, None + + layer_idx = layer_match.group(1) + layer_suffix = name[len(f"layers.{layer_idx}."):] + hf_layer_prefix = f"model.layers.{layer_idx}" + + if model_family == "gptoss": + mapping = { + "attention.wq.weight": "self_attn.q_proj.weight", + "attention.wq.bias": "self_attn.q_proj.bias", + "attention.wk.weight": "self_attn.k_proj.weight", + "attention.wk.bias": "self_attn.k_proj.bias", + "attention.wv.weight": "self_attn.v_proj.weight", + "attention.wv.bias": "self_attn.v_proj.bias", + "attention.wo.weight": "self_attn.o_proj.weight", + "attention.wo.bias": "self_attn.o_proj.bias", + "attention.sinks": "self_attn.sinks", + "moe.experts.mlp1_weight": "mlp.experts.gate_up_proj", + "moe.experts.mlp1_bias": "mlp.experts.gate_up_proj_bias", + "moe.experts.mlp2_weight": "mlp.experts.down_proj", + "moe.experts.mlp2_bias": "mlp.experts.down_proj_bias", + "moe.router.gate.weight": "mlp.router.weight", + "moe.router.gate.bias": "mlp.router.bias", + "moe.expert_bias": "mlp.router.bias", # NOTE: this gets added into router bias + "attention_norm.weight": "input_layernorm.weight", + "ffn_norm.weight": "post_attention_layernorm.weight", + } + hf_suffix = mapping.get(layer_suffix) + if hf_suffix: + return f"{hf_layer_prefix}.{hf_suffix}", weight.clone() + return None, None + + # Handle attention weights + if layer_suffix == "attention.wq.weight": + if model_family == "llama3": + # For query weights, assume standard head_dim=128 + dim = weight.shape[1] + head_dim = 128 + n_heads = dim // head_dim + transformed_weight = permute(weight, n_heads) + else: + transformed_weight = weight + return f"{hf_layer_prefix}.self_attn.q_proj.weight", transformed_weight.clone() + + elif layer_suffix == "attention.wk.weight": + if model_family == "llama3": + # For key weights, infer n_kv_heads from weight shape + dim = weight.shape[1] + head_dim = 128 + n_kv_heads = weight.shape[0] // head_dim + key_value_dim = n_kv_heads * head_dim + transformed_weight = permute(weight, n_kv_heads, key_value_dim, dim) + else: + transformed_weight = weight + return f"{hf_layer_prefix}.self_attn.k_proj.weight", transformed_weight.clone() + + elif layer_suffix == "attention.wv.weight": + return f"{hf_layer_prefix}.self_attn.v_proj.weight", weight.clone() + + elif layer_suffix == "attention.wo.weight": + return f"{hf_layer_prefix}.self_attn.o_proj.weight", weight.clone() + + # Handle qwen3-specific attention norms + elif layer_suffix == "attention.q_norm.weight" and model_family == "qwen3": + return f"{hf_layer_prefix}.self_attn.q_norm.weight", weight.clone() + + elif layer_suffix == "attention.k_norm.weight" and model_family == "qwen3": + return f"{hf_layer_prefix}.self_attn.k_norm.weight", weight.clone() + + # Handle MLP weights + elif layer_suffix == "feed_forward.w1.weight": + return f"{hf_layer_prefix}.mlp.gate_proj.weight", weight.clone() + + elif layer_suffix == "feed_forward.w2.weight": + return f"{hf_layer_prefix}.mlp.down_proj.weight", weight.clone() + + elif layer_suffix == "feed_forward.w3.weight": + return f"{hf_layer_prefix}.mlp.up_proj.weight", weight.clone() + + # Handle layer norms + elif layer_suffix == "attention_norm.weight": + return f"{hf_layer_prefix}.input_layernorm.weight", weight.clone() + + elif layer_suffix == "ffn_norm.weight": + return f"{hf_layer_prefix}.post_attention_layernorm.weight", weight.clone() + + # If no mapping found, return None + return None, None + + +def map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0): + """Map TorchTitan state dict to HuggingFace format.""" + if any(k.endswith('.attention.q_norm.weight') for k in torchtitan_state_dict): + model_family = 'qwen3' + elif any(k.endswith('.attention.wq.bias') for k in torchtitan_state_dict): + model_family = 'gptoss' + else: + model_family = 'llama3' + + layer_keys = [k for k in torchtitan_state_dict.keys() if k.startswith("layers.")] + assert len(layer_keys) > 0, "No layers found in state dict" + n_layers = max([int(k.split(".")[1]) for k in layer_keys]) + 1 + hf_state_dict = {} + + # Get model info from sample weight + sample_wq_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wq.weight')) + wq_weight = torchtitan_state_dict[sample_wq_key] + dim = wq_weight.shape[1] # input dimension + + # Check if we have a key weight to determine n_kv_heads + sample_wk_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wk.weight')) + wk_weight = torchtitan_state_dict[sample_wk_key] + + # Standard Llama head dim is 128 for the 3B, 8B, 70B and 405B models + # NOTE: The only exception is the 1B model: https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json#L9 + # But let's ignore that for now + head_dim = 128 + n_heads = dim // head_dim + + # For GQA models, n_kv_heads might be different + n_kv_heads = wk_weight.shape[0] // head_dim + + print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}, model_family={model_family}") + + # HuggingFace permutation function (exact copy from their conversion script) + def permute(w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + # Convert embeddings and output (no permutation needed) + if 'tok_embeddings.weight' in torchtitan_state_dict: + hf_state_dict['model.embed_tokens.weight'] = torchtitan_state_dict['tok_embeddings.weight'].clone() + if 'output.weight' in torchtitan_state_dict: + hf_state_dict['lm_head.weight'] = torchtitan_state_dict['output.weight'].clone() + if 'norm.weight' in torchtitan_state_dict: + hf_state_dict['model.norm.weight'] = torchtitan_state_dict['norm.weight'].clone() + + # Convert layers + for layer_idx in tqdm(range(n_layers), desc="Converting layers"): + layer_prefix = f'layers.{layer_idx}' + hf_layer_prefix = f'model.layers.{layer_idx}' + + if model_family == 'gptoss': + # Attention projections and biases + mappings = { + f'{layer_prefix}.attention.wq.weight': f'{hf_layer_prefix}.self_attn.q_proj.weight', + f'{layer_prefix}.attention.wq.bias': f'{hf_layer_prefix}.self_attn.q_proj.bias', + f'{layer_prefix}.attention.wk.weight': f'{hf_layer_prefix}.self_attn.k_proj.weight', + f'{layer_prefix}.attention.wk.bias': f'{hf_layer_prefix}.self_attn.k_proj.bias', + f'{layer_prefix}.attention.wv.weight': f'{hf_layer_prefix}.self_attn.v_proj.weight', + f'{layer_prefix}.attention.wv.bias': f'{hf_layer_prefix}.self_attn.v_proj.bias', + f'{layer_prefix}.attention.wo.weight': f'{hf_layer_prefix}.self_attn.o_proj.weight', + f'{layer_prefix}.attention.wo.bias': f'{hf_layer_prefix}.self_attn.o_proj.bias', + f'{layer_prefix}.attention.sinks': f'{hf_layer_prefix}.self_attn.sinks', + f'{layer_prefix}.moe.experts.mlp1_weight': f'{hf_layer_prefix}.mlp.experts.gate_up_proj', + f'{layer_prefix}.moe.experts.mlp1_bias': f'{hf_layer_prefix}.mlp.experts.gate_up_proj_bias', + f'{layer_prefix}.moe.experts.mlp2_weight': f'{hf_layer_prefix}.mlp.experts.down_proj', + f'{layer_prefix}.moe.experts.mlp2_bias': f'{hf_layer_prefix}.mlp.experts.down_proj_bias', + f'{layer_prefix}.moe.router.gate.weight': f'{hf_layer_prefix}.mlp.router.weight', + f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight', + f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight', + } + for tt_key, hf_key in mappings.items(): + if tt_key in torchtitan_state_dict: + hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() + # Combine router gate bias with expert bias (if present) + router_bias_key = f'{layer_prefix}.moe.router.gate.bias' + expert_bias_key = f'{layer_prefix}.moe.expert_bias' + if ( + router_bias_key in torchtitan_state_dict + or expert_bias_key in torchtitan_state_dict + ): + if router_bias_key in torchtitan_state_dict: + bias = torchtitan_state_dict[router_bias_key].clone() + else: + bias = torch.zeros_like(torchtitan_state_dict[expert_bias_key]) + if expert_bias_key in torchtitan_state_dict: + bias = bias + torchtitan_state_dict[expert_bias_key] + hf_state_dict[f'{hf_layer_prefix}.mlp.router.bias'] = bias + continue + + # Attention weights with proper permutation + if f'{layer_prefix}.attention.wq.weight' in torchtitan_state_dict: + wq = torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] + if model_family == "llama3": + wq = permute(wq, n_heads) + hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] = wq.clone() + + if f'{layer_prefix}.attention.wk.weight' in torchtitan_state_dict: + wk = torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] + key_value_dim = n_kv_heads * head_dim + if model_family == "llama3": + wk = permute(wk, n_kv_heads, key_value_dim, dim) + hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] = wk.clone() + + if f'{layer_prefix}.attention.wv.weight' in torchtitan_state_dict: + # Value weights don't get permuted + hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'].clone() + + if model_family == "qwen3": + if f'{layer_prefix}.attention.q_norm.weight' in torchtitan_state_dict: + hf_state_dict[f'{hf_layer_prefix}.self_attn.q_norm.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.q_norm.weight'].clone() + if f'{layer_prefix}.attention.k_norm.weight' in torchtitan_state_dict: + hf_state_dict[f'{hf_layer_prefix}.self_attn.k_norm.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.k_norm.weight'].clone() + + if f'{layer_prefix}.attention.wo.weight' in torchtitan_state_dict: + # Output projection doesn't get permuted + hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'].clone() + + # MLP weights (no permutation) + mlp_mappings = { + f'{layer_prefix}.feed_forward.w1.weight': f'{hf_layer_prefix}.mlp.gate_proj.weight', + f'{layer_prefix}.feed_forward.w2.weight': f'{hf_layer_prefix}.mlp.down_proj.weight', + f'{layer_prefix}.feed_forward.w3.weight': f'{hf_layer_prefix}.mlp.up_proj.weight', + } + + for tt_key, hf_key in mlp_mappings.items(): + if tt_key in torchtitan_state_dict: + hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() + + # Layer norms (no permutation) + norm_mappings = { + f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight', + f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight', } - for tt_key, (hf_key,) in m.items(): - if tt_key in tt: - hf[hf_key] = tt[tt_key].clone() - elif strict: - raise KeyError(f"Missing expected key in TorchTitan state dict: '{tt_key}'") + for tt_key, hf_key in norm_mappings.items(): + if tt_key in torchtitan_state_dict: + hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() + + print(f"Converted {len(hf_state_dict)} parameters from TorchTitan to HuggingFace format") + return hf_state_dict + + +def map_torchtitan_to_hf2(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0, validate_against_original=True): + """Map TorchTitan state dict to HuggingFace format using per-parameter function.""" + + # Auto-detect model family + if any(k.endswith('.attention.q_norm.weight') for k in torchtitan_state_dict): + model_family = "qwen3" + elif any(k.endswith('.attention.wq.bias') for k in torchtitan_state_dict): + model_family = "gptoss" + else: + model_family = "llama3" + + logger.info(f"Converting using per-parameter function with model_family={model_family}") + + hf_state_dict = {} + skipped_params = [] + + # Convert each parameter individually + for name, weight in tqdm(torchtitan_state_dict.items(), desc="Converting parameters"): + hf_name, hf_weight = map_torchtitan_to_hf_per_param(name, weight, model_family) + if hf_name is not None: + if hf_name in hf_state_dict: + hf_state_dict[hf_name] = hf_state_dict[hf_name] + hf_weight # NOTE: adds expert_bias into router bias + else: + hf_state_dict[hf_name] = hf_weight + else: + skipped_params.append(name) + + logger.info(f"Converted {len(hf_state_dict)} parameters, skipped {len(skipped_params)} parameters") + if skipped_params: + logger.info(f"Skipped parameters: {skipped_params}") - print(f"Converted {len(hf)} parameters from TorchTitan to HuggingFace format (gpt-oss).") - return hf + # Validation against original function + if validate_against_original: + logger.info("Validating against original conversion function...") + + # Get original result + original_hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta) + + # Compare keys + new_keys = set(hf_state_dict.keys()) + original_keys = set(original_hf_state_dict.keys()) + + if new_keys != original_keys: + missing_in_new = original_keys - new_keys + extra_in_new = new_keys - original_keys + logger.error(f"Key mismatch! Missing in new: {missing_in_new}, Extra in new: {extra_in_new}") + raise ValueError("Key sets don't match between implementations") + + # Compare tensor values + mismatched_tensors = [] + for key in original_keys: + if not torch.allclose(hf_state_dict[key], original_hf_state_dict[key], rtol=1e-5, atol=1e-8): + mismatched_tensors.append(key) + + if mismatched_tensors: + logger.error(f"Tensor value mismatches in: {mismatched_tensors}") + # Show details for first mismatch + key = mismatched_tensors[0] + logger.error(f"First mismatch in {key}:") + logger.error(f" Max abs diff: {torch.max(torch.abs(hf_state_dict[key] - original_hf_state_dict[key]))}") + logger.error(f" Original shape: {original_hf_state_dict[key].shape}") + logger.error(f" New shape: {hf_state_dict[key].shape}") + raise ValueError("Tensor values don't match between implementations") + + logger.info("āœ“ Validation passed! New implementation matches original.") + + return hf_state_dict @app.command(name="hf_to_dcp") @torch.inference_mode() -def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 150000.0, dtype: str = "auto", torchtitan_config: str = "20b"): +def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, dtype: str = "auto"): """Convert HuggingFace model to TorchTitan DCP format. Args: @@ -407,23 +574,12 @@ def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131 max_seq_len: Max sequence length for RoPE rope_theta: RoPE theta parameter dtype: Data type to use ("auto" to preserve original, or specific dtype like "float32") - torchtitan_config: TorchTitan model config name (e.g., "16B-A3B", "debugmodel") """ - # Import TorchTitan configs - try: - from torchtitan.models.gpt_oss import gptoss_configs - except ImportError: - raise ImportError("Cannot import TorchTitan GPT-OSS configs. Make sure you're in the right environment.") - logger.info(f"Loading model from {input_path}") # Load model with original dtype if "auto", otherwise use specified dtype hf_model = AutoModelForCausalLM.from_pretrained(input_path, torch_dtype=torch.bfloat16) - # Validate configuration compatibility - logger.info(f"Validating config compatibility with TorchTitan config: {torchtitan_config}") - validate_config_compatibility(hf_model.config, torchtitan_config, gptoss_configs) - hf_state_dict = hf_model.state_dict() logger.info(f"Loaded model with dtype: {next(iter(hf_state_dict.values())).dtype}") @@ -434,34 +590,12 @@ def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131 output_path.mkdir(parents=True, exist_ok=True) storage_writer = DCP.filesystem.FileSystemWriter(output_path, thread_count=8) DCP.save({"model": torchtitan_state_dict}, storage_writer=storage_writer) - - # Save metadata for reference - metadata = { - "original_hf_model": input_path, - "torchtitan_config": torchtitan_config, - "conversion_time": str(torch.tensor(0).item()), # placeholder - "hf_config": dict(hf_model.config.__dict__), - "torchtitan_config_dict": dict(gptoss_configs[torchtitan_config].__dict__), - } - with open(output_path / "conversion_metadata.json", "w") as f: - import json - json.dump(metadata, f, indent=2, default=str) - logger.info("Conversion complete!") - logger.info(f"šŸ“‹ Saved conversion metadata to {output_path}/conversion_metadata.json") - logger.info(f"šŸš€ To use in TorchTitan, specify model config: {torchtitan_config}") - - # Final reminder about RoPE differences - if "gpt-oss" in input_path.lower(): - logger.info(f"") - logger.info(f"šŸ”” IMPORTANT: Converted GPT-OSS model uses TorchTitan's RoPE implementation") - logger.info(f" This differs from HuggingFace but is expected behavior") - logger.info(f" See conversion script documentation for details") @app.command(name="dcp_to_hf") @torch.inference_mode() -def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 150000.0, default_model: str = "openai/gpt-oss-20b"): +def convert_dcp_to_hf(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B", validate_against_original: bool = False): """Convert TorchTitan DCP format to HuggingFace model. Args: @@ -471,11 +605,16 @@ def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 13 rope_theta: RoPE theta parameter default_model: Default HuggingFace model for config """ - from torchtitan.datasets.transformation import get_tokenizer_with_chat_template - from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner - from torch.distributed.checkpoint.state_dict_loader import _load_state_dict + + if str(input_path).startswith("s3://"): + import s3_utils + local_path = s3_utils.sync_to_nvme(str(input_path)) + input_path = Path(local_path) + logger.info(f"Loading DCP checkpoint from {input_path}") + from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner + from torch.distributed.checkpoint.state_dict_loader import _load_state_dict # Load DCP input_path state_dict = {} _load_state_dict( @@ -486,7 +625,15 @@ def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 13 ) torchtitan_state_dict = state_dict["model"] logger.info("Converting weights to HuggingFace format") - hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict) + hf_state_dict = map_torchtitan_to_hf2(torchtitan_state_dict, max_seq_len, rope_theta, validate_against_original=validate_against_original) + + if '/' not in default_model: + if 'qwen' in default_model.lower(): + default_model = f'Qwen/{default_model}' + elif 'llama' in default_model.lower(): + default_model = f'meta-llama/{default_model}' + else: + raise ValueError(f"Unsupported model: {default_model}") # Create HuggingFace config hf_config = AutoConfig.from_pretrained(default_model) @@ -495,7 +642,7 @@ def convert_dcp_to_hf(input_path: Path, output_path: Path, max_seq_len: int = 13 logger.info("Creating HuggingFace model") # tokenizer = AutoTokenizer.from_pretrained(default_model) tokenizer = get_tokenizer_with_chat_template(default_model, "tulu", override=True) - hf_model = AutoModelForCausalLM.from_pretrained(default_model) + hf_model = AutoModelForCausalLM.from_pretrained(default_model, device_map="auto") # NOTE: need device_map="auto" to avoid CPU OOM # load state dict logger.info("Loading state dict") From 4010fa2cc5d0647f433dfc7e0048011856f5c8b3 Mon Sep 17 00:00:00 2001 From: Rohan Pandey Date: Fri, 5 Sep 2025 12:19:10 -0700 Subject: [PATCH 06/18] new lse-based flexattn implementation for sinks --- torchtitan/models/attention.py | 92 +++++++++++++++------------------- 1 file changed, 40 insertions(+), 52 deletions(-) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index a273ac563c..3c3b607571 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -72,85 +72,73 @@ def __init__( self.attn_mask_type = attn_mask_type self.fixed_block_size = fixed_block_size + self.mask_cache = {} FlexAttention.used_attn_mask_types.add(self.mask_key) @property def mask_key(self) -> FLEX_ATTN_MASK_T: return (self.attn_mask_type, self.fixed_block_size) - def forward(self, q, k, v, sink_weights=None, sliding_window=0, enable_gqa=False): - """ - q : (B, H_q, S_q, D) - k : (B, H_kv, S_kv, D) -- without sink - v : (B, H_kv, S_kv, D) - sink_weights : (H_q,) or (H, M) -- broadcast to all queries - sliding_window : int - enable_gqa : bool - """ + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, + sink_weights: torch.Tensor | None = None, + sliding_window: int = 0, + enable_gqa: bool = False, + ) -> torch.Tensor: if sink_weights is None: block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask) + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) B, H_q, S_q, D = q.shape _, H_kv, S_kv, _ = k.shape - sink_idx = S_kv # sink occupies final key slot - - sink_k = k.new_zeros(B, H_kv, 1, D) # this needn't be 0's since it's overwritten - sink_v = v.new_zeros(B, H_kv, 1, D) # 0 value nullifies sink weight in output - - k_ext = torch.cat([k, sink_k], dim=2) - v_ext = torch.cat([v, sink_v], dim=2) - # masks ensure sinks are included in softmax - if sliding_window is not None and sliding_window > 0: - mask_mod = FlexAttention._get_sliding_window_with_sink_mask_mod(sliding_window, sink_idx) - else: - mask_mod = FlexAttention._get_causal_with_sink_mask_mod(sink_idx) - - block_mask = FlexAttention.compiled_create_block_mask( - mask_mod, B, H_q, S_q, S_kv+1 - ) - - # overwrite the dummy sink scores with actual sink weights - def score_mod(score, b, h_q, q_idx, kv_idx): - return torch.where( - kv_idx == sink_idx, - sink_weights[h_q].to(score.dtype) + 0.0, # cast + keep grad - score + # regular (no-sink) mask + no extra KV col + mask_key = (sliding_window, S_q, S_kv) + if mask_key not in self.mask_cache: + if sliding_window is not None and sliding_window > 0: + mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window) + else: + mask_mod = FlexAttention._get_causal_mask_mod() + block_mask = create_block_mask( + mask_mod, B, H_q, S_q, S_kv, + _compile=True, device=q.device # NOTE: set _compile=False if sampling for debugging ) + self.mask_cache[mask_key] = block_mask - return FlexAttention.flex_attn( - q, k_ext, v_ext, + block_mask = self.mask_cache[mask_key] + + # run fast flex_attn and return LSE + out, lse = FlexAttention.flex_attn( + q, k, v, block_mask=block_mask, - score_mod=score_mod, - enable_gqa=enable_gqa + enable_gqa=enable_gqa, + return_lse=True ) - @staticmethod - def _get_causal_with_sink_mask_mod(sink_idx): - """ - Returns a mask_mod function that - - only allows kv_idx ≤ q_idx (causal) - - or if kv_idx == sink_idx (always allow the sink) - """ - orig = FlexAttention._get_causal_mask_mod() - def causal_with_sink(b, h, q_idx, kv_idx): - return orig(b, h, q_idx, kv_idx) | (kv_idx == sink_idx) - return causal_with_sink + # rescale by sigma(lse - w[h]) and broadcast over D + if sink_weights is not None: + w = sink_weights # [H] + scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1] + out = out * scale + + out = out.to(q.dtype) + return out @staticmethod - def _get_sliding_window_with_sink_mask_mod(window: int, sink_idx: int): + def _get_sliding_window_mask_mod(window: int): """ Returns a mask_mod function that - only allows kv_idx ≤ q_idx (causal) - and only if (q_idx - kv_idx) ≤ window - - or if kv_idx == sink_idx (always allow the sink) """ def sliding_mod(b, h, q_idx, kv_idx): # causal within window keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window) - # always allow the sink slot - return keep | (kv_idx == sink_idx) + return keep return sliding_mod @staticmethod From 2e71aaf01cbc01c47577a6c043155d1eb02a37f8 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 22 Sep 2025 22:49:17 -0700 Subject: [PATCH 07/18] test --- torchtitan/experiments/gpt_oss/README.py | 0 torchtitan/experiments/gpt_oss/__init__.py | 44 +- .../gpt_oss/infra/expert_parallel.py | 137 +--- .../experiments/gpt_oss/infra/optimizer.py | 67 -- .../experiments/gpt_oss/infra/parallelize.py | 1 + torchtitan/experiments/gpt_oss/model/args.py | 38 +- torchtitan/experiments/gpt_oss/model/model.py | 49 +- torchtitan/experiments/gpt_oss/model/moe.py | 212 +----- .../gpt_oss/scripts/compare_hf_to_tt.py | 405 ----------- .../gpt_oss/scripts/convert_gptoss.py | 661 ------------------ .../gpt_oss/train_configs/debug_model.toml | 2 +- .../gpt_oss/train_configs/gpt_oss_120b.toml | 70 -- .../gpt_oss/train_configs/gpt_oss_20b.toml | 70 -- 13 files changed, 108 insertions(+), 1648 deletions(-) create mode 100644 torchtitan/experiments/gpt_oss/README.py delete mode 100644 torchtitan/experiments/gpt_oss/infra/optimizer.py delete mode 100644 torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py delete mode 100644 torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py delete mode 100644 torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml delete mode 100644 torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml diff --git a/torchtitan/experiments/gpt_oss/README.py b/torchtitan/experiments/gpt_oss/README.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index 715ce943e0..8603588f55 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -5,9 +5,11 @@ from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.moe import MoEArgs from .infra.optimizer import build_gptoss_optimizers from torchtitan.protocols.train_spec import register_train_spec, TrainSpec +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from .infra.parallelize import parallelize_gptoss from .model.args import GptOssModelArgs @@ -26,15 +28,45 @@ hidden_size=256, num_hidden_layers=4, use_flex_attn=False, - use_grouped_mm=False, + moe_args = MoEArgs( + num_experts=8, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=False, + load_balance_coeff=1e-3, + ) ), "20b": GptOssModelArgs( num_hidden_layers=24, - num_local_experts=32, + moe_args = MoEArgs( + num_experts=32, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ) ), "120b": GptOssModelArgs( num_hidden_layers=36, - num_local_experts=128, + moe_args = MoEArgs( + num_experts=128, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ) ), } @@ -42,11 +74,11 @@ register_train_spec( TrainSpec( name="gpt_oss", - cls=GptOssModel, - config=gptoss_configs, + model_cls=GptOssModel, + model_args=gptoss_configs, parallelize_fn=parallelize_gptoss, pipelining_fn=None, - build_optimizers_fn=build_gptoss_optimizers, # use optimizer hooks to update expert weights + build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py index e47bdeec58..62775b8b67 100644 --- a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -15,6 +15,7 @@ ) from torch.distributed.tensor.parallel import ParallelStyle from torch.distributed.tensor.placement_types import Placement +from torchtitan.distributed.expert_parallel import ExpertParallel # implementation of Tensor Parallel for the GroupedExperts in MoE @@ -43,138 +44,6 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: self._partition_fn, ) - -# NOTE: This is to achieve replicate computation on the gate module in the MoE router. -# It does nothing other than (1) setting the module parameters as DTensors on the given mesh -# and (2) inserting hooks to module boundary to change torch.Tensor to DTensor and back. -# The reason we need this wrapping is to ensure all parameters are on the same 1D/2D mesh, -# which is assumed by (1) gradient norm clipping, and (2) optimizer fused implementation. -class NoParallel(ParallelStyle): - def __init__( - self, - *, - input_layout: Placement | None = None, - output_layout: Placement | None = None, - use_local_output: bool = True, - ): - super().__init__() - self.input_layout = input_layout or Replicate() - self.output_layout = output_layout or Replicate() - self.desired_input_layout = Replicate() - self.use_local_output = use_local_output - - @staticmethod - def _prepare_input_fn(input_layout, desired_input_layout, mod, inputs, device_mesh): - # annotate module input placements/sharding with input_layouts - input_tensor = inputs[0] - if not isinstance(input_tensor, DTensor): - input_tensor = DTensor.from_local( - input_tensor, device_mesh, (input_layout,), run_check=False - ) - - if input_layout != desired_input_layout: - input_tensor = input_tensor.redistribute( - placements=(desired_input_layout,), async_op=True - ) - return (input_tensor, *inputs[1:]) - - @staticmethod - def _prepare_output_fn(output_layout, use_local_output, mod, outputs, device_mesh): - if outputs.placements != (output_layout,): - outputs = outputs.redistribute(placements=(output_layout,), async_op=True) - # back to local tensor - return outputs.to_local() if use_local_output else outputs - - def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - return distribute_module( - module, - device_mesh, - None, - partial( - self._prepare_input_fn, self.input_layout, self.desired_input_layout - ), - partial(self._prepare_output_fn, self.output_layout, self.use_local_output), - ) - - -class ExpertParallel(ParallelStyle): - def __init__(self): - super().__init__() - self.input_splits = None - self.output_splits = None - - # performing all-to-all dispatch on the input - def _token_dispatch(self, mod, inputs, device_mesh): - # annotate module input placements/sharding with input_layouts - routed_input, num_tokens_per_expert = inputs - - # generate the input splits and output splits for all-to-all - with torch.no_grad(): - num_tokens_per_expert_group = num_tokens_per_expert.new_empty( - num_tokens_per_expert.shape[0] - ) - dist.all_to_all_single( - num_tokens_per_expert_group, - num_tokens_per_expert, - group=device_mesh.get_group(), - ) - # NOTE: this would incur a device-to-host sync - self.input_splits = ( - num_tokens_per_expert.view(device_mesh.shape[0], -1).sum(dim=1).tolist() - ) - self.output_splits = ( - num_tokens_per_expert_group.view(device_mesh.shape[0], -1) - .sum(dim=1) - .tolist() - ) - - # perform all-to-all - routed_input = all_to_all_single_autograd( - routed_input, - self.output_splits, - self.input_splits, - device_mesh.get_group(), - ) - - # NOTE: After this all-to-all, the routed input is put on proper EP rank. - # However, the num_tokens_per_expert_group is not of the final target format - # [#tokens for local expert 0, #tokens for local expert 1, ...] - # Rather, it is of the format - # [#tokens for local expert 0 from EP rank 0, #tokens for local expert 1 from EP rank 0, ..., - # #tokens for local expert 0 from EP rank 1, #tokens for local expert 1 from EP rank 1, ...] - # We need to perform another shuffle to get the correct format -- this is done via the function - # generate_permute_indices in moe.py, which also does padding to make sure the number of tokens - # each expert gets locally is a multiple of ALIGN_SIZE_M. - - return routed_input, num_tokens_per_expert_group - - @staticmethod - def _partition_fn(name, mod, device_mesh): - # shard on the expert dimension - for name, param in mod.named_parameters(recurse=False): - dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) - mod.register_parameter(name, dist_param) - - # performing all-to-all combine on the output - def _token_combine(self, mod, routed_output, device_mesh): - routed_output = all_to_all_single_autograd( - routed_output, - self.input_splits, - self.output_splits, - device_mesh.get_group(), - ) - return routed_output - - def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: - return distribute_module( - module, - device_mesh, - partition_fn=ExpertParallel._partition_fn, - input_fn=self._token_dispatch, - output_fn=self._token_combine, - ) - - # This class is for dp2ep with TP (without TP we can just use ExpertParallel) class ExpertTensorParallel(ExpertParallel): def __init__( @@ -224,6 +93,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) +# TODO(jianiw): This need to be merged with def expert_parallel(func: Callable) -> Callable: """ This is a wrapper applied to the GroupedExperts computation, serving @@ -250,6 +120,7 @@ def wrapper( mlp1_bias: torch.Tensor, mlp2_weight: torch.Tensor, mlp2_bias: torch.Tensor, + swiglu_limit: float, x: torch.Tensor, num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: @@ -285,7 +156,7 @@ def wrapper( input_shape = x.shape x = x[permuted_indices, :] - out = func(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, x, num_tokens_per_expert) + out = func(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, swiglu_limit, x, num_tokens_per_expert) if num_tokens_per_expert is not None: out_unpermuted = out.new_empty(input_shape) diff --git a/torchtitan/experiments/gpt_oss/infra/optimizer.py b/torchtitan/experiments/gpt_oss/infra/optimizer.py deleted file mode 100644 index de8537032d..0000000000 --- a/torchtitan/experiments/gpt_oss/infra/optimizer.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import torch -import torch.nn as nn -from torch.distributed.device_mesh import DeviceMesh - -from torchtitan.components.ft import FTManager -from torchtitan.components.optimizer import build_optimizers, OptimizersContainer -from torchtitan.config_manager import JobConfig -from torchtitan.distributed import ParallelDims - - -# for MoE auxiliary-loss-free load balancing -def _update_expert_bias( - model_parts: list[nn.Module], - world_mesh: dict[str, DeviceMesh], - parallel_dims: ParallelDims, -): - dp_cp_mesh = world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None - # TODO: Currently this sync is blocking (thus exposed) and happens on the - # default compute stream. Need to assess if this is OK performance-wise. - for model_part in model_parts: - for transformer_block in model_part.layers.values(): - moe = transformer_block.moe - if moe.load_balance_coeff is None: - return - - if dp_cp_mesh is not None: - torch.distributed.all_reduce( - moe.tokens_per_expert, group=dp_cp_mesh.get_group() - ) - - with torch.no_grad(): - expert_bias_delta = moe.load_balance_coeff * torch.sign( - moe.tokens_per_expert.mean() - moe.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - moe.expert_bias.add_(expert_bias_delta) - moe.tokens_per_expert.zero_() - - -def build_gptoss_optimizers( - model_parts: list[nn.Module], - job_config: JobConfig, - parallel_dims: ParallelDims, - world_mesh: DeviceMesh, - ft_manager: FTManager, -) -> OptimizersContainer: - optimizers = build_optimizers( - model_parts=model_parts, - job_config=job_config, - parallel_dims=parallel_dims, - world_mesh=world_mesh, - ft_manager=ft_manager, - ) - - optimizers.register_step_pre_hook( - lambda *args, **kwargs: _update_expert_bias( - model_parts, world_mesh=world_mesh, parallel_dims=parallel_dims - ) - ) - - return optimizers diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 47ad01b99e..0d3f9a9f54 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -29,6 +29,7 @@ NoParallel, TensorParallel, ) +from torchtitan.distributed import NoParallel # Adapted from llama4/infra/parallelize.py diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index 63c2b6bb82..e5d09e856a 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -2,17 +2,17 @@ # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Literal from torch import nn -from torchtitan.components.tokenizer import Tokenizer + from torchtitan.config_manager import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger - -# from transformers.models.gpt_oss.modeling_gpt_oss import GPT_OSS_PRETRAINED_INIT_CONFIGURATION +from torchtitan.models.moe import MoEArgs +from torchtitan.tools.utils import has_cuda_capability # Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py @@ -61,10 +61,8 @@ class GptOssModelArgs(BaseModelArgs): num_hidden_layers: int = 24 norm_eps: float = 1e-5 # eps used for RMSNorm # MoE - num_local_experts: int = 32 - num_experts_per_tok: int = 4 - use_grouped_mm: bool = True - load_balance_coeff: float | None = 1e-3 + moe_args: MoEArgs = field(default_factory=MoEArgs) + swiglu_limit: float = 7.0 # Multi-Head Latent Attention (MLA) head_dim: int = 64 num_attention_heads: int = 64 @@ -79,12 +77,24 @@ class GptOssModelArgs(BaseModelArgs): beta_fast: int = 32 beta_slow: int = 1 - def update_from_config(self, job_config: JobConfig, tokenizer: Tokenizer) -> None: - """ - Update the model_config config from the given job config. - """ - # self.vocab_size = tokenizer.vocab_size - self.max_seq_len = job_config.training.seq_len + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): + logger.warning( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + self.moe_args.use_grouped_mm = False + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: """ diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index c16cb53274..cbc2a4a25d 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -13,46 +13,6 @@ from .args import GptOssModelArgs from .moe import MoE -# TODO: may be able to remove this once parallelized properly -def convert_submodules_to_bf16( - module: nn.Module, - exclude_names: tuple[str, ...] = ("freqs_cis", "attention_norm", "ffn_norm", "norm"), - attr_opt_out: str = "no_bf16", # if a submodule sets `self.no_bf16 = True`, it will be skipped - ) -> None: - """ - Recursively convert parameters & buffers of submodules to bfloat16, - except: - - modules whose *qualified name* ends with any of `exclude_names` - - modules with attribute `{attr_opt_out} == True` - Conversion is *shallow per-module* so exclusions are respected even deep in the tree. - """ - - def should_skip(qname: str, mod: nn.Module) -> bool: - base = qname.rsplit(".", 1)[-1] # local (leaf) name - if base in exclude_names: - return True - if getattr(mod, attr_opt_out, False): - return True - return False - - def convert_shallow(mod: nn.Module): - # convert parameters owned by this module - for _, p in mod.named_parameters(recurse=False): - if p.is_floating_point(): - p.data = p.data.to(torch.bfloat16) - # convert buffers owned by this module - for _, b in mod.named_buffers(recurse=False): - # keep non-float buffers (e.g., ints, bool masks) as-is - if torch.is_floating_point(b): - b.data = b.data.to(torch.bfloat16) - - # walk the module tree; convert only *this* module's tensors if not skipped - for qname, mod in module.named_modules(): - # skip the root container name (empty) check gracefully - local_name = qname.rsplit(".", 1)[-1] if qname else "" - if local_name and should_skip(qname, mod): - continue - convert_shallow(mod) # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 def precompute_freqs_cis(args: GptOssModelArgs) -> torch.Tensor: @@ -191,6 +151,12 @@ def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor): assert freqs_cis.shape[-1] == rot, "freqs_cis last dim must be D/2" freqs_cis = freqs_cis[:T, :] + # Memory layout comparison for head_dim=8: + # HF Format: [r0][r1][r2][r3][i0][i1][i2][i3] + # ↑-- reals --↑ ↑-- imags --↑ + + # Interleaved: [r0][i0][r1][i1][r2][i2][r3][i3] + # ↑-pair-↑ ↑-pair-↑ ↑-pair-↑ ↑-pair-↑ # --- inline: HF half-split -> interleaved (real0, imag0, real1, imag1, ...) # q_i, k_i: [B, T, H, D] q_i = torch.empty_like(q) @@ -202,10 +168,12 @@ def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor): # --- Torchtitan default complex apply (expects interleaved last dim) # freqs_cis will be reshaped inside to [1, T, 1, rot] + # TODO(jianiw): I think we shoud go with sin/cos representation to simplify the conversion between paired real/imaginary <-> half-split real/imaginary q_rot_i = apply_rotary_emb_inner(q_i, freqs_cis) # uses TT's complex path k_rot_i = apply_rotary_emb_inner(k_i, freqs_cis) # --- inline: interleaved -> HF half-split + # TODO(jianiw): convert it back q_out = torch.cat([q_rot_i[..., 0::2], q_rot_i[..., 1::2]], dim=-1) k_out = torch.cat([k_rot_i[..., 0::2], k_rot_i[..., 1::2]], dim=-1) return q_out, k_out @@ -331,6 +299,7 @@ def forward( enable_gqa=True, ) else: + # eager attention forward output = self.attn( q, k, v, self.sinks, attention_mask=self.sliding_window_causal(seqlen, x.device), diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py index c056819758..2fbbc38e28 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -6,8 +6,10 @@ import torch.nn.functional as F from torch import nn from torchtitan.models.gpt_oss.infra.expert_parallel import expert_parallel +from torchtitan.protocols import model from .args import GptOssModelArgs +from torchtitan.models.moe import MoE, MoEArgs, GroupedExperts def swiglu(x, alpha: float = 1.702, limit: float = 7.0): x_glu, x_linear = x[..., ::2], x[..., 1::2] @@ -18,18 +20,20 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0): # Note we add an extra bias of 1 to the linear layer return out_glu * (x_linear + 1) -class GroupedExperts(nn.Module): +class GptOssGroupedExperts(nn.Module): def __init__( self, dim: int, num_experts: int, + swiglu_limit: float, use_grouped_mm: bool, ): super().__init__() self.num_experts = num_experts self.use_grouped_mm = use_grouped_mm + self.swiglu_limit = swiglu_limit - self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2))) + self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2))) # w1 and w3 self.mlp1_bias = nn.Parameter(torch.empty((num_experts, dim * 2))) self.mlp2_weight = nn.Parameter(torch.empty((num_experts, dim, dim))) self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) @@ -40,12 +44,12 @@ def forward( num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: if self.use_grouped_mm: - return GroupedExperts._run_experts_grouped_mm( - self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + return GptOssGroupedExperts._run_experts_grouped_mm( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, self.swiglu_limit, x, num_tokens_per_expert ) else: - return GroupedExperts._run_experts_for_loop( - self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, x, num_tokens_per_expert + return GptOssGroupedExperts._run_experts_for_loop( + self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, self.swiglu_limit, x, num_tokens_per_expert ) # TODO: keeping this for-loop implementation for comparison @@ -57,6 +61,7 @@ def _run_experts_for_loop( mlp1_bias: torch.Tensor, mlp2_weight: torch.Tensor, mlp2_bias: torch.Tensor, + swiglu_limit: float, x: torch.Tensor, num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: @@ -77,7 +82,7 @@ def _run_experts_for_loop( out_experts_splits = [] for expert_idx, x_expert in enumerate(x): h = torch.matmul(x_expert, mlp1_weight[expert_idx]) + mlp1_bias[expert_idx] - h = swiglu(h) + h = swiglu(h, limit=swiglu_limit) h = torch.matmul(h, mlp2_weight[expert_idx]) + mlp2_bias[expert_idx] out_experts_splits.append(h) out = torch.cat(out_experts_splits, dim=0) @@ -87,18 +92,19 @@ def _run_experts_for_loop( else: # x shape (num_experts, tokens_per_expert, dim) h = torch.bmm(x, mlp1_weight) + mlp1_bias.unsqueeze(1) - h = swiglu(h) + h = swiglu(h, limit=swiglu_limit) out = torch.bmm(h, mlp2_weight) + mlp2_bias.unsqueeze(1) return out - # @expert_parallel # NOTE: EP currently reduces 20B MFU from 17.8% to 16.5%! + @expert_parallel # NOTE: EP currently reduces 20B MFU from 17.8% to 16.5%! @staticmethod def _run_experts_grouped_mm( mlp1_weight: torch.Tensor, mlp1_bias: torch.Tensor, mlp2_weight: torch.Tensor, mlp2_bias: torch.Tensor, + swiglu_limit: float, x: torch.Tensor, num_tokens_per_expert: torch.Tensor | None = None, ) -> torch.Tensor: @@ -123,7 +129,7 @@ def _run_experts_grouped_mm( b1 = torch.cat([b1, b1.new_zeros((tail_slack, b1.shape[-1]))], dim=0) h = h + b1.to(h.dtype) - h = swiglu(h) + h = swiglu(h, limit=swiglu_limit) h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) if offsets is not None: b2 = mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) @@ -148,177 +154,21 @@ def extra_repr(self): f"mlp2_weight={tuple(self.mlp2_weight.shape)}, " f"mlp2_bias={tuple(self.mlp2_bias.shape)}") -class TokenChoiceTopKRouter(nn.Module): - """This class implements token-choice routing. In token-choice top-K routing, each token is - routed to top K experts based on the router scores. - Args: - dim (int): Dimension of the input. - num_experts (int): Number of experts in each moe layer. - top_k (int): Number of experts each token will be routed to in token-choice routing. - """ - - def __init__( - self, - dim: int, - num_experts: int, - top_k: int, - ): - super().__init__() - - self.dim = dim - self.num_experts = num_experts - self.top_k = top_k - self.gate = nn.Linear(self.dim, self.num_experts, bias=True) - - def forward( - self, x: torch.Tensor, expert_bias: torch.Tensor | None = None - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - TODO: We haven't implement the group-based routing (node limit routing), - and currently EP is not supporting node limit routing yet. - - Args: - x (torch.Tensor): Input tensor with shape ``(bs*slen, dim)``. - - Returns: - routed_input (torch.Tensor): - Tokens grouped together by experts indices with shape ``(bs*slen*top_k,)``. - token_indices (torch.Tensor): - Token indices for routed_input with shape ``(bs*slen*top_k,)``. - num_tokens_per_expert (torch.Tensor): - Number of tokens assigned to each expert with shape ``(num_experts,)``. - """ - # scores shape (bs*slen, num_experts) - router_logits = self.gate(x) - - if expert_bias is not None: - router_logits = router_logits + expert_bias - - # top scores shape (bs*slen, top_k) - top_scores, selected_experts_indices = torch.topk( - router_logits, k=self.top_k, dim=1 - ) - - top_scores = F.softmax(top_scores, dim=1) - - # group tokens together by expert indices from 0 to num_experts and pass that to experts forward - num_tokens_per_expert = torch.histc( - selected_experts_indices.view(-1), - bins=self.num_experts, - min=0, - max=self.num_experts, - ) - - # Reorder the token indices to match the order of the experts - # token_indices_experts_sorted shape (bs*slen*top_k,) - token_indices_experts_sorted = torch.argsort( - selected_experts_indices.view(-1), stable=True - ) - - # reorder the scores to match the order of the token indices - top_scores = top_scores.view(-1)[token_indices_experts_sorted] - token_indices_experts_sorted = token_indices_experts_sorted // self.top_k - - return top_scores, token_indices_experts_sorted, num_tokens_per_expert - - def init_weights(self, init_std: float): - nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std) - - -class MoE(nn.Module): - def __init__(self, model_args: GptOssModelArgs): - - super().__init__() - dim = model_args.hidden_size - - num_experts = model_args.num_local_experts - top_k = model_args.num_experts_per_tok - - self.experts = GroupedExperts( - dim=dim, - num_experts=num_experts, - use_grouped_mm=model_args.use_grouped_mm, - ) - self.router = TokenChoiceTopKRouter( +class GptOssMoE(MoE): + """GptOss MoE implementation that inherits from the base MoE class.""" + + def __init__(self, model_args: GptOssModelArgs, dim: int, hidden_dim: int): + # Convert GptOssModelArgs to MoEArgs for base class compatibility + moe_args = model_args.moe_args + + # Initialize the base MoE class + super().__init__(moe_args, dim, hidden_dim) + + # Override the base GroupedExperts with GptOssGroupedExperts + self.experts = GptOssGroupedExperts( dim=dim, - num_experts=num_experts, - top_k=top_k, - ) - self.load_balance_coeff = model_args.load_balance_coeff - if self.load_balance_coeff is not None: - assert self.load_balance_coeff > 0.0 - self.register_buffer( - "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - self.register_buffer( - "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), - persistent=True, - ) - else: - self.expert_bias = None - - def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Args: - x (torch.Tensor): Input tensor with shape ``(bs, slen, dim)``. - - Returns: - out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``. - """ - bs, slen, dim = x.shape - - # top_scores and selected_indices shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - ( - top_scores, - token_indices, - num_tokens_per_expert, - ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - - if self.load_balance_coeff is not None and torch.is_grad_enabled(): - with torch.no_grad(): - self.tokens_per_expert.add_(num_tokens_per_expert) - - # shape (bs*slen*top_k, dim) - token_indices = token_indices.reshape(-1, 1).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather( - x.view(-1, dim), - dim=0, - index=token_indices, - ) - - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - routed_output = (routed_output.to(torch.float32) * top_scores.unsqueeze(-1)).to( - x.dtype + num_experts=moe_args.num_experts, + swiglu_limit=model_args.swiglu_limit, + use_grouped_mm=moe_args.use_grouped_mm, ) - - out = torch.zeros_like(x.reshape(bs * slen, dim)) - - # Accumulate multiple expert results becase each token can be routed to multiple experts - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) - out = out.reshape(bs, slen, dim) - return out - - def init_weights( - self, - init_std: float, - buffer_device: torch.device, - ): - self.experts.init_weights(init_std) - self.router.init_weights(init_std) - if self.load_balance_coeff is not None: - with torch.device(buffer_device): - self.expert_bias = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) - self.tokens_per_expert = torch.zeros( - self.experts.num_experts, dtype=torch.float32 - ) diff --git a/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py b/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py deleted file mode 100644 index dbbb880af5..0000000000 --- a/torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py +++ /dev/null @@ -1,405 +0,0 @@ -""" -Compare logits and generations of GPT-OSS implemented in TorchTitan and HuggingFace. -This requires at least a 2xH100. - -First ensure you convert the HF model to a TorchTitan DCP checkpoint: -uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ - -Then you can run a comparison like this: -uv run torchtitan/experiments/gpt_oss/scripts/compare_hf_to_tt.py \ - --tt_config torchtitan/models/gpt_oss/train_configs/gpt_oss_20b.toml \ - --tt_checkpoint_path gptoss_dcp/ \ - --hf_model_path openai/gpt-oss-20b \ - --prompt "Once upon a time, in a land far away," \ - --temperature 0.8 \ - --max_new_tokens 256 \ - --batch_size 1 \ - --out -""" - -import json -import os -import sys -import time -from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Sequence, Tuple, NamedTuple - -import torch -import torch.nn as nn -import torch.distributed.checkpoint as dcp -import tyro -from transformers import AutoModelForCausalLM, AutoTokenizer - -from torchtitan.tools.logging import init_logger, logger -from torchtitan.tools.utils import device_module, device_type -from torchtitan.components.metrics import build_device_memory_monitor -from torchtitan.config_manager import ConfigManager -from torchtitan.protocols.train_spec import get_train_spec -from torchtitan.distributed import ParallelDims, utils as dist_utils -from torch.distributed import DeviceMesh -from torch.distributed.elastic.multiprocessing.errors import record - -# -------- Torchtitan Sampling Utils -------- -def multinomial_sample_one( - probs: torch.Tensor, rng: Optional[torch.Generator] = None -) -> torch.Tensor: - q = torch.empty_like(probs).exponential_(1, generator=rng) - return torch.argmax(probs / q, dim=-1, keepdim=True).to(dtype=torch.long) - - -def logits_to_probs( - logits: torch.Tensor, - temperature: float = 1.0, - top_k: Optional[int] = None, -) -> torch.Tensor: - logits = logits / max(temperature, 1e-5) - - if top_k is not None: - v, _ = torch.topk(logits, k=min(top_k, logits.size(-1))) - pivot = v.select(dim=-1, index=-1).unsqueeze(-1) - logits = torch.where(logits < pivot, -float("Inf"), logits) - - probs = torch.nn.functional.softmax(logits, dim=-1) - return probs - - -def generate_next_token( - model, - x: torch.Tensor, - *, - temperature: float = 1.0, - top_k: Optional[int] = None, - rng: Optional[torch.Generator] = None, -) -> torch.Tensor: - logits = model(x) # (B, T, vocab_size) - probs = logits_to_probs(logits[:, -1, :], temperature, top_k) - next_token = multinomial_sample_one(probs, rng=rng) - return next_token - - -@torch.no_grad() -def tt_generate_text( - model, - input_ids: torch.Tensor, - *, - max_new_tokens: int, - temperature: float = 1.0, - top_k: Optional[int] = None, - seed: Optional[int] = None, -) -> torch.Tensor: - # ensure batch dimension (T,) --> (B, T) - if input_ids.ndim == 1: - input_ids = input_ids.unsqueeze(0) - - rng = None - if seed is not None: - rng = torch.Generator(input_ids.device).manual_seed(seed) - - generated_tokens = input_ids.clone() - - for i in range(max_new_tokens): - next_token = generate_next_token( - model, - x=generated_tokens.to(input_ids.device), - temperature=temperature, - top_k=top_k, - rng=rng, - ) - print(f"generated token {i}: {next_token}") - - generated_tokens = torch.cat([generated_tokens, next_token], dim=1) - - return generated_tokens - -@dataclass -class GenerateConfig: - """Configuration for test generation.""" - hf_model_path: Optional[str] = None - """HuggingFace model path to load (if provided).""" - tt_config: Optional[str] = None - """TOML config file path for TorchTitan model.""" - tt_checkpoint_path: Optional[str] = None - """Checkpoint path for the TorchTitan model (if provided).""" - tt_tokenizer_path: Optional[str] = "libs/torchtitan/torchtitan/models/gpt_oss_20b/tokenizer" - """Tokenizer path to load.""" - temperature: float = 1.0 - """Sampling temperature (0 for greedy).""" - max_new_tokens: int = 32 - """Max number of tokens to generate.""" - batch_size: int = 1 - """Batch size for inputs.""" - top_k: Optional[int] = None - """Top-k sampling (optional).""" - seed: Optional[int] = None - """Random seed for reproducibility.""" - deterministic: bool = False - """Use deterministic algorithms.""" - prompt: str = "" - """Input prompt string.""" - out: bool = False - """If true, print JSON report at end.""" - - -class LogitsComparison(NamedTuple): - max_abs_diff: float - mean_abs_diff: float - max_rel_diff: float - mean_rel_diff: float - allclose_results: Sequence[Tuple[float, float, str, bool]] - sample_diffs: Optional[torch.Tensor] - systematic_offset: Optional[Tuple[float, float]] - - -def load_hf_model(path: str, device: torch.device) -> nn.Module: - model = AutoModelForCausalLM.from_pretrained(path).to(device) - model.eval() - return model - -def print_param_dtypes_first_block(model): - """ - Prints the dtype of every parameter in the given model. - For any parameters under a 'layers' module (e.g., layers.), - only prints those from the first block (idx == "0"). - This works for both GptOssForCausalLM (with a .model submodule) - and GptOssModel architectures. - """ - for name, param in model.named_parameters(): - parts = name.split('.') - # If this parameter is under a 'layers' module, check its index - if 'layers' in parts: - idx = parts.index('layers') + 1 - if idx < len(parts) and parts[idx] != '0': - continue - print(f"{name:50s} → {param.dtype}") - -def get_logits(model: nn.Module, input_ids: torch.Tensor) -> torch.Tensor: - with torch.no_grad(): - out = model(input_ids) - if hasattr(out, "logits"): - return out.logits - else: - return out - - -def compare_logits( - tt_logits: torch.Tensor, - hf_logits: torch.Tensor, - tolerances: Sequence[Tuple[float, float, str]] = ( - (1e-4, 1e-6, "Very Strict"), - (1e-2, 1e-4, "Strict"), - (1e-1, 1e-2, "Moderate"), - ), -) -> LogitsComparison: - # Apply softmax to convert logits to probabilities - hf_logits = torch.nn.functional.softmax(hf_logits.float(), dim=-1) - tt_logits = torch.nn.functional.softmax(tt_logits.float(), dim=-1) - - diff = torch.abs(tt_logits - hf_logits) - max_abs = float(torch.max(diff)) - mean_abs = float(torch.mean(diff)) - rel = diff / (torch.abs(tt_logits) + 1e-8) - max_rel = float(torch.max(rel)) - mean_rel = float(torch.mean(rel)) - - results = [] - any_match = False - for rtol, atol, name in tolerances: - match = torch.allclose(tt_logits, hf_logits, rtol=rtol, atol=atol) - results.append((rtol, atol, name, bool(match))) - if match: - any_match = True - break - - sample_diffs = None - sys_offset = None - if not any_match: - flat = (tt_logits - hf_logits).flatten() - sample_diffs = flat[:25] - sys_offset = (float(torch.mean(flat)), float(torch.std(flat))) - - return LogitsComparison(max_abs, mean_abs, max_rel, mean_rel, results, sample_diffs, sys_offset) - - -def generate_text( - model: nn.Module, - input_ids: torch.Tensor, - max_new_tokens: int, - temperature: float = 0.0, - top_k: Optional[int] = None, -) -> torch.Tensor: - do_sample = temperature > 0 - temp_arg = temperature if do_sample else None - with torch.no_grad(): - return model.generate( - input_ids, - max_new_tokens=max_new_tokens, - do_sample=do_sample, - temperature=temp_arg, - top_k=top_k, - ) - - -def print_logits_comparison(comp: LogitsComparison): - print("\n" + "="*70) - print("LOGITS COMPARISON") - print("="*70) - print(f"Max abs diff: {comp.max_abs_diff:.6f}") - print(f"Mean abs diff: {comp.mean_abs_diff:.6f}") - print(f"Max rel diff: {comp.max_rel_diff:.6f}") - print(f"Mean rel diff: {comp.mean_rel_diff:.6f}\n") - print("Tolerance tests:") - for rtol, atol, name, match in comp.allclose_results: - print(f" {'āœ…' if match else 'āŒ'} {name} (rtol={rtol}, atol={atol})") - if comp.sample_diffs is not None: - print("\nšŸ” Sample diffs (first 25):") - for v in comp.sample_diffs.tolist(): - print(f" {v:.6f}") - mean, std = comp.systematic_offset - print(f"\nSystematic offset: mean={mean:.6f}, std={std:.6f}") - - -def print_generation(title: str, outputs: torch.Tensor, tokenizer): - text = tokenizer.decode(outputs[0].tolist()) - print("\n" + "="*60) - print(title) - print("="*60) - print(text) - print("="*60) - - -def print_generation_comparison( - tt_out: torch.Tensor, - hf_out: torch.Tensor, - tokenizer, - prompt_len: int, -): - tt_tokens = tt_out[0][prompt_len:].tolist() - hf_tokens = hf_out[0][prompt_len:].tolist() - n = min(len(tt_tokens), len(hf_tokens)) - matches = sum(1 for i in range(n) if tt_tokens[i] == hf_tokens[i]) - print("\n" + "="*70) - print("GENERATION COMPARISON") - print("="*70) - print(f"Match rate: {matches}/{n} ({matches/n*100:.1f}%)") - if matches != n or len(tt_tokens) != len(hf_tokens): - print("First mismatches:") - for i in range(min(10, n)): - if tt_tokens[i] != hf_tokens[i]: - tt_txt = tokenizer.decode([tt_tokens[i]]) - hf_txt = tokenizer.decode([hf_tokens[i]]) - print(f" Pos {i}: TT='{tt_txt}' vs HF='{hf_txt}'") - - -@record -def test_generate(args: GenerateConfig): - init_logger() - - if not args.hf_model_path and not args.tt_config: - raise ValueError("Either hf_model_path or tt_config must be provided.") - if not args.prompt: - logger.warning("Empty prompt; generating from scratch.") - - # --- Common setup: tokenizer & inputs --- - if args.hf_model_path: - tokenizer = AutoTokenizer.from_pretrained(args.hf_model_path) - input_ids = tokenizer.encode(args.prompt, add_special_tokens=False, return_tensors="pt") - print(input_ids) - if args.tt_config: - config_mgr = ConfigManager() - config = config_mgr.parse_args([ - f"--job.config_file={args.tt_config}", - f"--model.tokenizer_path={args.tt_tokenizer_path}", - ]) - train_spec = get_train_spec(config.model.name) - - # --- HuggingFace model (optional) --- - hf_model = None - hf_logits = None - hf_out = None - if args.hf_model_path: # NOTE: comment this block out for rapid tt testing - hf_device = torch.device(f"{device_type}:0") - hf_model = load_hf_model(args.hf_model_path, hf_device) - print("\n" + "="*60) - print("HUGGINGFACE MODEL ARCHITECTURE:") - print(hf_model) - print("="*60) - print_param_dtypes_first_block(hf_model) - print("="*60) - - hf_in = input_ids.to(hf_device) - hf_logits = get_logits(hf_model, hf_in).to(input_ids.device) - print(f"hf_logits: {hf_logits[:, :, 42069:42072]}") - hf_out = generate_text( - hf_model, hf_in, - max_new_tokens=args.max_new_tokens, - temperature=0.0, - top_k=args.top_k, - ).to(input_ids.device) - - # --- TorchTitan model (optional) --- - tt_model = None - tt_logits = None - tt_out = None - if args.tt_config: - # (Original TT setup: distributed, device, checkpoint load, etc.) - world_size = int(os.environ.get("WORLD_SIZE", 1)) - device = torch.device(f"{device_type}:1") - device_module.set_device(device) - dist_utils.set_determinism(None, device, args.seed, args.deterministic) - - # instantiate & load TT model - model_args = train_spec.config[config.model.flavor] - model_args.update_from_config(config, tokenizer) - init_dev = "meta" if world_size > 1 else device - with torch.device(init_dev): - tt_model = train_spec.cls(model_args) - if world_size > 1: - # parallelize if needed - pass - print("\n" + "="*60) - print("TORCHTITAN MODEL ARCHITECTURE:") - print(tt_model) - print("="*60) - print_param_dtypes_first_block(tt_model) - print("="*60) - - tt_model.eval() - if args.tt_checkpoint_path: # only load checkpoint if provided - tt_state = tt_model.state_dict() - tt_state.pop("freqs_cis", None) - state = {"model": tt_state} - dcp.load(state, checkpoint_id=args.tt_checkpoint_path) - - tt_logits = get_logits(tt_model, input_ids.to(device)).to(hf_logits.device if hf_logits is not None else device) - print(f"āœ… Torchtitan model forward pass succeeded: {tt_logits.shape=}") - print(f"tt_logits: {tt_logits[:, :, 42069:42072]}") - - tt_out = tt_generate_text( - tt_model, input_ids.to(device), - max_new_tokens=args.max_new_tokens, - temperature=args.temperature, - top_k=args.top_k, - seed=args.seed, - ) - - # --- Logits comparison (if both present) --- - if hf_logits is not None and tt_logits is not None: - comp = compare_logits(tt_logits, hf_logits) - print_logits_comparison(comp) - - # --- Print generations --- - if hf_out is not None: - print_generation("HUGGINGFACE MODEL OUTPUT:", hf_out, tokenizer) - if tt_out is not None: - print_generation("TORCHTITAN MODEL OUTPUT:", tt_out, tokenizer) - - # --- Generation comparison --- - if hf_out is not None and tt_out is not None: - prompt_len = input_ids.size(1) - print_generation_comparison(tt_out, hf_out, tokenizer, prompt_len) - - -if __name__ == "__main__": - args = tyro.cli(GenerateConfig) - test_generate(args) diff --git a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py b/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py deleted file mode 100644 index 59c15ab944..0000000000 --- a/torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py +++ /dev/null @@ -1,661 +0,0 @@ -""" -Convert checkpoints between TorchTitan and HuggingFace. - -# Convert HF to TorchTitan DCP -uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py hf-to-dcp --input-path openai/gpt-oss-20b --output-path gptoss_dcp/ - -# Convert TorchTitan DCP to HF -uv run torchtitan/experiments/gpt_oss/scripts/convert_gptoss.py dcp-to-hf --input-path gptoss_dcp/ --output-path gptoss_hf/ -""" - -import re -import tempfile -from pathlib import Path -from typing import Union, Tuple, Optional - -import torch -import torch.distributed.checkpoint as DCP -from torch.distributed.checkpoint.format_utils import dcp_to_torch_save -from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner -from torch.distributed.checkpoint.state_dict_loader import _load_state_dict -from torchtitan.datasets.transformation import get_tokenizer_with_chat_template -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, LlamaConfig -from torchtitan.models.llama3.model import precompute_freqs_cis -from tqdm import tqdm -from tyro.extras import SubcommandApp - -from torchtitan.tools.logging import init_logger, logger - -app = SubcommandApp() - - - -def validate_hf_keys(hf_state_dict, model_config, model_name): - """Validate that all expected weight keys exist in the HF state dict.""" - missing_keys = [] - n_layers = model_config.num_hidden_layers - - # Check basic weights - required_keys = [ - "model.embed_tokens.weight", - "lm_head.weight", - "model.norm.weight" - ] - - for key in required_keys: - if key not in hf_state_dict: - missing_keys.append(key) - - # Check layer weights - for layer_idx in range(n_layers): - layer_prefix = f'model.layers.{layer_idx}' - - # Check attention weights - attention_keys = [ - f"{layer_prefix}.self_attn.q_proj.weight", - f"{layer_prefix}.self_attn.k_proj.weight", - f"{layer_prefix}.self_attn.v_proj.weight", - f"{layer_prefix}.self_attn.o_proj.weight", - f"{layer_prefix}.self_attn.q_proj.bias", - f"{layer_prefix}.self_attn.k_proj.bias", - f"{layer_prefix}.self_attn.v_proj.bias", - f"{layer_prefix}.self_attn.o_proj.bias", - f"{layer_prefix}.input_layernorm.weight", - f"{layer_prefix}.post_attention_layernorm.weight", - ] - - for key in attention_keys: - if key not in hf_state_dict: - missing_keys.append(key) - - # Check MoE weights - mlp_keys = [ - f"{layer_prefix}.mlp.router.weight", - f"{layer_prefix}.mlp.router.bias", - f"{layer_prefix}.mlp.experts.gate_up_proj", - f"{layer_prefix}.mlp.experts.gate_up_proj_bias", - f"{layer_prefix}.mlp.experts.down_proj", - f"{layer_prefix}.mlp.experts.down_proj_bias", - ] - - for key in mlp_keys: - if key not in hf_state_dict: - missing_keys.append(key) - - if missing_keys: - logger.error(f"Missing {len(missing_keys)} expected weight keys in HF model:") - for key in missing_keys[:10]: # Show first 10 - logger.error(f" - {key}") - if len(missing_keys) > 10: - logger.error(f" ... and {len(missing_keys) - 10} more") - - # Try to diagnose the issue - logger.info("Available keys in HF model:") - available_keys = list(hf_state_dict.keys()) - for key in available_keys[:20]: # Show first 20 - logger.info(f" - {key}") - if len(available_keys) > 20: - logger.info(f" ... and {len(available_keys) - 20} more") - - raise ValueError(f"HF model '{model_name}' is missing expected weight keys. " - f"This suggests the model architecture doesn't match expectations.") - - logger.info(f"āœ“ Weight key validation passed - found all expected keys") - - -def map_hf_to_torchtitan(hf_state_dict, model_config, max_seq_len=131072, rope_theta=500000.0, model_name="meta-llama/Llama-3.1-8B"): - """Map HuggingFace state dict to TorchTitan format. - - Note: TorchTitan and HuggingFace use different RoPE implementations: - - TorchTitan: Adjacent element pairing with complex arithmetic - - HuggingFace: First/second half pairing with cos/sin arithmetic - - This difference is architectural, not a bug. Converted models will have - slightly different positional encoding but typically minimal impact on performance. - """ - - # Validate that all expected keys exist - validate_hf_keys(hf_state_dict, model_config, model_name) - - n_layers = model_config.num_hidden_layers - n_heads = model_config.num_attention_heads - dim = model_config.hidden_size - dims_per_head = dim // n_heads - - # Fix: Corrected model family detection logic - if "llama" in model_name.lower(): - model_family = "llama3" - elif "qwen" in model_name.lower(): - model_family = "qwen3" - max_seq_len = model_config.max_position_embeddings - rope_theta = model_config.rope_theta - elif "gpt-oss" in model_name.lower(): - model_family = "gptoss" - max_seq_len = model_config.max_position_embeddings - rope_theta = model_config.rope_theta - else: - raise ValueError(f"Unsupported HuggingFace model for conversion: {model_name}") - - # Determine n_kv_heads for GQA models - n_kv_heads = model_config.num_key_value_heads - head_dim = model_config.head_dim - print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}, model_family={model_family}, max_seq_len={max_seq_len}, rope_theta={rope_theta}") - torchtitan_state_dict = {} - - # Convert embeddings and output - torchtitan_state_dict["tok_embeddings.weight"] = hf_state_dict["model.embed_tokens.weight"].clone() - torchtitan_state_dict["output.weight"] = hf_state_dict["lm_head.weight"].clone() - torchtitan_state_dict["norm.weight"] = hf_state_dict["model.norm.weight"].clone() - - def permute(w, n_heads_arg, dim1=None, dim2=None): - if dim1 is None: - dim1 = w.shape[0] - if dim2 is None: - dim2 = w.shape[1] - return w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - # Convert layers - for layer_idx in tqdm(range(n_layers), desc="Converting layers"): - hf_layer_prefix = f'model.layers.{layer_idx}' - layer_prefix = f'layers.{layer_idx}' - - wq = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] - torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] = wq.clone() - wq_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.bias'] - torchtitan_state_dict[f'{layer_prefix}.attention.wq.bias'] = wq_bias.clone() - - wk = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] - torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] = wk.clone() - wk_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.bias'] - torchtitan_state_dict[f'{layer_prefix}.attention.wk.bias'] = wk_bias.clone() - - wv = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] - torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'] = wv.clone() - wv_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.bias'] - torchtitan_state_dict[f'{layer_prefix}.attention.wv.bias'] = wv_bias.clone() - - wo = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] - torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'] = wo.clone() - wo_bias = hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.bias'] - torchtitan_state_dict[f'{layer_prefix}.attention.wo.bias'] = wo_bias.clone() - - sinks = hf_state_dict[f'{hf_layer_prefix}.self_attn.sinks'] - torchtitan_state_dict[f'{layer_prefix}.attention.sinks'] = sinks.clone() - - # MoE weights - mlp1 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj'] - torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_weight'] = mlp1.clone() - - mlp1_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.gate_up_proj_bias'] - torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp1_bias'] = mlp1_bias.clone() - - mlp2 = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj'] - torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_weight'] = mlp2.clone() - - mlp2_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.experts.down_proj_bias'] - torchtitan_state_dict[f'{layer_prefix}.moe.experts.mlp2_bias'] = mlp2_bias.clone() - - # router - gate = hf_state_dict[f'{hf_layer_prefix}.mlp.router.weight'] - torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.weight'] = gate.clone() - router_bias = hf_state_dict[f'{hf_layer_prefix}.mlp.router.bias'] - torchtitan_state_dict[f'{layer_prefix}.moe.router.gate.bias'] = router_bias.clone() - - # # @vwxyzjn: This is technically not needed, but we added here because we haven't figured out - # # how to tell torchtitan to ignore this parameter. - # tokens_per_expert = torch.zeros_like(expert_bias) - # torchtitan_state_dict[f'{layer_prefix}.moe.tokens_per_expert'] = tokens_per_expert.clone() - - # Layer norms - attention_norm = hf_state_dict[f'{hf_layer_prefix}.input_layernorm.weight'] - torchtitan_state_dict[f'{layer_prefix}.attention_norm.weight'] = attention_norm.clone() - ffn_norm = hf_state_dict[f'{hf_layer_prefix}.post_attention_layernorm.weight'] - torchtitan_state_dict[f'{layer_prefix}.ffn_norm.weight'] = ffn_norm.clone() - - # Precompute RoPE frequencies - # NOTE: we no longer precompute RoPE frequencies in TorchTitan - # this `model_config` is HF but needs to be TT (to include e.g. beta_fast) - # torchtitan_state_dict["freqs_cis"] = precompute_freqs_cis(model_config) - - print(f"Converted {len(torchtitan_state_dict)} parameters from HuggingFace to TorchTitan format") - return torchtitan_state_dict - - -def map_torchtitan_to_hf_per_param(name: str, weight: torch.Tensor, model_family: str = "llama3") -> Tuple[Optional[str], Optional[torch.Tensor]]: - """Map a single TorchTitan parameter to HuggingFace format. - - Args: - name: Parameter name in TorchTitan format - weight: Parameter tensor - model_family: Model family ("llama3", "qwen3", or "gptoss") - - Returns: - Tuple of (hf_name, hf_weight) or (None, None) if parameter should be skipped - """ - # Skip freqs_cis as it's computed dynamically in HF - if name == "freqs_cis": - return None, None - - assert model_family in ("llama3", "qwen3", "gptoss"), f"Unsupported model family: {model_family}" - - # HuggingFace permutation function (exact copy from their conversion script) - def permute(w, n_heads_arg, dim1=None, dim2=None): - if dim1 is None: - dim1 = w.shape[0] - if dim2 is None: - dim2 = w.shape[1] - return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - # Handle embeddings and output weights - if name == "tok_embeddings.weight": - return "model.embed_tokens.weight", weight.clone() - elif name == "output.weight": - return "lm_head.weight", weight.clone() - elif name == "norm.weight": - return "model.norm.weight", weight.clone() - - # Handle layer-specific parameters - layer_match = re.match(r"layers\.(\d+)\.", name) - if not layer_match: - return None, None - - layer_idx = layer_match.group(1) - layer_suffix = name[len(f"layers.{layer_idx}."):] - hf_layer_prefix = f"model.layers.{layer_idx}" - - if model_family == "gptoss": - mapping = { - "attention.wq.weight": "self_attn.q_proj.weight", - "attention.wq.bias": "self_attn.q_proj.bias", - "attention.wk.weight": "self_attn.k_proj.weight", - "attention.wk.bias": "self_attn.k_proj.bias", - "attention.wv.weight": "self_attn.v_proj.weight", - "attention.wv.bias": "self_attn.v_proj.bias", - "attention.wo.weight": "self_attn.o_proj.weight", - "attention.wo.bias": "self_attn.o_proj.bias", - "attention.sinks": "self_attn.sinks", - "moe.experts.mlp1_weight": "mlp.experts.gate_up_proj", - "moe.experts.mlp1_bias": "mlp.experts.gate_up_proj_bias", - "moe.experts.mlp2_weight": "mlp.experts.down_proj", - "moe.experts.mlp2_bias": "mlp.experts.down_proj_bias", - "moe.router.gate.weight": "mlp.router.weight", - "moe.router.gate.bias": "mlp.router.bias", - "moe.expert_bias": "mlp.router.bias", # NOTE: this gets added into router bias - "attention_norm.weight": "input_layernorm.weight", - "ffn_norm.weight": "post_attention_layernorm.weight", - } - hf_suffix = mapping.get(layer_suffix) - if hf_suffix: - return f"{hf_layer_prefix}.{hf_suffix}", weight.clone() - return None, None - - # Handle attention weights - if layer_suffix == "attention.wq.weight": - if model_family == "llama3": - # For query weights, assume standard head_dim=128 - dim = weight.shape[1] - head_dim = 128 - n_heads = dim // head_dim - transformed_weight = permute(weight, n_heads) - else: - transformed_weight = weight - return f"{hf_layer_prefix}.self_attn.q_proj.weight", transformed_weight.clone() - - elif layer_suffix == "attention.wk.weight": - if model_family == "llama3": - # For key weights, infer n_kv_heads from weight shape - dim = weight.shape[1] - head_dim = 128 - n_kv_heads = weight.shape[0] // head_dim - key_value_dim = n_kv_heads * head_dim - transformed_weight = permute(weight, n_kv_heads, key_value_dim, dim) - else: - transformed_weight = weight - return f"{hf_layer_prefix}.self_attn.k_proj.weight", transformed_weight.clone() - - elif layer_suffix == "attention.wv.weight": - return f"{hf_layer_prefix}.self_attn.v_proj.weight", weight.clone() - - elif layer_suffix == "attention.wo.weight": - return f"{hf_layer_prefix}.self_attn.o_proj.weight", weight.clone() - - # Handle qwen3-specific attention norms - elif layer_suffix == "attention.q_norm.weight" and model_family == "qwen3": - return f"{hf_layer_prefix}.self_attn.q_norm.weight", weight.clone() - - elif layer_suffix == "attention.k_norm.weight" and model_family == "qwen3": - return f"{hf_layer_prefix}.self_attn.k_norm.weight", weight.clone() - - # Handle MLP weights - elif layer_suffix == "feed_forward.w1.weight": - return f"{hf_layer_prefix}.mlp.gate_proj.weight", weight.clone() - - elif layer_suffix == "feed_forward.w2.weight": - return f"{hf_layer_prefix}.mlp.down_proj.weight", weight.clone() - - elif layer_suffix == "feed_forward.w3.weight": - return f"{hf_layer_prefix}.mlp.up_proj.weight", weight.clone() - - # Handle layer norms - elif layer_suffix == "attention_norm.weight": - return f"{hf_layer_prefix}.input_layernorm.weight", weight.clone() - - elif layer_suffix == "ffn_norm.weight": - return f"{hf_layer_prefix}.post_attention_layernorm.weight", weight.clone() - - # If no mapping found, return None - return None, None - - -def map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0): - """Map TorchTitan state dict to HuggingFace format.""" - if any(k.endswith('.attention.q_norm.weight') for k in torchtitan_state_dict): - model_family = 'qwen3' - elif any(k.endswith('.attention.wq.bias') for k in torchtitan_state_dict): - model_family = 'gptoss' - else: - model_family = 'llama3' - - layer_keys = [k for k in torchtitan_state_dict.keys() if k.startswith("layers.")] - assert len(layer_keys) > 0, "No layers found in state dict" - n_layers = max([int(k.split(".")[1]) for k in layer_keys]) + 1 - hf_state_dict = {} - - # Get model info from sample weight - sample_wq_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wq.weight')) - wq_weight = torchtitan_state_dict[sample_wq_key] - dim = wq_weight.shape[1] # input dimension - - # Check if we have a key weight to determine n_kv_heads - sample_wk_key = next(k for k in torchtitan_state_dict.keys() if k.endswith('.attention.wk.weight')) - wk_weight = torchtitan_state_dict[sample_wk_key] - - # Standard Llama head dim is 128 for the 3B, 8B, 70B and 405B models - # NOTE: The only exception is the 1B model: https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json#L9 - # But let's ignore that for now - head_dim = 128 - n_heads = dim // head_dim - - # For GQA models, n_kv_heads might be different - n_kv_heads = wk_weight.shape[0] // head_dim - - print(f"Model info: dim={dim}, n_heads={n_heads}, n_kv_heads={n_kv_heads}, head_dim={head_dim}, model_family={model_family}") - - # HuggingFace permutation function (exact copy from their conversion script) - def permute(w, n_heads_arg, dim1=None, dim2=None): - if dim1 is None: - dim1 = w.shape[0] - if dim2 is None: - dim2 = w.shape[1] - return w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) - - # Convert embeddings and output (no permutation needed) - if 'tok_embeddings.weight' in torchtitan_state_dict: - hf_state_dict['model.embed_tokens.weight'] = torchtitan_state_dict['tok_embeddings.weight'].clone() - if 'output.weight' in torchtitan_state_dict: - hf_state_dict['lm_head.weight'] = torchtitan_state_dict['output.weight'].clone() - if 'norm.weight' in torchtitan_state_dict: - hf_state_dict['model.norm.weight'] = torchtitan_state_dict['norm.weight'].clone() - - # Convert layers - for layer_idx in tqdm(range(n_layers), desc="Converting layers"): - layer_prefix = f'layers.{layer_idx}' - hf_layer_prefix = f'model.layers.{layer_idx}' - - if model_family == 'gptoss': - # Attention projections and biases - mappings = { - f'{layer_prefix}.attention.wq.weight': f'{hf_layer_prefix}.self_attn.q_proj.weight', - f'{layer_prefix}.attention.wq.bias': f'{hf_layer_prefix}.self_attn.q_proj.bias', - f'{layer_prefix}.attention.wk.weight': f'{hf_layer_prefix}.self_attn.k_proj.weight', - f'{layer_prefix}.attention.wk.bias': f'{hf_layer_prefix}.self_attn.k_proj.bias', - f'{layer_prefix}.attention.wv.weight': f'{hf_layer_prefix}.self_attn.v_proj.weight', - f'{layer_prefix}.attention.wv.bias': f'{hf_layer_prefix}.self_attn.v_proj.bias', - f'{layer_prefix}.attention.wo.weight': f'{hf_layer_prefix}.self_attn.o_proj.weight', - f'{layer_prefix}.attention.wo.bias': f'{hf_layer_prefix}.self_attn.o_proj.bias', - f'{layer_prefix}.attention.sinks': f'{hf_layer_prefix}.self_attn.sinks', - f'{layer_prefix}.moe.experts.mlp1_weight': f'{hf_layer_prefix}.mlp.experts.gate_up_proj', - f'{layer_prefix}.moe.experts.mlp1_bias': f'{hf_layer_prefix}.mlp.experts.gate_up_proj_bias', - f'{layer_prefix}.moe.experts.mlp2_weight': f'{hf_layer_prefix}.mlp.experts.down_proj', - f'{layer_prefix}.moe.experts.mlp2_bias': f'{hf_layer_prefix}.mlp.experts.down_proj_bias', - f'{layer_prefix}.moe.router.gate.weight': f'{hf_layer_prefix}.mlp.router.weight', - f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight', - f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight', - } - for tt_key, hf_key in mappings.items(): - if tt_key in torchtitan_state_dict: - hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() - # Combine router gate bias with expert bias (if present) - router_bias_key = f'{layer_prefix}.moe.router.gate.bias' - expert_bias_key = f'{layer_prefix}.moe.expert_bias' - if ( - router_bias_key in torchtitan_state_dict - or expert_bias_key in torchtitan_state_dict - ): - if router_bias_key in torchtitan_state_dict: - bias = torchtitan_state_dict[router_bias_key].clone() - else: - bias = torch.zeros_like(torchtitan_state_dict[expert_bias_key]) - if expert_bias_key in torchtitan_state_dict: - bias = bias + torchtitan_state_dict[expert_bias_key] - hf_state_dict[f'{hf_layer_prefix}.mlp.router.bias'] = bias - continue - - # Attention weights with proper permutation - if f'{layer_prefix}.attention.wq.weight' in torchtitan_state_dict: - wq = torchtitan_state_dict[f'{layer_prefix}.attention.wq.weight'] - if model_family == "llama3": - wq = permute(wq, n_heads) - hf_state_dict[f'{hf_layer_prefix}.self_attn.q_proj.weight'] = wq.clone() - - if f'{layer_prefix}.attention.wk.weight' in torchtitan_state_dict: - wk = torchtitan_state_dict[f'{layer_prefix}.attention.wk.weight'] - key_value_dim = n_kv_heads * head_dim - if model_family == "llama3": - wk = permute(wk, n_kv_heads, key_value_dim, dim) - hf_state_dict[f'{hf_layer_prefix}.self_attn.k_proj.weight'] = wk.clone() - - if f'{layer_prefix}.attention.wv.weight' in torchtitan_state_dict: - # Value weights don't get permuted - hf_state_dict[f'{hf_layer_prefix}.self_attn.v_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wv.weight'].clone() - - if model_family == "qwen3": - if f'{layer_prefix}.attention.q_norm.weight' in torchtitan_state_dict: - hf_state_dict[f'{hf_layer_prefix}.self_attn.q_norm.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.q_norm.weight'].clone() - if f'{layer_prefix}.attention.k_norm.weight' in torchtitan_state_dict: - hf_state_dict[f'{hf_layer_prefix}.self_attn.k_norm.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.k_norm.weight'].clone() - - if f'{layer_prefix}.attention.wo.weight' in torchtitan_state_dict: - # Output projection doesn't get permuted - hf_state_dict[f'{hf_layer_prefix}.self_attn.o_proj.weight'] = torchtitan_state_dict[f'{layer_prefix}.attention.wo.weight'].clone() - - # MLP weights (no permutation) - mlp_mappings = { - f'{layer_prefix}.feed_forward.w1.weight': f'{hf_layer_prefix}.mlp.gate_proj.weight', - f'{layer_prefix}.feed_forward.w2.weight': f'{hf_layer_prefix}.mlp.down_proj.weight', - f'{layer_prefix}.feed_forward.w3.weight': f'{hf_layer_prefix}.mlp.up_proj.weight', - } - - for tt_key, hf_key in mlp_mappings.items(): - if tt_key in torchtitan_state_dict: - hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() - - # Layer norms (no permutation) - norm_mappings = { - f'{layer_prefix}.attention_norm.weight': f'{hf_layer_prefix}.input_layernorm.weight', - f'{layer_prefix}.ffn_norm.weight': f'{hf_layer_prefix}.post_attention_layernorm.weight', - } - - for tt_key, hf_key in norm_mappings.items(): - if tt_key in torchtitan_state_dict: - hf_state_dict[hf_key] = torchtitan_state_dict[tt_key].clone() - - print(f"Converted {len(hf_state_dict)} parameters from TorchTitan to HuggingFace format") - return hf_state_dict - - -def map_torchtitan_to_hf2(torchtitan_state_dict, max_seq_len=131072, rope_theta=500000.0, validate_against_original=True): - """Map TorchTitan state dict to HuggingFace format using per-parameter function.""" - - # Auto-detect model family - if any(k.endswith('.attention.q_norm.weight') for k in torchtitan_state_dict): - model_family = "qwen3" - elif any(k.endswith('.attention.wq.bias') for k in torchtitan_state_dict): - model_family = "gptoss" - else: - model_family = "llama3" - - logger.info(f"Converting using per-parameter function with model_family={model_family}") - - hf_state_dict = {} - skipped_params = [] - - # Convert each parameter individually - for name, weight in tqdm(torchtitan_state_dict.items(), desc="Converting parameters"): - hf_name, hf_weight = map_torchtitan_to_hf_per_param(name, weight, model_family) - if hf_name is not None: - if hf_name in hf_state_dict: - hf_state_dict[hf_name] = hf_state_dict[hf_name] + hf_weight # NOTE: adds expert_bias into router bias - else: - hf_state_dict[hf_name] = hf_weight - else: - skipped_params.append(name) - - logger.info(f"Converted {len(hf_state_dict)} parameters, skipped {len(skipped_params)} parameters") - if skipped_params: - logger.info(f"Skipped parameters: {skipped_params}") - - # Validation against original function - if validate_against_original: - logger.info("Validating against original conversion function...") - - # Get original result - original_hf_state_dict = map_torchtitan_to_hf(torchtitan_state_dict, max_seq_len, rope_theta) - - # Compare keys - new_keys = set(hf_state_dict.keys()) - original_keys = set(original_hf_state_dict.keys()) - - if new_keys != original_keys: - missing_in_new = original_keys - new_keys - extra_in_new = new_keys - original_keys - logger.error(f"Key mismatch! Missing in new: {missing_in_new}, Extra in new: {extra_in_new}") - raise ValueError("Key sets don't match between implementations") - - # Compare tensor values - mismatched_tensors = [] - for key in original_keys: - if not torch.allclose(hf_state_dict[key], original_hf_state_dict[key], rtol=1e-5, atol=1e-8): - mismatched_tensors.append(key) - - if mismatched_tensors: - logger.error(f"Tensor value mismatches in: {mismatched_tensors}") - # Show details for first mismatch - key = mismatched_tensors[0] - logger.error(f"First mismatch in {key}:") - logger.error(f" Max abs diff: {torch.max(torch.abs(hf_state_dict[key] - original_hf_state_dict[key]))}") - logger.error(f" Original shape: {original_hf_state_dict[key].shape}") - logger.error(f" New shape: {hf_state_dict[key].shape}") - raise ValueError("Tensor values don't match between implementations") - - logger.info("āœ“ Validation passed! New implementation matches original.") - - return hf_state_dict - - -@app.command(name="hf_to_dcp") -@torch.inference_mode() -def convert_hf_to_dcp(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, dtype: str = "auto"): - """Convert HuggingFace model to TorchTitan DCP format. - - Args: - input_path: HuggingFace model name or path - output_path: Output DCP checkpoint path - max_seq_len: Max sequence length for RoPE - rope_theta: RoPE theta parameter - dtype: Data type to use ("auto" to preserve original, or specific dtype like "float32") - """ - logger.info(f"Loading model from {input_path}") - - # Load model with original dtype if "auto", otherwise use specified dtype - hf_model = AutoModelForCausalLM.from_pretrained(input_path, torch_dtype=torch.bfloat16) - - hf_state_dict = hf_model.state_dict() - logger.info(f"Loaded model with dtype: {next(iter(hf_state_dict.values())).dtype}") - - logger.info("Converting weights to TorchTitan format") - torchtitan_state_dict = map_hf_to_torchtitan(hf_state_dict, hf_model.config, max_seq_len, rope_theta, input_path) - - logger.info(f"Writing to DCP at '{output_path}'") - output_path.mkdir(parents=True, exist_ok=True) - storage_writer = DCP.filesystem.FileSystemWriter(output_path, thread_count=8) - DCP.save({"model": torchtitan_state_dict}, storage_writer=storage_writer) - logger.info("Conversion complete!") - - -@app.command(name="dcp_to_hf") -@torch.inference_mode() -def convert_dcp_to_hf(input_path: str, output_path: Path, max_seq_len: int = 131072, rope_theta: float = 500000.0, default_model: str = "meta-llama/Meta-Llama-3.1-8B", validate_against_original: bool = False): - """Convert TorchTitan DCP format to HuggingFace model. - - Args: - input_path: Input DCP checkpoint path - output_path: Output HuggingFace model path - max_seq_len: Max sequence length for RoPE - rope_theta: RoPE theta parameter - default_model: Default HuggingFace model for config - """ - - if str(input_path).startswith("s3://"): - import s3_utils - local_path = s3_utils.sync_to_nvme(str(input_path)) - input_path = Path(local_path) - - logger.info(f"Loading DCP checkpoint from {input_path}") - - from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner - from torch.distributed.checkpoint.state_dict_loader import _load_state_dict - # Load DCP input_path - state_dict = {} - _load_state_dict( - state_dict, - storage_reader=DCP.filesystem.FileSystemReader(input_path), - planner=_EmptyStateDictLoadPlanner(), - no_dist=True, - ) - torchtitan_state_dict = state_dict["model"] - logger.info("Converting weights to HuggingFace format") - hf_state_dict = map_torchtitan_to_hf2(torchtitan_state_dict, max_seq_len, rope_theta, validate_against_original=validate_against_original) - - if '/' not in default_model: - if 'qwen' in default_model.lower(): - default_model = f'Qwen/{default_model}' - elif 'llama' in default_model.lower(): - default_model = f'meta-llama/{default_model}' - else: - raise ValueError(f"Unsupported model: {default_model}") - - # Create HuggingFace config - hf_config = AutoConfig.from_pretrained(default_model) - - # Create and load model - logger.info("Creating HuggingFace model") - # tokenizer = AutoTokenizer.from_pretrained(default_model) - tokenizer = get_tokenizer_with_chat_template(default_model, "tulu", override=True) - hf_model = AutoModelForCausalLM.from_pretrained(default_model, device_map="auto") # NOTE: need device_map="auto" to avoid CPU OOM - - # load state dict - logger.info("Loading state dict") - hf_model.load_state_dict(hf_state_dict, strict=True) - - # Save model - logger.info(f"Saving model to {output_path}") - output_path.mkdir(parents=True, exist_ok=True) - hf_model.save_pretrained(output_path) - tokenizer.save_pretrained(output_path) - logger.info("Conversion complete!") - - -if __name__ == "__main__": - init_logger() - app.cli() diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml index 878e478ff5..78c8aface0 100644 --- a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -24,7 +24,7 @@ enable_wandb = false name = "gpt_oss" flavor = "debugmodel" # test tokenizer, for debug purpose only -tokenizer_path = "./tests/assets/tokenizer" +hf_asset_path = "./tests/assets/tokenizer" # converters = ["float8"] [optimizer] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml deleted file mode 100644 index 81908972ad..0000000000 --- a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_120b.toml +++ /dev/null @@ -1,70 +0,0 @@ -# torchtitan Config.toml - -[job] -dump_folder = "./outputs" -description = "GPT-OSS 120B model training" -print_args = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 10 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "gpt_oss" -flavor = "120B" -tokenizer_path = "./assets/tokenizer/GPT-OSS" -# converters = ["float8"] - -[optimizer] -name = "AdamW" -lr = 2.2e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps -decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps -decay_type = "linear" -lr_min = 2.2e-5 - -[training] -local_batch_size = 4 -seq_len = 4096 -max_norm = 1.0 # grad norm clipping -steps = 10_000 -compile = false -dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 8 -enable_async_tensor_parallel = false -expert_parallel_degree = 1 - -[checkpoint] -enable_checkpoint = false -folder = "checkpoint" -interval = 500 -last_save_model_weights_only = false -export_dtype = "float32" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" - -[activation_checkpoint] -mode = "full" # ["none", "selective", "full"] - -[float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] -moe_fqns = ["experts"] diff --git a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml b/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml deleted file mode 100644 index 88d1c4d27f..0000000000 --- a/torchtitan/experiments/gpt_oss/train_configs/gpt_oss_20b.toml +++ /dev/null @@ -1,70 +0,0 @@ -# torchtitan Config.toml - -[job] -dump_folder = "./outputs" -description = "GPT-OSS 20B model training" -print_args = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 10 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "gpt_oss" -flavor = "20B" -tokenizer_path = "./assets/tokenizer/GPT-OSS" -# converters = ["float8"] - -[optimizer] -name = "AdamW" -lr = 2.2e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps -decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps -decay_type = "linear" -lr_min = 2.2e-5 - -[training] -local_batch_size = 8 -seq_len = 4096 -max_norm = 1.0 # grad norm clipping -steps = 1000 -compile = false -dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -expert_parallel_degree = 1 - -[checkpoint] -enable_checkpoint = false -folder = "checkpoint" -interval = 10 -last_save_model_weights_only = false -export_dtype = "float32" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" - -[activation_checkpoint] -mode = "full" # ["none", "selective", "full"] - -[float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] -moe_fqns = ["experts"] From 122e93a18ad5b62a00cbaf9775c065e3b57e16e8 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 23 Sep 2025 16:00:43 -0700 Subject: [PATCH 08/18] rebase --- torchtitan/experiments/__init__.py | 2 +- torchtitan/experiments/gpt_oss/__init__.py | 11 +- .../gpt_oss/infra/expert_parallel.py | 2 +- .../experiments/gpt_oss/infra/parallelize.py | 148 ++++++------------ torchtitan/experiments/gpt_oss/model/args.py | 5 +- torchtitan/experiments/gpt_oss/model/model.py | 23 ++- torchtitan/experiments/gpt_oss/model/moe.py | 10 +- .../gpt_oss/train_configs/debug_model.toml | 36 +++-- torchtitan/models/attention.py | 83 ++++++---- 9 files changed, 157 insertions(+), 163 deletions(-) diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 74c7eaec9b..b73ddc8458 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -_supported_experiments = frozenset(["flux", "llama4", "qwen3", "simple_fsdp", "vlm"]) +_supported_experiments = frozenset(["flux", "llama4", "qwen3", "simple_fsdp", "vlm", "gpt_oss"]) diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index 8603588f55..e4e288c09f 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -6,7 +6,6 @@ from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.moe import MoEArgs -from .infra.optimizer import build_gptoss_optimizers from torchtitan.protocols.train_spec import register_train_spec, TrainSpec from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing @@ -27,7 +26,6 @@ "debugmodel": GptOssModelArgs( hidden_size=256, num_hidden_layers=4, - use_flex_attn=False, moe_args = MoEArgs( num_experts=8, num_shared_experts=0, @@ -38,7 +36,9 @@ top_k=4, use_grouped_mm=False, load_balance_coeff=1e-3, - ) + ), + use_flex_attn=True, + attn_mask_type="causal" ), "20b": GptOssModelArgs( num_hidden_layers=24, @@ -71,8 +71,8 @@ } -register_train_spec( - TrainSpec( +def get_train_spec() -> TrainSpec: + return TrainSpec( name="gpt_oss", model_cls=GptOssModel, model_args=gptoss_configs, @@ -84,4 +84,3 @@ build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, ) -) diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py index 62775b8b67..512bc8f6fd 100644 --- a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -93,7 +93,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ) -# TODO(jianiw): This need to be merged with +# TODO(jianiw): This need to be merged with expert_parallel def expert_parallel(func: Callable) -> Callable: """ This is a wrapper applied to the GroupedExperts computation, serving diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 0d3f9a9f54..1a71abff66 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -15,7 +15,8 @@ else: print(f"Since torch version {torch.__version__} < 2.9, PrepareModuleInputOutput is not available and MoE EP TP will fail.") -from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config.job_config import JobConfig +from torchtitan.config import TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger @@ -26,22 +27,46 @@ from .expert_parallel import ( ExpertParallel, ExpertTensorParallel, - NoParallel, TensorParallel, ) from torchtitan.distributed import NoParallel +# for selective op activation checkpointing +_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_to_all_single.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + # Adapted from llama4/infra/parallelize.py def parallelize_gptoss( model: nn.Module, - world_mesh: DeviceMesh, parallel_dims: ParallelDims, job_config: JobConfig, ): + world_mesh = parallel_dims.world_mesh + + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + raise NotImplementedError("CP support for FlexAttention is still in progress.") + if parallel_dims.tp_enabled: if job_config.parallelism.enable_async_tensor_parallel: - # TODO(jianiw): This branch needs to be tested and enabled raise NotImplementedError( "Currently, async TP is not tested for gptoss. \ torch.compile is not supported yet, which is required for async TP." @@ -55,7 +80,6 @@ def parallelize_gptoss( enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise if enable_float8_tensorwise_tp: - # TODO(jianiw): This branch needs to be tested and enabled raise NotImplementedError( "Currently, float8 tensorwise TP is not tested for gptoss" ) @@ -79,13 +103,20 @@ def parallelize_gptoss( else None ), ) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if job_config.activation_checkpoint.mode != "none": - apply_ac(model, job_config.activation_checkpoint) - - if job_config.training.compile: - raise NotImplementedError("torch.compile is not supported yet for gptoss") - + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + save_list=_save_list, + ) + dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel @@ -134,96 +165,12 @@ def parallelize_gptoss( apply_ddp( model, dp_mesh, - enable_compile=job_config.training.compile, + enable_compile=model_compile_enabled, enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) return model - -def apply_non_moe_tp( - model: nn.Module, - tp_mesh: DeviceMesh, - loss_parallel: bool, - enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, -): - """Apply tensor parallelism.""" - # 1. Parallelize the embedding and shard its outputs (which are the first - # transformer block's inputs) - # 2. Parallelize the root norm layer over the sequence dim - # 3. Parallelize the final linear output layer - parallelize_module( - model, - tp_mesh, - { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "norm": SequenceParallel(), - "output": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, - ), - }, - ) - - rowwise_parallel, colwise_parallel, prepare_module_input = ( - RowwiseParallel, - ColwiseParallel, - PrepareModuleInput, - ) - - # Apply tensor + sequence parallelism to every transformer block - # NOTE: At the cost of model code change, we can accelerate Sequence Parallel - # by folding (and unfolding) the batch dimension and the sequence dimension. - # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 - for transformer_block in model.layers.values(): - layer_plan = { - "attention_norm": SequenceParallel(), - "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), - ), - # use_local_output=False make the output to be a DTensor instead of a plain Tensor - "attention.wkv_a": NoParallel(use_local_output=False), - "attention.wkv_b": colwise_parallel(use_local_output=False), - "attention.kv_norm": NoParallel(use_local_output=False), - "attention.wo": rowwise_parallel(output_layouts=Shard(1)), - "ffn_norm": SequenceParallel(), - } - - if transformer_block.attention.q_lora_rank == 0: - layer_plan.update( - { - "attention.wq": colwise_parallel( - use_local_output=False - ), # This is only used when q_lora_rank==0 - } - ) - else: - layer_plan.update( - { - "attention.wq_a": NoParallel(use_local_output=False), - "attention.wq_b": colwise_parallel(use_local_output=False), - "attention.q_norm": NoParallel(use_local_output=False), - } - ) - - parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=layer_plan, - ) - - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" - "Tensor Parallelism to the model" - ) - - def apply_non_moe_tp( model: nn.Module, tp_mesh: DeviceMesh, @@ -280,12 +227,12 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), None), - desired_input_layouts=(Replicate(), None), + input_layouts=(Shard(1), Replicate()), + desired_input_layouts=(Replicate(), Replicate()), ), - "attention.wq": colwise_parallel(), - "attention.wk": colwise_parallel(), - "attention.wv": colwise_parallel(), + "attention.wq": colwise_parallel(use_local_output=False), + "attention.wk": colwise_parallel(use_local_output=False), + "attention.wv": colwise_parallel(use_local_output=False), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } @@ -297,6 +244,7 @@ def apply_non_moe_tp( ) # shard attention.sinks across heads + # TODO(jianiw): Fix the sink implementation attn = transformer_block.attention attn.register_parameter( "sinks", diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index e5d09e856a..583dbc848f 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -8,7 +8,7 @@ from torch import nn -from torchtitan.config_manager import JobConfig +from torchtitan.config.job_config import JobConfig from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger from torchtitan.models.moe import MoEArgs @@ -58,6 +58,7 @@ class GptOssModelArgs(BaseModelArgs): dtype: Literal["bf16", "fp8"] = "bf16" vocab_size: int = 201088 hidden_size: int = 2880 + moe_inter_dim: int = 2880 num_hidden_layers: int = 24 norm_eps: float = 1e-5 # eps used for RMSNorm # MoE @@ -124,7 +125,7 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in nparams_sparse_active = ( nparams_moe_router + nparams_shared_expert - + nparams_experts * self.num_experts_per_tok // self.num_local_experts + + nparams_experts * self.moe_args.top_k // self.moe_args.num_experts ) logger.info( diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index cbc2a4a25d..07345d7658 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -7,11 +7,12 @@ import torch from torch import nn from torch.distributed.tensor import DTensor +from torchtitan.experiments.simple_fsdp import model from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol from .args import GptOssModelArgs -from .moe import MoE +from .moe import GptOssMoE # Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 @@ -190,6 +191,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +# TODO(jianw): This is eager version from HuggingFace def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, @@ -257,8 +259,13 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa self.sinks = nn.Parameter(torch.empty(model_args.num_attention_heads)) self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: - self.attn = build_attention(True, model_args.attn_mask_type) + # Only apply sliding window to every other layer + if use_sliding_attention: + self.attn = build_attention(use_flex_attn=True, attn_mask_type="sliding_window", sliding_window=self.sliding_window) + else: + self.attn = build_attention(use_flex_attn=True, attn_mask_type=model_args.attn_mask_type) else: # NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed self.attn = eager_attention_forward @@ -294,10 +301,11 @@ def forward( if self.use_flex_attn: output = self.attn( q, k, v, - self.sinks.to_local() if isinstance(self.sinks, DTensor) else self.sinks, - sliding_window=self.sliding_window, - enable_gqa=True, - ) + scale=None, + sink_weights=self.sinks.to_local() if isinstance(self.sinks, DTensor) else self.sinks, + # sliding_window=self.sliding_window, + enable_gqa=True if self.sliding_window else False, + ) else: # eager attention forward output = self.attn( @@ -352,7 +360,8 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs): self.attention_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) self.ffn_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) - self.moe = MoE(model_args) + self.moe = GptOssMoE(model_args, dim=model_args.hidden_size, hidden_dim=model_args.moe_inter_dim) + self.moe_enabled = True # for composability with load balancing self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py index 2fbbc38e28..40fdac1887 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -5,7 +5,7 @@ from torch.distributed.tensor import DTensor import torch.nn.functional as F from torch import nn -from torchtitan.models.gpt_oss.infra.expert_parallel import expert_parallel +from torchtitan.experiments.gpt_oss.infra.expert_parallel import expert_parallel from torchtitan.protocols import model from .args import GptOssModelArgs @@ -24,6 +24,7 @@ class GptOssGroupedExperts(nn.Module): def __init__( self, dim: int, + hidden_dim: int, num_experts: int, swiglu_limit: float, use_grouped_mm: bool, @@ -33,9 +34,9 @@ def __init__( self.use_grouped_mm = use_grouped_mm self.swiglu_limit = swiglu_limit - self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, dim * 2))) # w1 and w3 - self.mlp1_bias = nn.Parameter(torch.empty((num_experts, dim * 2))) - self.mlp2_weight = nn.Parameter(torch.empty((num_experts, dim, dim))) + self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, hidden_dim * 2))) # w1 and w3 + self.mlp1_bias = nn.Parameter(torch.empty((num_experts, hidden_dim * 2))) + self.mlp2_weight = nn.Parameter(torch.empty((num_experts, hidden_dim, dim))) self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) def forward( @@ -168,6 +169,7 @@ def __init__(self, model_args: GptOssModelArgs, dim: int, hidden_dim: int): # Override the base GroupedExperts with GptOssGroupedExperts self.experts = GptOssGroupedExperts( dim=dim, + hidden_dim=hidden_dim, num_experts=moe_args.num_experts, swiglu_limit=model_args.swiglu_limit, use_grouped_mm=moe_args.use_grouped_mm, diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml index 78c8aface0..22bfd10e1b 100644 --- a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -1,10 +1,7 @@ -# torchtitan Config.toml - [job] dump_folder = "./outputs" -description = "GPT-OSS debug training" +description = "Gpt-oss debug training" print_args = false -use_for_integration_test = true [profiling] enable_profiling = false @@ -23,8 +20,8 @@ enable_wandb = false [model] name = "gpt_oss" flavor = "debugmodel" -# test tokenizer, for debug purpose only -hf_asset_path = "./tests/assets/tokenizer" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" # converters = ["float8"] [optimizer] @@ -36,29 +33,29 @@ eps = 1e-8 warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps decay_type = "linear" -lr_min = 0.0 +min_lr_factor = 0.0 [training] local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping -steps = 1 -compile = false +steps = 10 dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false -expert_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 [checkpoint] -enable_checkpoint = false +enable = false folder = "checkpoint" interval = 10 -last_save_model_weights_only = false +last_save_model_only = false export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] @@ -66,8 +63,17 @@ async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] mode = "none" # ["none", "selective", "full"] selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy +[compile] +enable=false +components = ["model", "loss"] + [float8] enable_fsdp_float8_all_gather = false precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] -moe_fqns = ["experts"] +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 3c3b607571..b2a42d0c9b 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -25,7 +25,7 @@ # FlexAttention mask type. For each mask type, we initialize it at most once per # batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to # track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None] +FLEX_ATTN_MASK_T = tuple[str, int | None, int | None] # (mask_type, fixed_block_size, sliding_window) class FlexAttention(torch.nn.Module): @@ -64,20 +64,21 @@ class FlexAttention(torch.nn.Module): attn_mask_type: str def __init__( - self, attn_mask_type: str, fixed_block_size: int | None = None + self, attn_mask_type: str, fixed_block_size: int | None = None, sliding_window: int | None = None ) -> None: super().__init__() - if attn_mask_type not in ["causal", "block_causal"]: + if attn_mask_type not in ["causal", "block_causal", "sliding_window"]: raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") self.attn_mask_type = attn_mask_type self.fixed_block_size = fixed_block_size + self.sliding_window = sliding_window self.mask_cache = {} FlexAttention.used_attn_mask_types.add(self.mask_key) @property def mask_key(self) -> FLEX_ATTN_MASK_T: - return (self.attn_mask_type, self.fixed_block_size) + return (self.attn_mask_type, self.fixed_block_size, self.sliding_window) def forward( self, @@ -86,47 +87,60 @@ def forward( v: torch.Tensor, scale: float | None = None, sink_weights: torch.Tensor | None = None, - sliding_window: int = 0, + # sliding_window: int = 0, enable_gqa: bool = False, ) -> torch.Tensor: - if sink_weights is None: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) - + + # Use sink logic when sliding_window is used and sink_weights is provided + if self.attn_mask_type == "sliding_window" and sink_weights is not None: + return self._forward_with_sink(q, k, v, scale, sink_weights, enable_gqa) + + # Regular path without sink - use pre-compiled block masks + block_mask = FlexAttention.block_masks[self.mask_key] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + + def _forward_with_sink( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + scale: float | None = None, + sink_weights: torch.Tensor | None = None, + enable_gqa: bool = False, + ) -> torch.Tensor: + """Forward pass with attention sink for sliding window attention.""" B, H_q, S_q, D = q.shape _, H_kv, S_kv, _ = k.shape - # regular (no-sink) mask + no extra KV col - mask_key = (sliding_window, S_q, S_kv) + if self.sliding_window is None or self.sliding_window <= 0: + raise RuntimeError("sliding_window must be configured for sliding_window attention type") + mask_key = ("sliding_window_sink", self.sliding_window, S_q, S_kv) if mask_key not in self.mask_cache: - if sliding_window is not None and sliding_window > 0: - mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window) - else: - mask_mod = FlexAttention._get_causal_mask_mod() + mask_mod = FlexAttention._get_sliding_window_mask_mod(self.sliding_window) block_mask = create_block_mask( mask_mod, B, H_q, S_q, S_kv, - _compile=True, device=q.device # NOTE: set _compile=False if sampling for debugging + _compile=True, device=q.device ) self.mask_cache[mask_key] = block_mask - block_mask = self.mask_cache[mask_key] - # run fast flex_attn and return LSE + # Run flex_attn and return LSE for sink computation out, lse = FlexAttention.flex_attn( q, k, v, block_mask=block_mask, enable_gqa=enable_gqa, - return_lse=True + return_lse=True, + scale=scale ) - # rescale by sigma(lse - w[h]) and broadcast over D + # Apply attention sink rescaling: rescale by σ(lse - w[h]) + # This is mathematically equivalent to concatenating learnable sink weights if sink_weights is not None: w = sink_weights # [H] - scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1] - out = out * scale + sink_scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1] + out = out * sink_scale - out = out.to(q.dtype) - return out + return out.to(q.dtype) @staticmethod def _get_sliding_window_mask_mod(window: int): @@ -208,7 +222,7 @@ def blocked_mask_mod( def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: # batch is [b, s, h, d] shape for mask_key in FlexAttention.used_attn_mask_types: - attn_mask_type, fixed_block_size = mask_key + attn_mask_type, fixed_block_size, sliding_window = mask_key match attn_mask_type: case "causal": if FlexAttention.block_masks.get(mask_key, None) is not None: @@ -224,6 +238,17 @@ def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: ) batch_dimension = batch.shape[0] mask_mod = FlexAttention._get_block_causal_mask_mod(batch, eos_id) + case "sliding_window": + if sliding_window is None or sliding_window <= 0: + raise RuntimeError( + "sliding_window must be provided and > 0 for sliding_window mask." + ) + if FlexAttention.block_masks.get(mask_key, None) is not None: + continue + # We don't care about batch dimension -- + # all samples have the same sliding window mask. + batch_dimension = 1 + mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window) case _: raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") @@ -278,15 +303,19 @@ def forward( def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None + use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None, sliding_window: int | None = None ): if use_flex_attn: - return FlexAttention(attn_mask_type, fixed_block_size) + return FlexAttention(attn_mask_type, fixed_block_size, sliding_window) else: if fixed_block_size is not None: raise ValueError( "TorchTitan with SDPA currently does not support fixed_block_size." ) + if sliding_window is not None: + raise ValueError( + "TorchTitan with SDPA currently does not support sliding_window." + ) if attn_mask_type != "causal": raise ValueError( "TorchTitan with SDPA currently only supports causal mask." From 589ce620cf14f259452f1b3324b2790d72e8a3b0 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 24 Sep 2025 13:21:36 -0700 Subject: [PATCH 09/18] fix flexattn --- torchtitan/experiments/gpt_oss/model/model.py | 41 ++++++------ torchtitan/models/attention.py | 64 +++++++++---------- 2 files changed, 53 insertions(+), 52 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index 07345d7658..e57ba30e0d 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -180,16 +180,17 @@ def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor): return q_out, k_out # Torch Attention backup implementation (for debugging and sampling) from HuggingFace -def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: - """ - This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, - num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) - """ - batch, num_key_value_heads, slen, head_dim = hidden_states.shape +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: - return hidden_states - hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) - return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + # TODO(jianw): This is eager version from HuggingFace def eager_attention_forward( @@ -200,12 +201,9 @@ def eager_attention_forward( attention_mask: torch.Tensor, scaling: float, dropout: float = 0.0, - num_key_value_groups: int = 1, **kwargs, ): - key_states = repeat_kv(key, num_key_value_groups) - value_states = repeat_kv(value, num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: # attention_mask can be [Tq, Tk] or [B, H, Tq, Tk] # Convert boolean "allowed" -> additive mask @@ -230,7 +228,7 @@ def eager_attention_forward( probs = nn.functional.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) scores = probs[..., :-1] # we drop the sink here attn_weights = nn.functional.dropout(scores, p=dropout, training=False) - attn_output = torch.matmul(attn_weights, value_states) + attn_output = torch.matmul(attn_weights, value) return attn_output class Attention(nn.Module): @@ -243,6 +241,10 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa self.sliding_window = model_args.sliding_window if use_sliding_attention else None self.head_dim = model_args.head_dim + self.n_heads = model_args.num_attention_heads + self.n_kv_heads = model_args.num_key_value_heads + + self.n_rep = self.n_heads // self.n_kv_heads self.wq = nn.Linear( model_args.hidden_size, model_args.num_attention_heads * model_args.head_dim, bias=True @@ -294,17 +296,19 @@ def forward( q, k = apply_rotary_emb(q, k, freqs_cis) + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(k, self.n_rep) + values = repeat_kv(v, self.n_rep) + q = q.transpose(1, 2).contiguous() - k = k.transpose(1, 2).contiguous() - v = v.transpose(1, 2).contiguous() + k = keys.transpose(1, 2).contiguous() + v = values.transpose(1, 2).contiguous() if self.use_flex_attn: output = self.attn( q, k, v, scale=None, sink_weights=self.sinks.to_local() if isinstance(self.sinks, DTensor) else self.sinks, - # sliding_window=self.sliding_window, - enable_gqa=True if self.sliding_window else False, ) else: # eager attention forward @@ -313,7 +317,6 @@ def forward( attention_mask=self.sliding_window_causal(seqlen, x.device), scaling=self.head_dim**-0.5, dropout=0.0, - num_key_value_groups=8, ) output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index b2a42d0c9b..54df26ef9d 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -25,7 +25,9 @@ # FlexAttention mask type. For each mask type, we initialize it at most once per # batch. To record what it is initialized, FLEX_ATTN_MASK_T is used as the key to # track the initialized mask. -FLEX_ATTN_MASK_T = tuple[str, int | None, int | None] # (mask_type, fixed_block_size, sliding_window) +FLEX_ATTN_MASK_T = tuple[ + str, int | None, int | None +] # (mask_type, fixed_block_size, sliding_window) class FlexAttention(torch.nn.Module): @@ -64,7 +66,10 @@ class FlexAttention(torch.nn.Module): attn_mask_type: str def __init__( - self, attn_mask_type: str, fixed_block_size: int | None = None, sliding_window: int | None = None + self, + attn_mask_type: str, + fixed_block_size: int | None = None, + sliding_window: int | None = None, ) -> None: super().__init__() if attn_mask_type not in ["causal", "block_causal", "sliding_window"]: @@ -73,7 +78,6 @@ def __init__( self.fixed_block_size = fixed_block_size self.sliding_window = sliding_window - self.mask_cache = {} FlexAttention.used_attn_mask_types.add(self.mask_key) @property @@ -87,57 +91,44 @@ def forward( v: torch.Tensor, scale: float | None = None, sink_weights: torch.Tensor | None = None, - # sliding_window: int = 0, - enable_gqa: bool = False, ) -> torch.Tensor: - + # Use sink logic when sliding_window is used and sink_weights is provided if self.attn_mask_type == "sliding_window" and sink_weights is not None: - return self._forward_with_sink(q, k, v, scale, sink_weights, enable_gqa) - - # Regular path without sink - use pre-compiled block masks + return self._forward_with_sink(q, k, v, scale, sink_weights) + + # Regular path without sink block_mask = FlexAttention.block_masks[self.mask_key] return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) - + def _forward_with_sink( self, q: torch.Tensor, - k: torch.Tensor, + k: torch.Tensor, v: torch.Tensor, scale: float | None = None, sink_weights: torch.Tensor | None = None, - enable_gqa: bool = False, ) -> torch.Tensor: """Forward pass with attention sink for sliding window attention.""" - B, H_q, S_q, D = q.shape - _, H_kv, S_kv, _ = k.shape - - if self.sliding_window is None or self.sliding_window <= 0: - raise RuntimeError("sliding_window must be configured for sliding_window attention type") - mask_key = ("sliding_window_sink", self.sliding_window, S_q, S_kv) - if mask_key not in self.mask_cache: - mask_mod = FlexAttention._get_sliding_window_mask_mod(self.sliding_window) - block_mask = create_block_mask( - mask_mod, B, H_q, S_q, S_kv, - _compile=True, device=q.device - ) - self.mask_cache[mask_key] = block_mask - block_mask = self.mask_cache[mask_key] + # Use the pre-compiled static block mask + block_mask = FlexAttention.block_masks[self.mask_key] # Run flex_attn and return LSE for sink computation out, lse = FlexAttention.flex_attn( - q, k, v, + q, + k, + v, block_mask=block_mask, - enable_gqa=enable_gqa, return_lse=True, - scale=scale + scale=scale, ) # Apply attention sink rescaling: rescale by σ(lse - w[h]) # This is mathematically equivalent to concatenating learnable sink weights if sink_weights is not None: - w = sink_weights # [H] - sink_scale = torch.sigmoid(lse - w.view(1, -1, 1)).unsqueeze(-1) # [B,H,S,1] + sink_scale = torch.sigmoid(lse - sink_weights.view(1, -1, 1)).unsqueeze( + -1 + ) # [B,H,S,1] out = out * sink_scale return out.to(q.dtype) @@ -149,10 +140,12 @@ def _get_sliding_window_mask_mod(window: int): - only allows kv_idx ≤ q_idx (causal) - and only if (q_idx - kv_idx) ≤ window """ + def sliding_mod(b, h, q_idx, kv_idx): # causal within window keep = (kv_idx <= q_idx) & (q_idx - kv_idx <= window) return keep + return sliding_mod @staticmethod @@ -248,7 +241,9 @@ def init_attention_mask(batch: torch.Tensor, eos_id: int | None) -> None: # We don't care about batch dimension -- # all samples have the same sliding window mask. batch_dimension = 1 - mask_mod = FlexAttention._get_sliding_window_mask_mod(sliding_window) + mask_mod = FlexAttention._get_sliding_window_mask_mod( + sliding_window + ) case _: raise RuntimeError(f"Shouldn't reach here. {attn_mask_type}") @@ -303,7 +298,10 @@ def forward( def build_attention( - use_flex_attn: bool, attn_mask_type: str, fixed_block_size: int | None = None, sliding_window: int | None = None + use_flex_attn: bool, + attn_mask_type: str, + fixed_block_size: int | None = None, + sliding_window: int | None = None, ): if use_flex_attn: return FlexAttention(attn_mask_type, fixed_block_size, sliding_window) From 4fc78a38e802fa1a887f77d17583432538f6b285 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Wed, 24 Sep 2025 15:54:04 -0700 Subject: [PATCH 10/18] check and replace rope --- torchtitan/experiments/gpt_oss/model/model.py | 241 ++++++------------ 1 file changed, 77 insertions(+), 164 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index e57ba30e0d..f774b0ffd9 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -15,171 +15,78 @@ from .moe import GptOssMoE -# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 -def precompute_freqs_cis(args: GptOssModelArgs) -> torch.Tensor: - """ - Precomputes frequency-based complex exponential values for rotary positional embeddings. - - Args: - args (GptOssModelArgs): Model arguments containing positional embedding parameters. - - Returns: - torch.Tensor: Precomputed complex exponential values for positional embeddings. - """ - dim = args.head_dim - seqlen = args.max_seq_len - beta_fast = args.beta_fast - beta_slow = args.beta_slow - base = args.rope_theta - factor = args.rope_factor - original_seq_len = args.original_seq_len - - # YaRN default m-scale (attention_factor). Matches HF when attention_factor is None. - mscale = 0.1 * math.log(factor) + 1.0 - - def find_correction_dim( - num_rotations: float, dim: int, base: float, max_seq_len: int - ) -> float: - """ - Computes the correction dimension for a given number of rotations in the rotary positional embedding. - - Args: - num_rotations (float): Number of rotations to compute the correction for. - dim (int): Dimensionality of the embedding space. - base (float): Base value for the exponential computation. - max_seq_len (int): Maximum sequence length. - - Returns: - float: The correction dimension based on the input parameters. - """ - return ( - dim - * math.log(max_seq_len / (num_rotations * 2 * math.pi)) - / (2 * math.log(base)) - ) - - def find_correction_range( - low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int - ) -> Tuple[int, int]: - """ - Computes the range of correction dimensions for rotary positional embeddings. - - Args: - low_rot (float): Lower bound for the number of rotations. - high_rot (float): Upper bound for the number of rotations. - dim (int): Dimensionality of the embedding space. - base (float): Base value for the exponential computation. - max_seq_len (int): Maximum sequence length. +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) - Returns: - Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. - """ - low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) - high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) - return max(low, 0), min(high, dim - 1) + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.outer(t, freqs).float() - def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: - """ - Computes a linear ramp function used to smooth values between a minimum and maximum range. - - Args: - min (float): Minimum value for the ramp function. - max (float): Maximum value for the ramp function. - dim (int): Dimensionality of the ramp tensor. - - Returns: - torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, - clamped to the range [0, 1]. - """ - if min == max: - max += 0.001 - linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) - ramp_func = torch.clamp(linear_func, 0, 1) - return ramp_func - - # Basic RoPE frequency calculation - freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) - - # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. - if seqlen > original_seq_len: - low, high = find_correction_range( - beta_fast, beta_slow, dim, base, original_seq_len - ) - smooth = 1 - linear_ramp_factor(low, high, dim // 2) - freqs = freqs / factor * (1 - smooth) + freqs * smooth + # We cache the cos and sin embeddings instead of the IDs. This helps + # ensure we have correct behavior when training with bf16 + # Size: [max_seq_len, (dim * 2)] + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache - # Create position indices - t = torch.arange(seqlen) - # Outer product: [positions] Ɨ [frequencies] - freqs = torch.outer(t, freqs) +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) - # Convert to complex exponentials: e^(i*freq*pos) - freqs_cis = torch.polar(torch.full_like(freqs, fill_value=mscale), freqs) - return freqs_cis +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. -def apply_rotary_emb_inner(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: - """ - Applies rotary positional embeddings to the input tensor. + The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2), + and the first seqlen elements will be sliced, but dim must match x. Args: - x (torch.Tensor): Input tensor with positional embeddings to be applied. - freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. Returns: - torch.Tensor: Tensor with rotary embeddings applied. + torch.Tensor: Reshaped frequency tensor. """ - dtype = x.dtype - x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) - freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) - y = torch.view_as_real(x * freqs_cis).flatten(3) - return y.to(dtype) + ndim = x.ndim + assert ndim > 1 + _, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + # input tensor x has shape [bsz, seq_len, num_heads, head_dim] + head_dim = xq.shape[-1] + + # reshape for broadcast + rope_cache = reshape_for_broadcast(rope_cache, xq) + + # [bsz, seq_len, 1, head_dim] + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + + # xq: [bsz, seq_len, num_heads, head_dim] + # xk: [bsz, seq_len, num_kv_heads, head_dim] + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) -def apply_rotary_emb(q: torch.Tensor, k: torch.Tensor, freqs_cis: torch.Tensor): - """ - HF-style inputs (half-split last dim) -> interleave -> Torchtitan complex RoPE -> de-interleave. - Shapes: - q, k: [B, T, H, D] with D even (HF half-split: first D/2 real, last D/2 imag) - freqs_cis: complex, last dim == D/2. Typically [T, D/2] or [1, T, D/2]. - Returns: - q_out, k_out in HF half-split layout (same shape as q, k). - """ - B, T, H, D = q.shape - assert D % 2 == 0, "head_dim must be even for RoPE" - rot = D // 2 - assert freqs_cis.shape[-1] == rot, "freqs_cis last dim must be D/2" - freqs_cis = freqs_cis[:T, :] - - # Memory layout comparison for head_dim=8: - # HF Format: [r0][r1][r2][r3][i0][i1][i2][i3] - # ↑-- reals --↑ ↑-- imags --↑ - - # Interleaved: [r0][i0][r1][i1][r2][i2][r3][i3] - # ↑-pair-↑ ↑-pair-↑ ↑-pair-↑ ↑-pair-↑ - # --- inline: HF half-split -> interleaved (real0, imag0, real1, imag1, ...) - # q_i, k_i: [B, T, H, D] - q_i = torch.empty_like(q) - k_i = torch.empty_like(k) - q_i[..., 0::2] = q[..., :rot] - q_i[..., 1::2] = q[..., rot:] - k_i[..., 0::2] = k[..., :rot] - k_i[..., 1::2] = k[..., rot:] - - # --- Torchtitan default complex apply (expects interleaved last dim) - # freqs_cis will be reshaped inside to [1, T, 1, rot] - # TODO(jianiw): I think we shoud go with sin/cos representation to simplify the conversion between paired real/imaginary <-> half-split real/imaginary - q_rot_i = apply_rotary_emb_inner(q_i, freqs_cis) # uses TT's complex path - k_rot_i = apply_rotary_emb_inner(k_i, freqs_cis) - - # --- inline: interleaved -> HF half-split - # TODO(jianiw): convert it back - q_out = torch.cat([q_rot_i[..., 0::2], q_rot_i[..., 1::2]], dim=-1) - k_out = torch.cat([k_rot_i[..., 0::2], k_rot_i[..., 1::2]], dim=-1) - return q_out, k_out - -# Torch Attention backup implementation (for debugging and sampling) from HuggingFace def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape @@ -215,7 +122,7 @@ def eager_attention_forward( add_mask = attention_mask.to(attn_weights.dtype) # Truncate to current key length and add (broadcasts if needed) - add_mask = add_mask[..., : key_states.shape[-2]] + add_mask = add_mask[..., : key.shape[-2]] attn_weights = attn_weights + add_mask sinks = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) @@ -275,14 +182,14 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa def forward( self, x: torch.Tensor, - freqs_cis: torch.Tensor, + rope_cache: torch.Tensor, ): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). - freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies for rope embedding. Returns: torch.Tensor: Output tensor with the same shape as the input. @@ -294,7 +201,7 @@ def forward( k = self.wk(x).view(hidden_shape) v = self.wv(x).view(hidden_shape) - q, k = apply_rotary_emb(q, k, freqs_cis) + q, k = apply_rotary_emb(q, k, rope_cache) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(k, self.n_rep) @@ -369,18 +276,18 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs): self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 self.layer_id = layer_id - def forward(self, x: torch.Tensor, freqs_cis: torch.Tensor): + def forward(self, x: torch.Tensor, rope_cache: torch.Tensor): """ Forward pass for the Transformer block. Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). - freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. Returns: torch.Tensor: Output tensor with the same shape as the input. """ - x = x + self.attention(self.attention_norm(x), freqs_cis) + x = x + self.attention(self.attention_norm(x), rope_cache) x = x + self.moe(self.ffn_norm(x)) return x @@ -398,16 +305,16 @@ class GptOssModel(nn.Module, ModelProtocol): def __init__(self, model_args: GptOssModelArgs): super().__init__() + self.model_args = model_args self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_size) self.register_buffer( - "freqs_cis", precompute_freqs_cis(model_args), persistent=True + "rope_cache", self._precompute_rope_cache(), persistent=False ) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.num_hidden_layers): self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16) - # convert_submodules_to_bf16(self.layers[str(layer_id)]) self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) self.output = nn.Linear( @@ -418,12 +325,11 @@ def __init__(self, model_args: GptOssModelArgs): ) self.model_args = model_args self.init_weights() - # convert_submodules_to_bf16(self) def init_weights(self, buffer_device: torch.device | None = None) -> None: - buffer_device = buffer_device or self.freqs_cis.device + buffer_device = buffer_device or self.rope_cache.device with torch.device(buffer_device): - self.freqs_cis = precompute_freqs_cis(self.model_args) + self.rope_cache = self._precompute_rope_cache() if self.tok_embeddings is not None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): @@ -442,6 +348,13 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: b=cutoff_factor * final_out_std, ) + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.head_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + def forward(self, tokens: torch.Tensor): """ Forward pass for the Transformer model. @@ -455,7 +368,7 @@ def forward(self, tokens: torch.Tensor): h = self.tok_embeddings(tokens) for layer in self.layers.values(): - h = layer(h, self.freqs_cis) + h = layer(h, self.rope_cache) h = self.norm(h) output = self.output(h) return output From b28fe7c092b88f058e81406622a16d3b9bd1e93f Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 25 Sep 2025 16:48:48 -0700 Subject: [PATCH 11/18] FSDP work, TP doesn't work --- torchtitan/distributed/utils.py | 2 +- torchtitan/experiments/gpt_oss/__init__.py | 2 +- .../experiments/gpt_oss/infra/parallelize.py | 118 ++++++------------ torchtitan/experiments/gpt_oss/model/model.py | 26 +++- .../gpt_oss/train_configs/debug_model.toml | 2 +- torchtitan/models/attention.py | 47 +------ torchtitan/train.py | 1 - 7 files changed, 68 insertions(+), 130 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 75e4fe4ed5..b3ecf2c80a 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -186,7 +186,7 @@ def create_context_parallel_ctx( def get_train_context( - enable_loss_parallel: bool, enable_compiled_autograd: bool, use_sdpa: bool = True + enable_loss_parallel: bool, enable_compiled_autograd: bool ) -> Generator[None, None, None]: @contextlib.contextmanager def context(cp_context: Generator[None, None, None] | None = None): diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index e4e288c09f..532b1ccfcc 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -34,7 +34,7 @@ route_scale=1.0, score_before_experts=False, top_k=4, - use_grouped_mm=False, + use_grouped_mm=True, load_balance_coeff=1e-3, ), use_flex_attn=True, diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 1a71abff66..e3f3842c60 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -17,23 +17,26 @@ from torchtitan.config.job_config import JobConfig from torchtitan.config import TORCH_DTYPE_MAP -from torchtitan.distributed import ParallelDims +from torchtitan.distributed import ParallelDims, NoParallel +from torchtitan.distributed.expert_parallel import ExpertParallel, ReordererSequenceParallel from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.experiments.llama4.infra.parallelize import ( + apply_fsdp, + apply_moe_ep_tp, +) + from torchtitan.tools.logging import logger -from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy from torch.distributed.tensor import Partial, Replicate, Shard from .expert_parallel import ( - ExpertParallel, ExpertTensorParallel, TensorParallel, ) -from torchtitan.distributed import NoParallel # for selective op activation checkpointing -_save_list = { +_op_sac_save_list = { torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, @@ -87,7 +90,7 @@ def parallelize_gptoss( apply_non_moe_tp( model, world_mesh["tp"], - loss_parallel=parallel_dims.loss_parallel_enabled, + loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, enable_async_tp=False, ) @@ -99,9 +102,10 @@ def parallelize_gptoss( ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, ep_tp_mesh=( world_mesh["ep", "tp"] - if parallel_dims.tp_enabled and parallel_dims.ep_enabled + if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled else None ), + etp_enabled=parallel_dims.etp_enabled, ) model_compile_enabled = ( @@ -114,7 +118,7 @@ def parallelize_gptoss( job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, use_flex_attn=use_flex_attn, - save_list=_save_list, + save_list=_op_sac_save_list, ) dp_mesh: DeviceMesh | None = None @@ -263,83 +267,18 @@ def apply_non_moe_tp( ) -def apply_fsdp( - model: nn.Module, - dp_mesh: DeviceMesh, - param_dtype: torch.dtype, - reduce_dtype: torch.dtype, - pp_enabled: bool, - cpu_offload: bool = False, - reshard_after_forward_policy: str = "default", - dp_mod_ep_mesh: DeviceMesh | None = None, -): - """ - Apply data parallelism (via FSDP2) to the model. - - Args: - model (nn.Module): The model to apply data parallelism to. - dp_mesh (DeviceMesh): The device mesh to use for data parallelism. - param_dtype (torch.dtype): The data type to use for model parameters. - reduce_dtype (torch.dtype): The data type to use for reduction operations. - pp_enabled (bool): Whether pipeline parallelism is enabled. - cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. - reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". - Other options: "never", "always". - - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - - "always" will enable `reshard_after_forward` for all forward passes. - - "never" will disable `reshard_after_forward` for all forward passes. - - """ - mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} - if cpu_offload: - fsdp_config["offload_policy"] = CPUOffloadPolicy() - - for layer_id, transformer_block in model.layers.items(): - if reshard_after_forward_policy == "always": - reshard_after_forward = True - elif reshard_after_forward_policy == "never": - reshard_after_forward = False - elif reshard_after_forward_policy == "default": - if pp_enabled: - # For PP, do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = False - else: - # As an optimization, do not reshard after forward for the last - # transformer block since FSDP would prefetch it immediately - reshard_after_forward = int(layer_id) < len(model.layers) - 1 - else: - raise ValueError( - f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." - ) - - # NOTE: in an MoE layer, the router and the shared experts - # are sharded together with the TransformerBlock - if dp_mod_ep_mesh: - fsdp_mod_ep_config = fsdp_config.copy() - fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh - fully_shard( - transformer_block.moe.experts, - **fsdp_mod_ep_config, - reshard_after_forward=reshard_after_forward, - ) - - fully_shard( - transformer_block, - **fsdp_config, - reshard_after_forward=reshard_after_forward, - ) - fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled) - - +# NOTE(jianiw): The function can not be reused now because reimplemented ExpertTensorParallel def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, + etp_enabled: bool, ): for transformer_block in model.layers.values(): + if not transformer_block.moe_enabled: + continue + if tp_mesh is not None: moe_layer_plan = { # input / output sharding on the seqlen dim @@ -354,13 +293,28 @@ def apply_moe_ep_tp( # replicate computation for the router "moe.router.gate": NoParallel(), } + if ep_mesh is not None and not etp_enabled: + # If TP is borrowed for EP, then split the tokens across TP ranks so that + # the reorderer, the all-to-all comms, and routed experts computation + # are effectively running Sequence Parallel (split along the folded bs*slen dim) + moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + if transformer_block.moe.shared_experts is not None: + # input Replicate, output Partial + moe_layer_plan.update( + { + "moe.shared_experts.w1": ColwiseParallel(), + "moe.shared_experts.w2": RowwiseParallel( + output_layouts=Partial() + ), + "moe.shared_experts.w3": ColwiseParallel(), + } + ) parallelize_module( module=transformer_block, device_mesh=tp_mesh, parallelize_plan=moe_layer_plan, ) - # if ep_mesh is not None: experts_mesh, experts_plan = None, None if ep_mesh is None: experts_mesh = tp_mesh @@ -370,9 +324,13 @@ def apply_moe_ep_tp( experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() - else: + elif etp_enabled: experts_mesh = ep_tp_mesh experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) + else: + experts_mesh = ep_mesh + experts_plan = ExpertParallel() + parallelize_module( module=transformer_block.moe.experts, device_mesh=experts_mesh, diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index f774b0ffd9..6684575a5c 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -99,7 +99,7 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) -# TODO(jianw): This is eager version from HuggingFace +# TODO(jianw): This is eager version from HuggingFace. Remove it once FlexAttention is ready. def eager_attention_forward( query: torch.Tensor, key: torch.Tensor, @@ -109,8 +109,15 @@ def eager_attention_forward( scaling: float, dropout: float = 0.0, **kwargs, -): - attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling +): + key_values = key.transpose(2, 3) # When TP is enabled, key should be shard() + print(f"key_values : {key_values.placements} {key_values.shape}") + print(f"query : {query.placements} {query.shape}") + + # [rank0]:key_values : (Shard(dim=1),) torch.Size([8, 64, 64, 2048]) + # [rank0]:query : (Shard(dim=1),) torch.Size([8, 64, 2048, 64]) + + attn_weights = query @ key_values * scaling if attention_mask is not None: # attention_mask can be [Tq, Tk] or [B, H, Tq, Tk] # Convert boolean "allowed" -> additive mask @@ -212,11 +219,20 @@ def forward( v = values.transpose(1, 2).contiguous() if self.use_flex_attn: - output = self.attn( + # FlexAttention + output, lse = self.attn( q, k, v, scale=None, - sink_weights=self.sinks.to_local() if isinstance(self.sinks, DTensor) else self.sinks, + return_lse=True, ) + + # Apply attention sink rescaling: rescale by σ(lse - w[h]) + # This is mathematically equivalent to concatenating learnable sink weights + sink_scale = torch.sigmoid(lse - self.sink.view(1, -1, 1)).unsqueeze( + -1 + ) # [B,H,S,1] + output = output * sink_scale + else: # eager attention forward output = self.attn( diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml index 22bfd10e1b..215ae69dd8 100644 --- a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -46,7 +46,7 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 54df26ef9d..90a2caced7 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -69,20 +69,20 @@ def __init__( self, attn_mask_type: str, fixed_block_size: int | None = None, - sliding_window: int | None = None, + sliding_window_size: int | None = None, ) -> None: super().__init__() if attn_mask_type not in ["causal", "block_causal", "sliding_window"]: raise ValueError(f"Unrecognized attn_mask_type {attn_mask_type}.") self.attn_mask_type = attn_mask_type self.fixed_block_size = fixed_block_size - self.sliding_window = sliding_window + self.sliding_window_size = sliding_window_size FlexAttention.used_attn_mask_types.add(self.mask_key) @property def mask_key(self) -> FLEX_ATTN_MASK_T: - return (self.attn_mask_type, self.fixed_block_size, self.sliding_window) + return (self.attn_mask_type, self.fixed_block_size, self.sliding_window_size) def forward( self, @@ -90,48 +90,13 @@ def forward( k: torch.Tensor, v: torch.Tensor, scale: float | None = None, - sink_weights: torch.Tensor | None = None, - ) -> torch.Tensor: - - # Use sink logic when sliding_window is used and sink_weights is provided - if self.attn_mask_type == "sliding_window" and sink_weights is not None: - return self._forward_with_sink(q, k, v, scale, sink_weights) + return_lse: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: # Regular path without sink block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) - - def _forward_with_sink( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - scale: float | None = None, - sink_weights: torch.Tensor | None = None, - ) -> torch.Tensor: - """Forward pass with attention sink for sliding window attention.""" - # Use the pre-compiled static block mask - block_mask = FlexAttention.block_masks[self.mask_key] - - # Run flex_attn and return LSE for sink computation - out, lse = FlexAttention.flex_attn( - q, - k, - v, - block_mask=block_mask, - return_lse=True, - scale=scale, - ) - - # Apply attention sink rescaling: rescale by σ(lse - w[h]) - # This is mathematically equivalent to concatenating learnable sink weights - if sink_weights is not None: - sink_scale = torch.sigmoid(lse - sink_weights.view(1, -1, 1)).unsqueeze( - -1 - ) # [B,H,S,1] - out = out * sink_scale + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, return_lse=return_lse, scale=scale) - return out.to(q.dtype) @staticmethod def _get_sliding_window_mask_mod(window: int): diff --git a/torchtitan/train.py b/torchtitan/train.py index 29ecb428fa..008a4eebba 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -329,7 +329,6 @@ def __init__(self, job_config: JobConfig): self.train_context = dist_utils.get_train_context( loss_parallel_enabled, parallelism_config.enable_compiled_autograd, - use_sdpa=not getattr(model_args, "use_flex_attn", False), ) self.maybe_enable_amp = dist_utils.maybe_enable_amp( parallel_dims, From bb8ee6f3e044fb4ce4ed0baeb394e88522636cab Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 29 Sep 2025 15:41:18 -0700 Subject: [PATCH 12/18] test --- .../experiments/gpt_oss/infra/parallelize.py | 10 +-- torchtitan/experiments/gpt_oss/model/model.py | 85 +++++++++++++------ 2 files changed, 66 insertions(+), 29 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index e3f3842c60..606e6620f8 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -249,11 +249,11 @@ def apply_non_moe_tp( # shard attention.sinks across heads # TODO(jianiw): Fix the sink implementation - attn = transformer_block.attention - attn.register_parameter( - "sinks", - nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), - ) + # attn = transformer_block.attention + # attn.register_parameter( + # "sinks", + # nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Replicate()])), + # ) if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index 6684575a5c..5bd58f8312 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -87,6 +87,7 @@ def apply_rotary_emb( xk_out = (xk * cos) + (rotate_half(xk) * sin) return xq_out.type_as(xq), xk_out.type_as(xk) + def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape @@ -109,7 +110,7 @@ def eager_attention_forward( scaling: float, dropout: float = 0.0, **kwargs, -): +): key_values = key.transpose(2, 3) # When TP is enabled, key should be shard() print(f"key_values : {key_values.placements} {key_values.shape}") print(f"query : {query.placements} {query.shape}") @@ -145,15 +146,20 @@ def eager_attention_forward( attn_output = torch.matmul(attn_weights, value) return attn_output + class Attention(nn.Module): """ Multi-head attention (MLA) module. """ - def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = False): + def __init__( + self, model_args: GptOssModelArgs, use_sliding_attention: bool = False + ): super().__init__() - self.sliding_window = model_args.sliding_window if use_sliding_attention else None + self.sliding_window = ( + model_args.sliding_window if use_sliding_attention else None + ) self.head_dim = model_args.head_dim self.n_heads = model_args.num_attention_heads self.n_kv_heads = model_args.num_key_value_heads @@ -161,16 +167,24 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa self.n_rep = self.n_heads // self.n_kv_heads self.wq = nn.Linear( - model_args.hidden_size, model_args.num_attention_heads * model_args.head_dim, bias=True + model_args.hidden_size, + model_args.num_attention_heads * model_args.head_dim, + bias=True, ) self.wk = nn.Linear( - model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + model_args.hidden_size, + model_args.num_key_value_heads * model_args.head_dim, + bias=True, ) self.wv = nn.Linear( - model_args.hidden_size, model_args.num_key_value_heads * model_args.head_dim, bias=True + model_args.hidden_size, + model_args.num_key_value_heads * model_args.head_dim, + bias=True, ) self.wo = nn.Linear( - model_args.num_attention_heads * model_args.head_dim, model_args.hidden_size, bias=True + model_args.num_attention_heads * model_args.head_dim, + model_args.hidden_size, + bias=True, ) self.sinks = nn.Parameter(torch.empty(model_args.num_attention_heads)) @@ -179,9 +193,15 @@ def __init__(self, model_args: GptOssModelArgs, use_sliding_attention: bool = Fa if self.use_flex_attn: # Only apply sliding window to every other layer if use_sliding_attention: - self.attn = build_attention(use_flex_attn=True, attn_mask_type="sliding_window", sliding_window=self.sliding_window) + self.attn = build_attention( + use_flex_attn=True, + attn_mask_type="sliding_window", + sliding_window=self.sliding_window, + ) else: - self.attn = build_attention(use_flex_attn=True, attn_mask_type=model_args.attn_mask_type) + self.attn = build_attention( + use_flex_attn=True, attn_mask_type=model_args.attn_mask_type + ) else: # NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed self.attn = eager_attention_forward @@ -219,32 +239,39 @@ def forward( v = values.transpose(1, 2).contiguous() if self.use_flex_attn: - # FlexAttention + # FlexAttention output, lse = self.attn( - q, k, v, + q, + k, + v, scale=None, - return_lse=True, + return_lse=False, ) # Apply attention sink rescaling: rescale by σ(lse - w[h]) - # This is mathematically equivalent to concatenating learnable sink weights - sink_scale = torch.sigmoid(lse - self.sink.view(1, -1, 1)).unsqueeze( + # This is mathematically equivalent to concatenating learnable sink weights + sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze( -1 ) # [B,H,S,1] - output = output * sink_scale + output = output * sink_scale.to(output.dtype) else: # eager attention forward output = self.attn( - q, k, v, self.sinks, + q, + k, + v, + self.sinks, attention_mask=self.sliding_window_causal(seqlen, x.device), scaling=self.head_dim**-0.5, dropout=0.0, ) - output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) + output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) # Reshape and project output - output = output.reshape(bsz, seqlen, -1).contiguous() # (bsz, seqlen, n_heads * v_head_dim) + output = output.reshape( + bsz, seqlen, -1 + ).contiguous() # (bsz, seqlen, n_heads * v_head_dim) output = self.wo(output) # (bsz, seqlen, dim) return output @@ -263,7 +290,7 @@ def init_weights(self, init_std: float): # TODO: statically init the mask using train.seq_len def sliding_window_causal(self, seqlen, device): i = torch.arange(seqlen, device=device) - q_idx = i[:, None] + q_idx = i[:, None] kv_idx = i[None, :] causal_mask = q_idx >= kv_idx @@ -282,11 +309,17 @@ def __init__(self, layer_id: int, model_args: GptOssModelArgs): super().__init__() use_sliding_attention = layer_id % 2 == 0 - self.attention = Attention(model_args, use_sliding_attention=use_sliding_attention) - self.attention_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) + self.attention = Attention( + model_args, use_sliding_attention=use_sliding_attention + ) + self.attention_norm = nn.RMSNorm( + model_args.hidden_size, eps=model_args.norm_eps + ) self.ffn_norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) - self.moe = GptOssMoE(model_args, dim=model_args.hidden_size, hidden_dim=model_args.moe_inter_dim) + self.moe = GptOssMoE( + model_args, dim=model_args.hidden_size, hidden_dim=model_args.moe_inter_dim + ) self.moe_enabled = True # for composability with load balancing self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 @@ -323,14 +356,18 @@ def __init__(self, model_args: GptOssModelArgs): super().__init__() self.model_args = model_args self.max_seq_len = model_args.max_seq_len - self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.hidden_size) + self.tok_embeddings = nn.Embedding( + model_args.vocab_size, model_args.hidden_size + ) self.register_buffer( "rope_cache", self._precompute_rope_cache(), persistent=False ) self.layers = torch.nn.ModuleDict() for layer_id in range(model_args.num_hidden_layers): - self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to(torch.bfloat16) + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to( + torch.bfloat16 + ) self.norm = nn.RMSNorm(model_args.hidden_size, eps=model_args.norm_eps) self.output = nn.Linear( From 07c0ff42f213a766b0b500df35a619763162f359 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 29 Sep 2025 20:50:00 -0700 Subject: [PATCH 13/18] fix sink --- .../experiments/gpt_oss/infra/parallelize.py | 35 ++++++++++--------- torchtitan/experiments/gpt_oss/model/model.py | 8 +++-- .../gpt_oss/train_configs/debug_model.toml | 8 +++-- 3 files changed, 30 insertions(+), 21 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 606e6620f8..5a378b5403 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -8,6 +8,7 @@ PrepareModuleInput, RowwiseParallel, SequenceParallel, + PrepareModuleOutput, ) if torch.__version__ >= "2.9": @@ -22,7 +23,6 @@ from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.experiments.llama4.infra.parallelize import ( apply_fsdp, - apply_moe_ep_tp, ) from torchtitan.tools.logging import logger @@ -212,18 +212,21 @@ def apply_non_moe_tp( Float8ColwiseParallel, Float8RowwiseParallel, PrepareFloat8ModuleInput, + PrepareFloat8ModuleOutput ) - rowwise_parallel, colwise_parallel, prepare_module_input = ( + rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output = ( Float8RowwiseParallel, Float8ColwiseParallel, PrepareFloat8ModuleInput, + PrepareFloat8ModuleOutput, ) else: - rowwise_parallel, colwise_parallel, prepare_module_input = ( + rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output= ( RowwiseParallel, ColwiseParallel, PrepareModuleInput, + PrepareModuleOutput, ) # Apply tensor + sequence parallelism to every transformer block @@ -231,30 +234,30 @@ def apply_non_moe_tp( layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), ), - "attention.wq": colwise_parallel(use_local_output=False), - "attention.wk": colwise_parallel(use_local_output=False), - "attention.wv": colwise_parallel(use_local_output=False), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.attn": prepare_module_output(output_layouts=(Shard(1), Shard(1)), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } + # shard attention.sinks across heads + attn = transformer_block.attention + attn.register_parameter( + "sinks", + nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), + ) + parallelize_module( module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, ) - # shard attention.sinks across heads - # TODO(jianiw): Fix the sink implementation - # attn = transformer_block.attention - # attn.register_parameter( - # "sinks", - # nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Replicate()])), - # ) - if enable_async_tp: from torch.distributed._symmetric_memory import enable_symm_mem_for_group diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index 5bd58f8312..b5ca968762 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -245,14 +245,18 @@ def forward( k, v, scale=None, - return_lse=False, + return_lse=True, ) # Apply attention sink rescaling: rescale by σ(lse - w[h]) # This is mathematically equivalent to concatenating learnable sink weights + # TODO: If attention part is, but self.sinks are registered as a DTensor, while lse is a plain tensor + # q, k, v are already sharded by TP: [batch, local_heads, seq_len, head_dim] (plain tensor) + # sinks shape needs to match: [local_heads], + # [rank0]:lse.shape torch.Size([8, 32, 2048]), sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze( -1 - ) # [B,H,S,1] + ) output = output * sink_scale.to(output.dtype) else: diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml index 215ae69dd8..b628e2e1ac 100644 --- a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -46,10 +46,12 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -context_parallel_degree = 1 +expert_parallel_degree = 4 + + + [checkpoint] enable = false From a2727a68583d0afb1b41bf4cde61c9d95758fff8 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 29 Sep 2025 22:55:15 -0700 Subject: [PATCH 14/18] test EP --- torchtitan/experiments/gpt_oss/infra/parallelize.py | 7 +++++-- .../experiments/gpt_oss/train_configs/debug_model.toml | 2 -- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index 5a378b5403..beffc8b43b 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -102,7 +102,9 @@ def parallelize_gptoss( ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, ep_tp_mesh=( world_mesh["ep", "tp"] - if parallel_dims.tp_enabled and parallel_dims.ep_enabled and parallel_dims.etp_enabled + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled else None ), etp_enabled=parallel_dims.etp_enabled, @@ -145,9 +147,10 @@ def parallelize_gptoss( pp_enabled=parallel_dims.pp_enabled, cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, dp_mod_ep_mesh=( world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if dp_mod_ep_mesh_dim_names + if parallel_dims.ep_enabled else None ), ) diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml index b628e2e1ac..a83d57ff0e 100644 --- a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -51,8 +51,6 @@ enable_async_tensor_parallel = false expert_parallel_degree = 4 - - [checkpoint] enable = false folder = "checkpoint" From e7f9a5620400749bc2056c9e70ae75bc4512c959 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Mon, 29 Sep 2025 23:45:07 -0700 Subject: [PATCH 15/18] working on ETP --- torchtitan/experiments/gpt_oss/infra/expert_parallel.py | 6 +++--- torchtitan/experiments/gpt_oss/model/moe.py | 2 ++ .../experiments/gpt_oss/train_configs/debug_model.toml | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py index 512bc8f6fd..850bfd0d6f 100644 --- a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -69,11 +69,11 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh): mod.register_parameter( "mlp1_bias", nn.Parameter(distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])), - ) # Row-wise sharding + ) # Column-wise sharding mod.register_parameter( "mlp2_weight", - nn.Parameter(distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)])), - ) # Column-wise sharding + nn.Parameter(distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(1)])), + ) # Row-wise sharding mod.register_parameter( "mlp2_bias", nn.Parameter(distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Shard(1)])), diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py index 40fdac1887..dc9850ae76 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -124,6 +124,7 @@ def _run_experts_grouped_mm( h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets) if offsets is not None: + # TODO(jianiw): check what is this doing b1 = mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) tail_slack = x.shape[0] - int(offsets[-1]) if tail_slack: @@ -131,6 +132,7 @@ def _run_experts_grouped_mm( h = h + b1.to(h.dtype) h = swiglu(h, limit=swiglu_limit) + # print(f"{h.shape} {mlp2_weight.shape}") # [rank0]:torch.Size([77507, 1440]) torch.Size([2, 2880, 128]) h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) if offsets is not None: b2 = mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml index a83d57ff0e..853415efa3 100644 --- a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -46,9 +46,10 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 +tensor_parallel_degree = 2 enable_async_tensor_parallel = false expert_parallel_degree = 4 +expert_tensor_parallel_degree = 2 [checkpoint] From ef146e15e77362bfbb9bbef4a4a64b7368be5ac4 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 30 Sep 2025 15:54:27 -0700 Subject: [PATCH 16/18] clean up --- torchtitan/distributed/utils.py | 1 + torchtitan/experiments/__init__.py | 4 +- torchtitan/experiments/gpt_oss/README.md | 18 +++ torchtitan/experiments/gpt_oss/README.py | 0 torchtitan/experiments/gpt_oss/__init__.py | 24 ++-- .../gpt_oss/infra/expert_parallel.py | 52 +++++-- .../experiments/gpt_oss/infra/parallelize.py | 99 ++++++-------- torchtitan/experiments/gpt_oss/model/args.py | 9 +- torchtitan/experiments/gpt_oss/model/model.py | 127 +++++------------- torchtitan/experiments/gpt_oss/model/moe.py | 69 +++++++--- .../gpt_oss/train_configs/debug_model.toml | 7 +- torchtitan/models/attention.py | 5 +- 12 files changed, 208 insertions(+), 207 deletions(-) create mode 100644 torchtitan/experiments/gpt_oss/README.md delete mode 100644 torchtitan/experiments/gpt_oss/README.py diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index b3ecf2c80a..72700fb1ab 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -206,6 +206,7 @@ def context(cp_context: Generator[None, None, None] | None = None): if SDPBackend.MATH in ScaledDotProductAttention.backends: ScaledDotProductAttention.backends.remove(SDPBackend.MATH) + stack.enter_context(cp_context) yield diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index b73ddc8458..2ce955de11 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -4,4 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -_supported_experiments = frozenset(["flux", "llama4", "qwen3", "simple_fsdp", "vlm", "gpt_oss"]) +_supported_experiments = frozenset( + ["flux", "llama4", "qwen3", "simple_fsdp", "vlm", "gpt_oss"] +) diff --git a/torchtitan/experiments/gpt_oss/README.md b/torchtitan/experiments/gpt_oss/README.md new file mode 100644 index 0000000000..4fd02246ec --- /dev/null +++ b/torchtitan/experiments/gpt_oss/README.md @@ -0,0 +1,18 @@ +# gpt-oss Model in torchtitan + +## Quick Start +```bash +CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh +``` + +## Supported Features +- FSDP/HSDP, TP, EP, ETP +- Grouped matrix multiplication for efficient computation +- SwiGLU activation +- Multi-head attention with sliding window mask and attention sink + + +## TODO +1. More parallelism support: CP, PP +2. Conversion between HF weights (StateDictAdapter) +3. Forward parity verification diff --git a/torchtitan/experiments/gpt_oss/README.py b/torchtitan/experiments/gpt_oss/README.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py index 532b1ccfcc..a93ba8769e 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -1,14 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader from torchtitan.models.moe import MoEArgs -from torchtitan.protocols.train_spec import register_train_spec, TrainSpec -from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.protocols.train_spec import TrainSpec from .infra.parallelize import parallelize_gptoss from .model.args import GptOssModelArgs @@ -26,7 +32,7 @@ "debugmodel": GptOssModelArgs( hidden_size=256, num_hidden_layers=4, - moe_args = MoEArgs( + moe_args=MoEArgs( num_experts=8, num_shared_experts=0, score_func="softmax", @@ -38,11 +44,11 @@ load_balance_coeff=1e-3, ), use_flex_attn=True, - attn_mask_type="causal" + attn_mask_type="causal", ), "20b": GptOssModelArgs( num_hidden_layers=24, - moe_args = MoEArgs( + moe_args=MoEArgs( num_experts=32, num_shared_experts=0, score_func="softmax", @@ -52,11 +58,11 @@ top_k=4, use_grouped_mm=True, load_balance_coeff=1e-3, - ) + ), ), "120b": GptOssModelArgs( num_hidden_layers=36, - moe_args = MoEArgs( + moe_args=MoEArgs( num_experts=128, num_shared_experts=0, score_func="softmax", @@ -66,7 +72,7 @@ top_k=4, use_grouped_mm=True, load_balance_coeff=1e-3, - ) + ), ), } @@ -78,7 +84,7 @@ def get_train_spec() -> TrainSpec: model_args=gptoss_configs, parallelize_fn=parallelize_gptoss, pipelining_fn=None, - build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_hf_dataloader, build_tokenizer_fn=build_hf_tokenizer, diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py index 850bfd0d6f..1d1c9e144e 100644 --- a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -1,10 +1,13 @@ -from functools import partial +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + from typing import Callable import torch -import torch.distributed as dist import torch.nn as nn -from torch.distributed._functional_collectives import all_to_all_single_autograd from torch.distributed.tensor import ( DeviceMesh, distribute_module, @@ -14,7 +17,6 @@ Shard, ) from torch.distributed.tensor.parallel import ParallelStyle -from torch.distributed.tensor.placement_types import Placement from torchtitan.distributed.expert_parallel import ExpertParallel @@ -22,7 +24,10 @@ class TensorParallel(ParallelStyle): def _partition_fn(self, name, module, device_mesh): module.register_parameter( - "mlp1_weight", nn.Parameter(distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)])) + "mlp1_weight", + nn.Parameter( + distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)]) + ), ) # Column-wise sharding module.register_parameter( "mlp1_bias", @@ -30,11 +35,15 @@ def _partition_fn(self, name, module, device_mesh): ) # Column-wise sharding module.register_parameter( "mlp2_weight", - nn.Parameter(distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)])), + nn.Parameter( + distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)]) + ), ) # Row-wise sharding module.register_parameter( "mlp2_bias", - nn.Parameter(distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()])), + nn.Parameter( + distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()]) + ), ) # Replicate def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: @@ -44,6 +53,7 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: self._partition_fn, ) + # This class is for dp2ep with TP (without TP we can just use ExpertParallel) class ExpertTensorParallel(ExpertParallel): def __init__( @@ -64,20 +74,28 @@ def _token_dispatch(self, mod, inputs, device_mesh): def _partition_fn_2d(self, name, mod, ep_tp_mesh): mod.register_parameter( "mlp1_weight", - nn.Parameter(distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)])), + nn.Parameter( + distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)]) + ), ) # Column-wise sharding mod.register_parameter( "mlp1_bias", - nn.Parameter(distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])), + nn.Parameter( + distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) + ), ) # Column-wise sharding mod.register_parameter( "mlp2_weight", - nn.Parameter(distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(1)])), + nn.Parameter( + distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(1)]) + ), ) # Row-wise sharding mod.register_parameter( "mlp2_bias", - nn.Parameter(distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Shard(1)])), - ) # Row-wise sharding + nn.Parameter( + distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) + ), + ) # Replicate def _token_combine(self, mod, routed_output, device_mesh): # token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh @@ -156,7 +174,15 @@ def wrapper( input_shape = x.shape x = x[permuted_indices, :] - out = func(mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias, swiglu_limit, x, num_tokens_per_expert) + out = func( + mlp1_weight, + mlp1_bias, + mlp2_weight, + mlp2_bias, + swiglu_limit, + x, + num_tokens_per_expert, + ) if num_tokens_per_expert is not None: out_unpermuted = out.new_empty(input_shape) diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py index beffc8b43b..bbde1a7515 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -1,38 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + import torch import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh -from torch.distributed.tensor import Replicate, Shard, distribute_tensor + +from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( ColwiseParallel, parallelize_module, PrepareModuleInput, + PrepareModuleInputOutput, + PrepareModuleOutput, RowwiseParallel, SequenceParallel, - PrepareModuleOutput, ) - -if torch.__version__ >= "2.9": - from torch.distributed.tensor.parallel import PrepareModuleInputOutput -else: - print(f"Since torch version {torch.__version__} < 2.9, PrepareModuleInputOutput is not available and MoE EP TP will fail.") - -from torchtitan.config.job_config import JobConfig from torchtitan.config import TORCH_DTYPE_MAP -from torchtitan.distributed import ParallelDims, NoParallel -from torchtitan.distributed.expert_parallel import ExpertParallel, ReordererSequenceParallel -from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp -from torchtitan.experiments.llama4.infra.parallelize import ( - apply_fsdp, +from torchtitan.config.job_config import JobConfig +from torchtitan.distributed import NoParallel, ParallelDims +from torchtitan.distributed.expert_parallel import ( + ExpertParallel, + ReordererSequenceParallel, ) - +from torchtitan.experiments.llama4.infra.parallelize import apply_fsdp +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp from torchtitan.tools.logging import logger -from torch.distributed.tensor import Partial, Replicate, Shard - -from .expert_parallel import ( - ExpertTensorParallel, - TensorParallel, -) +from .expert_parallel import ExpertTensorParallel, TensorParallel # for selective op activation checkpointing @@ -67,7 +64,7 @@ def parallelize_gptoss( use_flex_attn = getattr(model.model_args, "use_flex_attn", False) if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: raise NotImplementedError("CP support for FlexAttention is still in progress.") - + if parallel_dims.tp_enabled: if job_config.parallelism.enable_async_tensor_parallel: raise NotImplementedError( @@ -109,7 +106,7 @@ def parallelize_gptoss( ), etp_enabled=parallel_dims.etp_enabled, ) - + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -122,7 +119,7 @@ def parallelize_gptoss( use_flex_attn=use_flex_attn, save_list=_op_sac_save_list, ) - + dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel @@ -178,6 +175,7 @@ def parallelize_gptoss( return model + def apply_non_moe_tp( model: nn.Module, tp_mesh: DeviceMesh, @@ -207,30 +205,17 @@ def apply_non_moe_tp( }, ) - # Parallel styles used for transformer block linear weights and their - # inputs may be different for float8 linears with tensorwise scaling. - if enable_float8_tensorwise_tp: - # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there - from torchao.float8.float8_tensor_parallel import ( - Float8ColwiseParallel, - Float8RowwiseParallel, - PrepareFloat8ModuleInput, - PrepareFloat8ModuleOutput - ) - - rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output = ( - Float8RowwiseParallel, - Float8ColwiseParallel, - PrepareFloat8ModuleInput, - PrepareFloat8ModuleOutput, - ) - else: - rowwise_parallel, colwise_parallel, prepare_module_input, prepare_module_output= ( - RowwiseParallel, - ColwiseParallel, - PrepareModuleInput, - PrepareModuleOutput, - ) + ( + rowwise_parallel, + colwise_parallel, + prepare_module_input, + prepare_module_output, + ) = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + PrepareModuleOutput, + ) # Apply tensor + sequence parallelism to every transformer block for transformer_block in model.layers.values(): @@ -243,7 +228,11 @@ def apply_non_moe_tp( "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), "attention.wv": colwise_parallel(), - "attention.attn": prepare_module_output(output_layouts=(Shard(1), Shard(1)), desired_output_layouts=(Shard(1), Shard(1)), use_local_output=False), + "attention.attn": prepare_module_output( + output_layouts=(Shard(1), Shard(1)), + desired_output_layouts=(Shard(1), Shard(1)), + use_local_output=False, + ), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } @@ -304,17 +293,7 @@ def apply_moe_ep_tp( # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) - if transformer_block.moe.shared_experts is not None: - # input Replicate, output Partial - moe_layer_plan.update( - { - "moe.shared_experts.w1": ColwiseParallel(), - "moe.shared_experts.w2": RowwiseParallel( - output_layouts=Partial() - ), - "moe.shared_experts.w3": ColwiseParallel(), - } - ) + parallelize_module( module=transformer_block, device_mesh=tp_mesh, diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py index 583dbc848f..c3c0e925b7 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -1,3 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -7,11 +13,10 @@ from torch import nn - from torchtitan.config.job_config import JobConfig +from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import BaseModelArgs from torchtitan.tools.logging import logger -from torchtitan.models.moe import MoEArgs from torchtitan.tools.utils import has_cuda_capability diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py index b5ca968762..38150a0d60 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -1,13 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import math -from typing import Tuple +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. import torch from torch import nn -from torch.distributed.tensor import DTensor -from torchtitan.experiments.simple_fsdp import model from torchtitan.models.attention import build_attention from torchtitan.protocols.train_spec import ModelProtocol @@ -100,53 +101,6 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: ) -# TODO(jianw): This is eager version from HuggingFace. Remove it once FlexAttention is ready. -def eager_attention_forward( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - sinks: torch.Tensor, - attention_mask: torch.Tensor, - scaling: float, - dropout: float = 0.0, - **kwargs, -): - key_values = key.transpose(2, 3) # When TP is enabled, key should be shard() - print(f"key_values : {key_values.placements} {key_values.shape}") - print(f"query : {query.placements} {query.shape}") - - # [rank0]:key_values : (Shard(dim=1),) torch.Size([8, 64, 64, 2048]) - # [rank0]:query : (Shard(dim=1),) torch.Size([8, 64, 2048, 64]) - - attn_weights = query @ key_values * scaling - if attention_mask is not None: - # attention_mask can be [Tq, Tk] or [B, H, Tq, Tk] - # Convert boolean "allowed" -> additive mask - if attention_mask.dtype == torch.bool: - m = attention_mask - add_mask = torch.zeros_like(m, dtype=attn_weights.dtype) - add_mask = add_mask.masked_fill(~m, -float("inf")) - else: - add_mask = attention_mask.to(attn_weights.dtype) - - # Truncate to current key length and add (broadcasts if needed) - add_mask = add_mask[..., : key.shape[-2]] - attn_weights = attn_weights + add_mask - - sinks = sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) - combined_logits = torch.cat([attn_weights, sinks], dim=-1) - - # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 - # when training with bsz>1 we clamp max values. - - combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values - probs = nn.functional.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) - scores = probs[..., :-1] # we drop the sink here - attn_weights = nn.functional.dropout(scores, p=dropout, training=False) - attn_output = torch.matmul(attn_weights, value) - return attn_output - - class Attention(nn.Module): """ Multi-head attention (MLA) module. @@ -190,21 +144,20 @@ def __init__( self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - # Only apply sliding window to every other layer - if use_sliding_attention: - self.attn = build_attention( - use_flex_attn=True, - attn_mask_type="sliding_window", - sliding_window=self.sliding_window, - ) - else: - self.attn = build_attention( - use_flex_attn=True, attn_mask_type=model_args.attn_mask_type - ) + if not self.use_flex_attn: + raise ValueError("Only support FlexAttention in Gpt-oss model") + + # Only apply sliding window to every other layer + if use_sliding_attention: + self.attn = build_attention( + use_flex_attn=True, + attn_mask_type="sliding_window", + sliding_window=self.sliding_window, + ) else: - # NOTE: sampling with FlexAttn seems broken; use TorchAttn if needed - self.attn = eager_attention_forward + self.attn = build_attention( + use_flex_attn=True, attn_mask_type=model_args.attn_mask_type + ) def forward( self, @@ -238,38 +191,20 @@ def forward( k = keys.transpose(1, 2).contiguous() v = values.transpose(1, 2).contiguous() - if self.use_flex_attn: - # FlexAttention - output, lse = self.attn( - q, - k, - v, - scale=None, - return_lse=True, - ) + # FlexAttention + output, lse = self.attn( + q, + k, + v, + scale=None, + return_lse=True, + ) - # Apply attention sink rescaling: rescale by σ(lse - w[h]) - # This is mathematically equivalent to concatenating learnable sink weights - # TODO: If attention part is, but self.sinks are registered as a DTensor, while lse is a plain tensor - # q, k, v are already sharded by TP: [batch, local_heads, seq_len, head_dim] (plain tensor) - # sinks shape needs to match: [local_heads], - # [rank0]:lse.shape torch.Size([8, 32, 2048]), - sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze( - -1 - ) - output = output * sink_scale.to(output.dtype) + # Apply attention sink rescaling: rescale by σ(lse - w[h]) + # This is mathematically equivalent to concatenating learnable sink weights + sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(-1) + output = output * sink_scale.to(output.dtype) - else: - # eager attention forward - output = self.attn( - q, - k, - v, - self.sinks, - attention_mask=self.sliding_window_causal(seqlen, x.device), - scaling=self.head_dim**-0.5, - dropout=0.0, - ) output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) # Reshape and project output @@ -289,7 +224,9 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std) for linear in linear_list: nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(linear.bias, mean=0.0, std=init_std) nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.wo.bias, mean=0.0, std=init_std) # TODO: statically init the mask using train.seq_len def sliding_window_causal(self, seqlen, device): diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py index dc9850ae76..2df093880a 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -1,15 +1,20 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import torch -from torch.distributed.tensor import DTensor -import torch.nn.functional as F from torch import nn +from torch.distributed.tensor import DTensor from torchtitan.experiments.gpt_oss.infra.expert_parallel import expert_parallel -from torchtitan.protocols import model +from torchtitan.models.moe import MoE from .args import GptOssModelArgs -from torchtitan.models.moe import MoE, MoEArgs, GroupedExperts + def swiglu(x, alpha: float = 1.702, limit: float = 7.0): x_glu, x_linear = x[..., ::2], x[..., 1::2] @@ -20,6 +25,7 @@ def swiglu(x, alpha: float = 1.702, limit: float = 7.0): # Note we add an extra bias of 1 to the linear layer return out_glu * (x_linear + 1) + class GptOssGroupedExperts(nn.Module): def __init__( self, @@ -34,7 +40,7 @@ def __init__( self.use_grouped_mm = use_grouped_mm self.swiglu_limit = swiglu_limit - self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, hidden_dim * 2))) # w1 and w3 + self.mlp1_weight = nn.Parameter(torch.empty((num_experts, dim, hidden_dim * 2))) self.mlp1_bias = nn.Parameter(torch.empty((num_experts, hidden_dim * 2))) self.mlp2_weight = nn.Parameter(torch.empty((num_experts, hidden_dim, dim))) self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) @@ -46,11 +52,23 @@ def forward( ) -> torch.Tensor: if self.use_grouped_mm: return GptOssGroupedExperts._run_experts_grouped_mm( - self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, self.swiglu_limit, x, num_tokens_per_expert + self.mlp1_weight, + self.mlp1_bias, + self.mlp2_weight, + self.mlp2_bias, + self.swiglu_limit, + x, + num_tokens_per_expert, ) else: return GptOssGroupedExperts._run_experts_for_loop( - self.mlp1_weight, self.mlp1_bias, self.mlp2_weight, self.mlp2_bias, self.swiglu_limit, x, num_tokens_per_expert + self.mlp1_weight, + self.mlp1_bias, + self.mlp2_weight, + self.mlp2_bias, + self.swiglu_limit, + x, + num_tokens_per_expert, ) # TODO: keeping this for-loop implementation for comparison @@ -82,7 +100,10 @@ def _run_experts_for_loop( ) out_experts_splits = [] for expert_idx, x_expert in enumerate(x): - h = torch.matmul(x_expert, mlp1_weight[expert_idx]) + mlp1_bias[expert_idx] + h = ( + torch.matmul(x_expert, mlp1_weight[expert_idx]) + + mlp1_bias[expert_idx] + ) h = swiglu(h, limit=swiglu_limit) h = torch.matmul(h, mlp2_weight[expert_idx]) + mlp2_bias[expert_idx] out_experts_splits.append(h) @@ -98,7 +119,7 @@ def _run_experts_for_loop( return out - @expert_parallel # NOTE: EP currently reduces 20B MFU from 17.8% to 16.5%! + @expert_parallel @staticmethod def _run_experts_grouped_mm( mlp1_weight: torch.Tensor, @@ -120,7 +141,12 @@ def _run_experts_grouped_mm( assert x.dim() == 3 if isinstance(mlp1_weight, DTensor): - mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias = mlp1_weight.to_local(), mlp1_bias.to_local(), mlp2_weight.to_local(), mlp2_bias.to_local() + mlp1_weight, mlp1_bias, mlp2_weight, mlp2_bias = ( + mlp1_weight.to_local(), + mlp1_bias.to_local(), + mlp2_weight.to_local(), + mlp2_bias.to_local(), + ) h = torch._grouped_mm(x.bfloat16(), mlp1_weight.bfloat16(), offs=offsets) if offsets is not None: @@ -132,12 +158,11 @@ def _run_experts_grouped_mm( h = h + b1.to(h.dtype) h = swiglu(h, limit=swiglu_limit) - # print(f"{h.shape} {mlp2_weight.shape}") # [rank0]:torch.Size([77507, 1440]) torch.Size([2, 2880, 128]) h = torch._grouped_mm(h, mlp2_weight.bfloat16(), offs=offsets) if offsets is not None: b2 = mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) tail_slack = x.shape[0] - int(offsets[-1]) - if tail_slack: + if tail_slack: # padding b2 = torch.cat([b2, b2.new_zeros((tail_slack, b2.shape[-1]))], dim=0) h = h + b2.to(h.dtype) @@ -150,24 +175,26 @@ def init_weights(self, init_std: float): nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std) def extra_repr(self): - return (f"num_experts={self.num_experts}, " - f"use_grouped_mm={self.use_grouped_mm}, " - f"mlp1_weight={tuple(self.mlp1_weight.shape)}, " - f"mlp1_bias={tuple(self.mlp1_bias.shape)}, " - f"mlp2_weight={tuple(self.mlp2_weight.shape)}, " - f"mlp2_bias={tuple(self.mlp2_bias.shape)}") + return ( + f"num_experts={self.num_experts}, " + f"use_grouped_mm={self.use_grouped_mm}, " + f"mlp1_weight={tuple(self.mlp1_weight.shape)}, " + f"mlp1_bias={tuple(self.mlp1_bias.shape)}, " + f"mlp2_weight={tuple(self.mlp2_weight.shape)}, " + f"mlp2_bias={tuple(self.mlp2_bias.shape)}" + ) class GptOssMoE(MoE): """GptOss MoE implementation that inherits from the base MoE class.""" - + def __init__(self, model_args: GptOssModelArgs, dim: int, hidden_dim: int): # Convert GptOssModelArgs to MoEArgs for base class compatibility moe_args = model_args.moe_args - + # Initialize the base MoE class super().__init__(moe_args, dim, hidden_dim) - + # Override the base GroupedExperts with GptOssGroupedExperts self.experts = GptOssGroupedExperts( dim=dim, diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml index 853415efa3..bae6c304d7 100644 --- a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -46,11 +46,10 @@ dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 2 +tensor_parallel_degree = 1 enable_async_tensor_parallel = false -expert_parallel_degree = 4 -expert_tensor_parallel_degree = 2 - +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 [checkpoint] enable = false diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 90a2caced7..b34c70a157 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -95,8 +95,9 @@ def forward( # Regular path without sink block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, return_lse=return_lse, scale=scale) - + return FlexAttention.flex_attn( + q, k, v, block_mask=block_mask, return_lse=return_lse, scale=scale + ) @staticmethod def _get_sliding_window_mask_mod(window: int): From 2b47774777f16f96471950bc44b160789f470a68 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 30 Sep 2025 15:54:27 -0700 Subject: [PATCH 17/18] clean up --- torchtitan/experiments/gpt_oss/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/experiments/gpt_oss/README.md b/torchtitan/experiments/gpt_oss/README.md index 4fd02246ec..613e77003e 100644 --- a/torchtitan/experiments/gpt_oss/README.md +++ b/torchtitan/experiments/gpt_oss/README.md @@ -16,3 +16,4 @@ CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./ 1. More parallelism support: CP, PP 2. Conversion between HF weights (StateDictAdapter) 3. Forward parity verification +4. CI support From cd89d2641a0320f2a4cce9a2de199274717dfd3a Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Tue, 30 Sep 2025 16:03:11 -0700 Subject: [PATCH 18/18] fix lint --- torchtitan/experiments/__init__.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 45601e2d6f..6b28bf78b3 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -5,5 +5,13 @@ # LICENSE file in the root directory of this source tree. _supported_experiments = frozenset( - ["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm", "gpt_oss"] + [ + "flux", + "llama4", + "qwen3", + "simple_fsdp.llama3", + "simple_fsdp.deepseek_v3", + "vlm", + "gpt_oss", + ] )