diff --git a/torchtitan/experiments/autopartition/README.md b/torchtitan/experiments/autopartition/README.md new file mode 100644 index 0000000000..6de77cdd28 --- /dev/null +++ b/torchtitan/experiments/autopartition/README.md @@ -0,0 +1,55 @@ +# Auto-Partition in torchtitan + +## Overview + +This folder provides an automatic partitioning method that considers the computation cost of embedding layers. +This method involves calculating the floating-point operations (FLOPs) of the embedding layers and constructing an array that incorporates the FLOPs of both the transformer and embedding layers. Subsequently, a heuristic algorithm is employed to identify a balanced pipeline partition. + +## Quick Start + +### Compile + +First, we need to compile `autopipe.cpp`. +```bash +pip install pybind11 +cd ./torchtitan/experiments/autopartition/infra/cpp +mkdir build +cd build +cmake .. +make +mv *.so ../../ +``` + +The following command uses Llama 3 as an example: + +```bash +CONFIG_FILE="./torchtitan/experiments/autopartition/train_configs/debug_model.toml" ./run_train.sh +``` + +## Performance + +Hardware configuration: 4x RTX 3090 24GB, pipeline parallelism dimension is 4. + +### llama3 配置对比 +| hidden size| layers | autopipe TPS| default TPS| Speedup | +| ---------- | ---- | ---------- | -----------| ----------- | +| dim=256 | 6 | 31,094 | 29,549 | +5.2% | +| dim=256 | 12 | 21,803 | 21,923 | -0.5% | +| dim=2048 | 12 | 3,348 | 2,616 | +28.0% | +| dim=4096 | 12 | 981 | 761 | +28.9% | + +### deepseekv3(without moe) 配置对比 + +| hidden size| layers | autopipe TPS| default TPS| Speedup | +| ---------- | ---- | ---------- | -----------| ----------- | +| dim=256 | 6 | 13,373 | 13,059 | +2.4% | +| dim=256 | 12 | 7,714 | 6,859 | +12.5% | +| dim=2048 | 12 | 4,331 | 3,810 | +13.7% | +| dim=4096 | 12 | 2,888 | 2,561 | +12.8% | +| dim=4096 | 16 | 2,207 | 2,008 | +9.9% | +| dim=8192 | 16 | 4,331 | 3,935 | +10.1% | + + +### Known Issues + +- **Not Support Moe** - Auto-Partition need flops for each layers, but current profiler from deepspeed not support computing flops for moe. diff --git a/torchtitan/experiments/autopartition/__init__.py b/torchtitan/experiments/autopartition/__init__.py new file mode 100644 index 0000000000..2102ec1b38 --- /dev/null +++ b/torchtitan/experiments/autopartition/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +__all__ = [ + "get_deepseek_v3_train_spec", + "get_llama3_train_spec", +] + + +from .deepseek_v3_tain_spec import get_deepseek_v3_train_spec +from .llama3_tain_spec import get_llama3_train_spec diff --git a/torchtitan/experiments/autopartition/deepseek_v3/args.py b/torchtitan/experiments/autopartition/deepseek_v3/args.py new file mode 100644 index 0000000000..48d4b5ece1 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3/args.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass, field + +from torch import nn + +from torchtitan.config import JobConfig +from torchtitan.models.moe import MoEArgs +from torchtitan.models.utils import get_moe_model_nparams_and_flops +from torchtitan.protocols.model import BaseModelArgs +from torchtitan.tools.logging import logger +from torchtitan.tools.utils import has_cuda_capability + + +# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py +@dataclass +class DeepSeekV3ModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + inter_dim (int): Intermediate dimension for MLP layers. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + n_dense_layers (int): Number of dense layers in the model. + n_heads (int): Number of attention heads. + norm_eps (float): Epsilon value used for RMSNorm. + moe_args (MoEArgs): MoE configuration. + n_expert_groups (int): Number of expert groups. + n_limited_groups (int): Number of limited groups for MoE routing. + q_lora_rank (int): LoRA rank for query projections. + kv_lora_rank (int): LoRA rank for key-value projections. + qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. + qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. + v_head_dim (int): Dimension for value projections. + use_flex_attn (bool): Whether to use FlexAttention. + attn_mask_type (str): Type of attention mask. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 4096 * 4 + vocab_size: int = 102400 + dim: int = 2048 + inter_dim: int = 10944 + moe_inter_dim: int = 1408 + n_layers: int = 27 + n_dense_layers: int = 1 + n_heads: int = 16 + norm_eps: float = 1e-5 # eps used for RMSNorm + + # MoE + moe_args: MoEArgs = field(default_factory=MoEArgs) + # TODO: node-limited routing is not supported yet + n_expert_groups: int = 1 + n_limited_groups: int = 1 + + # Multi-Head Latent Attention (MLA) + q_lora_rank: int = 0 + kv_lora_rank: int = 512 + qk_nope_head_dim: int = 128 + qk_rope_head_dim: int = 64 + v_head_dim: int = 128 + use_flex_attn: bool = False + attn_mask_type: str = "causal" + + # yarn + original_seq_len: int = 4096 + rope_theta: float = 10000.0 + rope_factor: float = 40 + beta_fast: int = 32 + beta_slow: int = 1 + mscale: float = 1.0 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): + logger.warning( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + self.moe_args.use_grouped_mm = False + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + self.moe_args._debug_force_load_balance = ( + job_config.debug.moe_force_load_balance + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + return get_moe_model_nparams_and_flops( + self, + model, + self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim, + seq_len, + ) diff --git a/torchtitan/experiments/autopartition/deepseek_v3/model.py b/torchtitan/experiments/autopartition/deepseek_v3/model.py new file mode 100644 index 0000000000..3cf56eb1b2 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3/model.py @@ -0,0 +1,434 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math + +import torch +from torch import nn + +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.models.moe import FeedForward, MoE +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import DeepSeekV3ModelArgs + + +# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294 +def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor: + """ + Precomputes frequency-based complex exponential values for rotary positional embeddings. + + Args: + args (DeepSeekV3ModelArgs): Model arguments containing positional embedding parameters. + + Returns: + torch.Tensor: Precomputed complex exponential values for positional embeddings. + """ + dim = args.qk_rope_head_dim + seqlen = args.max_seq_len + beta_fast = args.beta_fast + beta_slow = args.beta_slow + base = args.rope_theta + factor = args.rope_factor + + def find_correction_dim( + num_rotations: float, dim: int, base: float, max_seq_len: int + ) -> float: + """ + Computes the correction dimension for a given number of rotations in the rotary positional embedding. + + Args: + num_rotations (float): Number of rotations to compute the correction for. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + float: The correction dimension based on the input parameters. + """ + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) + + def find_correction_range( + low_rot: float, high_rot: float, dim: int, base: float, max_seq_len: int + ) -> tuple[int, int]: + """ + Computes the range of correction dimensions for rotary positional embeddings. + + Args: + low_rot (float): Lower bound for the number of rotations. + high_rot (float): Upper bound for the number of rotations. + dim (int): Dimensionality of the embedding space. + base (float): Base value for the exponential computation. + max_seq_len (int): Maximum sequence length. + + Returns: + tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. + """ + low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) + high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) + return max(low, 0), min(high, dim - 1) + + def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: + """ + Computes a linear ramp function used to smooth values between a minimum and maximum range. + + Args: + min (float): Minimum value for the ramp function. + max (float): Maximum value for the ramp function. + dim (int): Dimensionality of the ramp tensor. + + Returns: + torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, + clamped to the range [0, 1]. + """ + if min == max: + max += 0.001 + linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + # Basic RoPE frequency calculation + freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) + + # YaRN scaling for extended context. YaRN is used to extend the context length after pre-training. + if seqlen > args.original_seq_len: + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, args.original_seq_len + ) + smooth = 1 - linear_ramp_factor(low, high, dim // 2) + freqs = freqs / factor * (1 - smooth) + freqs * smooth + + # Create position indices + t = torch.arange(seqlen) + + # Outer product: [positions] × [frequencies] + freqs = torch.outer(t, freqs) + + # Convert to complex exponentials: e^(i*freq*pos) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) + return freqs_cis + + +def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: + """ + Applies rotary positional embeddings to the input tensor. + + Args: + x (torch.Tensor): Input tensor with positional embeddings to be applied. + freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + + Returns: + torch.Tensor: Tensor with rotary embeddings applied. + """ + dtype = x.dtype + x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) + freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) + y = torch.view_as_real(x * freqs_cis).flatten(3) + return y.to(dtype) + + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__() + self.dim = model_args.dim + self.n_heads = model_args.n_heads + self.q_lora_rank = model_args.q_lora_rank + self.kv_lora_rank = model_args.kv_lora_rank + self.qk_nope_head_dim = model_args.qk_nope_head_dim + self.qk_rope_head_dim = model_args.qk_rope_head_dim + self.qk_head_dim = model_args.qk_nope_head_dim + model_args.qk_rope_head_dim + self.v_head_dim = model_args.v_head_dim + + if self.q_lora_rank == 0: + self.wq = nn.Linear(self.dim, self.n_heads * self.qk_head_dim, bias=False) + else: + self.wq_a = nn.Linear(self.dim, self.q_lora_rank, bias=False) + self.q_norm = nn.RMSNorm(self.q_lora_rank, eps=model_args.norm_eps) + self.wq_b = nn.Linear( + self.q_lora_rank, self.n_heads * self.qk_head_dim, bias=False + ) + self.wkv_a = nn.Linear( + self.dim, self.kv_lora_rank + self.qk_rope_head_dim, bias=False + ) + self.kv_norm = nn.RMSNorm(self.kv_lora_rank, eps=model_args.norm_eps) + self.wkv_b = nn.Linear( + self.kv_lora_rank, + self.n_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + ) + self.wo = nn.Linear(self.n_heads * self.v_head_dim, self.dim, bias=False) + self.softmax_scale = self.qk_head_dim**-0.5 + + if model_args.max_seq_len > model_args.original_seq_len: + mscale = 0.1 * model_args.mscale * math.log(model_args.rope_factor) + 1.0 + self.softmax_scale = self.softmax_scale * mscale * mscale + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + + # Query projection + if self.q_lora_rank == 0: + q = self.wq(x) # (bsz, seqlen, n_heads * qk_head_dim) + else: + q = self.wq_a(x) + q = self.wq_b(self.q_norm(q)) + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of q and kv as TP may have sharded them after + # the above linear ops. + q = q.view(bsz, seqlen, -1, self.qk_head_dim) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) + q_pe = apply_rotary_emb(q_pe, freqs_cis) + q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) + + # Key-value projection + kv = self.wkv_a(x) # (bsz, seqlen, kv_lora_rank + qk_rope_head_dim) + kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + + k_pe = apply_rotary_emb( + k_pe.unsqueeze(2), freqs_cis + ) # (bsz, seqlen, 1, qk_rope_head_dim) + + kv = self.wkv_b( + self.kv_norm(kv) + ) # (bsz, seqlen, n_heads * (qk_nope_head_dim + v_head_dim)) + kv = kv.view(bsz, seqlen, -1, self.qk_nope_head_dim + self.v_head_dim) + k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k = torch.cat( + [k_nope, k_pe.expand(-1, -1, self.n_heads, -1)], dim=-1 + ) # (bsz, seqlen, n_heads, qk_head_dim) + + q = q.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) + v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + else: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) + + # Reshape and project output + output = output.transpose( + 1, 2 + ).contiguous() # (bsz, seqlen, n_heads, v_head_dim) + output = output.view(bsz, seqlen, -1) # (bsz, seqlen, n_heads * v_head_dim) + return self.wo(output) # (bsz, seqlen, dim) + + def init_weights(self, init_std: float): + linear_list = [ + self.wkv_a, + self.wkv_b, + ] + if self.q_lora_rank > 0: + linear_list.extend([self.wq_a, self.wq_b]) + else: + linear_list.append(self.wq) + + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + self.kv_norm.reset_parameters() + if self.q_lora_rank > 0: + self.q_norm.reset_parameters() + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: DeepSeekV3ModelArgs): + + super().__init__() + self.attention = Attention(model_args) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + self.moe_enabled = layer_id >= model_args.n_dense_layers + if self.moe_enabled: + self.moe = MoE( + model_args.moe_args, + dim=model_args.dim, + hidden_dim=model_args.moe_inter_dim, + ) + else: + self.feed_forward = FeedForward(model_args.dim, model_args.inter_dim) + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + x = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + if self.moe_enabled: + x = x + self.moe(self.ffn_norm(x)) + else: + x = x + self.feed_forward(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + if self.moe_enabled: + self.moe.init_weights(self.weight_init_std, buffer_device) + else: + self.feed_forward.init_weights(self.weight_init_std) + + +class DeepSeekV3Model(nn.Module, ModelProtocol): + """ + DeepSeek-V3 Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__() + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer( + "freqs_cis", precompute_freqs_cis(model_args), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + + self.norm = nn.RMSNorm(model_args.dim) + self.output = nn.Linear( + model_args.dim, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = precompute_freqs_cis(self.model_args) + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + ): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + + h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis, attention_masks) + h = self.norm(h) if self.norm is not None else h + output = self.output(h) if self.output is not None else h + return output diff --git a/torchtitan/experiments/autopartition/deepseek_v3/state_dict_adapter.py b/torchtitan/experiments/autopartition/deepseek_v3/state_dict_adapter.py new file mode 100644 index 0000000000..fd4ec30284 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3/state_dict_adapter.py @@ -0,0 +1,206 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import re +from typing import Any + +import torch +from torch.distributed.checkpoint import HuggingFaceStorageReader + +from torch.distributed.tensor import DTensor +from torchtitan.models.utils import MoEStateDictAdapter + +from .args import DeepSeekV3ModelArgs + + +class DeepSeekV3StateDictAdapter(MoEStateDictAdapter): + """ + StateDictAdapter for DeepSeekV3 model. + """ + + def __init__( + self, + model_args: DeepSeekV3ModelArgs, + hf_assets_path: str | None, + ): + super().__init__(model_args, hf_assets_path) + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention Module + "model.layers.{}.self_attn.kv_a_proj_with_mqa.weight": "layers.{}.attention.wkv_a.weight", + "model.layers.{}.self_attn.kv_a_layernorm.weight": "layers.{}.attention.kv_norm.weight", + "model.layers.{}.self_attn.kv_b_proj.weight": "layers.{}.attention.wkv_b.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + # MLP Module + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + # Transformer Layer + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE Module + "model.layers.{}.mlp.experts.{}.gate_proj.weight": "layers.{}.moe.experts.w1", + "model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3", + "model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2", + "model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight", + "model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_experts.w1.weight", + "model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_experts.w3.weight", + "model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_experts.w2.weight", + "model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.moe.expert_bias", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + # Adjustments for from_hf_map based on model architecture + if model_args.q_lora_rank != 0: + self.from_hf_map.update( + { + "model.layers.{}.self_attn.q_a_proj.weight": "layers.{}.attention.wq_a.weight", + "model.layers.{}.self_attn.q_a_layernorm.weight": "layers.{}.attention.q_norm.weight", + "model.layers.{}.self_attn.q_b_proj.weight": "layers.{}.attention.wq_b.weight", + } + ) + else: + self.from_hf_map.update( + { + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + } + ) + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """ + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. + """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read DeepSeek-V3 671B model. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + BLOCK_SIZE = 128 + return QuantizedHuggingFaceStorageReader( + path=path, + target_dtype=torch.float32, + block_size=BLOCK_SIZE, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + 1. Convert between the HF shape and the torchtitan shape. + 2. Split the GroupedExperts' weight into separate expert's weight. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + + hf_state_dict = {} + + for key, value in state_dict.items(): + if "moe.experts" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_abstract_key = to_hf_map[abstract_key] + + # Store the GroupedExperts Weight metadata for from_hf() + if isinstance(value, DTensor): + self.grouped_expert_weight_placements[ + abstract_key + ] = value.placements + self.grouped_expert_weight_shape[abstract_key] = value.shape + + # Split GroupedExperts weight to local individual expert weights + local_expert_fqn = self._get_local_experts_weights( + new_abstract_key, + abstract_key, + layer_num, + value, + ) + hf_state_dict.update(local_expert_fqn) + + else: + # keep this path for offline conversion + split_values = self._split_experts_weights( + value, self.model_args.moe_args.num_experts + ) + + for expert_num in range(0, self.model_args.moe_args.num_experts): + new_key = new_abstract_key.format(layer_num, expert_num) + hf_state_dict[new_key] = split_values[expert_num].squeeze() + + elif "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + new_key = new_key.format(layer_num) + hf_state_dict[new_key] = value + + else: + new_key = to_hf_map[key] + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """ + 1. When loading from HF checkpoint, dequantize the weights from float8 to float32. + 2. Convert between the HF shape and the torchtitan shape. + 3. Concat separate expert's weight into GroupedExperts' weight. + """ + + state_dict = {} + expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} + + for key, value in hf_state_dict.items(): + if "mlp.experts" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=2) + layer_num, expert_num = re.findall(r"\d+", key) + titan_abstract_key = self.from_hf_map[abstract_key] + new_key = titan_abstract_key.format(layer_num) + + # Store the expert's weight in expert_weights_by_layer for concatenating later. + if layer_num not in expert_weights_by_layer: + expert_weights_by_layer[layer_num] = {} + if titan_abstract_key not in expert_weights_by_layer[layer_num]: + expert_weights_by_layer[layer_num][titan_abstract_key] = {} + expert_weights_by_layer[layer_num][titan_abstract_key][ + int(expert_num) + ] = value + + if isinstance(value, DTensor): + stacked_value = self._concatenate_expert_weights_dtensor( + expert_weights_by_layer, + titan_abstract_key, + layer_num, + value.device_mesh, + ) + else: # keep this path to be compatible with offline conversion + stacked_value = self._concatenate_expert_weights( + expert_weights_by_layer, + titan_abstract_key, + layer_num, + self.model_args.moe_args.num_experts, + ) + + if stacked_value is not None: + state_dict[new_key] = stacked_value + + elif "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + new_key = new_key.format(layer_num) + state_dict[new_key] = value + + else: + new_key = self.from_hf_map[key] + state_dict[new_key] = value + + return state_dict diff --git a/torchtitan/experiments/autopartition/deepseek_v3_tain_spec.py b/torchtitan/experiments/autopartition/deepseek_v3_tain_spec.py new file mode 100644 index 0000000000..a11af94be4 --- /dev/null +++ b/torchtitan/experiments/autopartition/deepseek_v3_tain_spec.py @@ -0,0 +1,172 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.models.moe import MoEArgs +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize_deepseek_v3 import parallelize_deepseekv3 +from .infra.pipeline_parallel import pipeline_llm +from .deepseek_v3.args import DeepSeekV3ModelArgs +from .deepseek_v3.model import DeepSeekV3Model +from .deepseek_v3.state_dict_adapter import DeepSeekV3StateDictAdapter + +__all__ = [ + "parallelize_deepseekv3", + "DeepSeekV3ModelArgs", + "DeepSeekV3Model", + "deepseekv3_args", +] + + +deepseekv3_args = { + "debugmodel": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=4096, + inter_dim=1024, + moe_inter_dim=256, + n_layers=12, + n_dense_layers=12, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), + "debugmodel_flex_attn": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=6, + n_dense_layers=1, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "16B": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=2048, + inter_dim=10944, + moe_inter_dim=1408, + n_layers=27, + n_dense_layers=1, + n_heads=16, + moe_args=MoEArgs( + num_experts=64, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "236B": DeepSeekV3ModelArgs( + vocab_size=102400, + dim=5120, + inter_dim=12288, + moe_inter_dim=1536, + n_layers=60, + n_dense_layers=1, + n_heads=128, + moe_args=MoEArgs( + num_experts=160, + num_shared_experts=2, + top_k=6, + score_func="softmax", + route_norm=False, + route_scale=16.0, + score_before_experts=False, + ), + n_expert_groups=8, + n_limited_groups=3, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "671B": DeepSeekV3ModelArgs( + vocab_size=129280, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=3, + n_heads=128, + moe_args=MoEArgs( + num_experts=256, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.5, + score_before_experts=False, + ), + n_expert_groups=8, + n_limited_groups=4, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=True, + attn_mask_type="block_causal", + ), +} + + +def get_deepseek_v3_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=deepseekv3_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autopartition/infra/cpp/CMakeLists.txt b/torchtitan/experiments/autopartition/infra/cpp/CMakeLists.txt new file mode 100644 index 0000000000..d9ffa4ffba --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/cpp/CMakeLists.txt @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +cmake_minimum_required(VERSION 3.12) +project(autopipe) + +# 使用最简单的方式,避免所有 Modern CMake 特性 + +# 查找 Python +find_package(PythonInterp REQUIRED) +find_package(PythonLibs REQUIRED) + +# 获取 Python 扩展名 +execute_process( + COMMAND ${PYTHON_EXECUTABLE} -c "import sysconfig; print(sysconfig.get_config_var('EXT_SUFFIX') or '.so')" + OUTPUT_VARIABLE PYTHON_MODULE_EXTENSION + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# 获取 pybind11 包含目录 +execute_process( + COMMAND ${PYTHON_EXECUTABLE} -c "import pybind11; print(pybind11.get_include())" + OUTPUT_VARIABLE PYBIND11_INCLUDE_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE +) + +# 创建模块 +add_library(autopipe MODULE autopipe.cpp) + +# 设置目标属性 +set_target_properties(autopipe PROPERTIES + PREFIX "" + SUFFIX ${PYTHON_MODULE_EXTENSION} + OUTPUT_NAME "autopipe" + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR} +) + +# 包含目录 +include_directories( + ${PYBIND11_INCLUDE_DIR} + ${PYTHON_INCLUDE_DIRS} +) + +# 链接库 +target_link_libraries(autopipe ${PYTHON_LIBRARIES}) diff --git a/torchtitan/experiments/autopartition/infra/cpp/autopipe.cpp b/torchtitan/experiments/autopartition/infra/cpp/autopipe.cpp new file mode 100644 index 0000000000..f481a27362 --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/cpp/autopipe.cpp @@ -0,0 +1,569 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. +// +// Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +// Algorithm for auto pipeline partition according to critical path for synchronized pipeline. +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace py = pybind11; +using namespace std; + +namespace torchpipe { + +// 常量定义 +constexpr long long kCommunicationOverhead = 0; +constexpr long long kMaxLongLong = std::numeric_limits::max(); +constexpr int kMaxInt32 = std::numeric_limits::max(); + +// 前向声明 +class PipelinePartitioner { +public: + static vector merak_pipe( + const vector& forward_times, + const vector& backward_times, + int num_stages + ); + +private: + struct PartitionResult { + vector> partition; + long long cost; + int critical_stage; + }; + + // 核心算法函数 + static vector> block_partition_algorithm( + const vector& model, + int num_stages, + const vector>& block_time_mapping + ); + + static void reconstruct_partitions( + const vector& model, + const vector& prefix_sum, + const vector>& dp, + int remaining_blocks, + int remaining_partitions, + vector>& partition + ); + + static pair calculate_training_time( + const vector>& partition, + const vector>& block_time_mapping + ); + + static void calculate_stage_times( + const vector>& partition, + const vector>& block_time_mapping, + vector& forward_time, + vector& backward_time, + vector& last_microbatch + ); + + static pair calculate_steady_phase( + const vector& last_batch, + const vector& forward_time, + const vector& backward_time + ); + + static long long calculate_cooldown_phase( + int num_stages, + int critical_stage, + long long last_forward_start, + const vector& forward_time, + const vector& backward_time + ); + + static PartitionResult find_best_partition( + const vector>& block_time_mapping, + int num_stages, + const vector>& initial_partition, + const vector& prefix_sum, + const vector>& dp_array + ); + + static void calculate_prefix_sum_and_dp( + const vector& model, + int num_stages, + const vector>& block_time_mapping, + vector& prefix_sum, + vector>& dp_array + ); +}; + +// 实现部分 +void PipelinePartitioner::calculate_prefix_sum_and_dp( + const vector& model, + int num_stages, + const vector>& block_time_mapping, + vector& prefix_sum, + vector>& dp_array +) { + int num_blocks = model.size(); + int max_partitions = min(num_blocks, num_stages); + + // 计算前缀和 + prefix_sum.clear(); + prefix_sum.reserve(num_blocks + 1); + prefix_sum.push_back(0); + + for (int i = 0; i < num_blocks; ++i) { + int block = model[i]; + prefix_sum.push_back(prefix_sum.back() + + block_time_mapping[0][block] + + block_time_mapping[1][block]); + } + + // 动态规划数组 + dp_array.assign(num_blocks + 1, vector(max_partitions + 1, kMaxLongLong)); + dp_array[0][0] = 0; + + // 动态规划计算 + for (int blocks = 1; blocks <= num_blocks; ++blocks) { + int max_p = min(blocks, max_partitions); + for (int partitions = 1; partitions <= max_p; ++partitions) { + long long min_val = kMaxLongLong; + for (int prev_blocks = 0; prev_blocks < blocks; ++prev_blocks) { + long long val = max(dp_array[prev_blocks][partitions - 1], + prefix_sum[blocks] - prefix_sum[prev_blocks]); + min_val = min(min_val, val); + if (min_val == 0) break; + } + dp_array[blocks][partitions] = min_val; + } + } +} + +vector> PipelinePartitioner::block_partition_algorithm( + const vector& model, + int num_stages, + const vector>& block_time_mapping +) { + vector prefix_sum; + vector> dp_array; + + calculate_prefix_sum_and_dp(model, num_stages, block_time_mapping, prefix_sum, dp_array); + + vector> partition; + reconstruct_partitions(model, prefix_sum, dp_array, + model.size(), num_stages, partition); + reverse(partition.begin(), partition.end()); + + return partition; +} + +void PipelinePartitioner::reconstruct_partitions( + const vector& model, + const vector& prefix_sum, + const vector>& dp_array, + int remaining_blocks, + int remaining_partitions, + vector>& partition +) { + if (remaining_blocks == 0 && remaining_partitions == 0) return; + + if (remaining_blocks <= 0 || remaining_partitions <= 0 || + remaining_blocks < remaining_partitions) { + throw runtime_error("Error during partition reconstruction"); + } + + int prev_end = 0; + while (prev_end < remaining_blocks && + dp_array[remaining_blocks][remaining_partitions] != + max(dp_array[prev_end][remaining_partitions - 1], + prefix_sum[remaining_blocks] - prefix_sum[prev_end])) { + ++prev_end; + } + + vector current_partition; + current_partition.reserve(remaining_blocks - prev_end); + for (int i = prev_end + 1; i <= remaining_blocks; ++i) { + current_partition.push_back(model[i - 1]); + } + partition.push_back(move(current_partition)); + + reconstruct_partitions(model, prefix_sum, dp_array, prev_end, + remaining_partitions - 1, partition); +} + +void PipelinePartitioner::calculate_stage_times( + const vector>& partition, + const vector>& block_time_mapping, + vector& forward_time, + vector& backward_time, + vector& last_microbatch +) { + int num_stages = partition.size(); + int num_microbatches = num_stages * 2; + + // 构建最后微批次数组 + for (int i = 0; i < num_stages; ++i) { + last_microbatch[i] = num_microbatches - num_stages + i; + } + + // 计算每个阶段的前向和后向时间 + for (int i = 1; i <= num_stages; ++i) { + long long forward_sum = 0, backward_sum = 0; + for (int block_type : partition[i - 1]) { + forward_sum += block_time_mapping[0][block_type]; + backward_sum += block_time_mapping[1][block_type]; + } + forward_time[i] = forward_sum; + backward_time[i] = backward_sum; + } +} + +pair PipelinePartitioner::calculate_steady_phase( + const vector& last_batch, + const vector& forward_time, + const vector& backward_time +) { + int num_stages = last_batch.size(); + int num_microbatches = num_stages * 2; + + // 动态规划数组 + vector>> dp(num_stages + 2, + vector>(num_microbatches, + vector(2, 0))); + + // 初始化 + long long initial_backward_start = 0; + for (int stage = 0; stage < num_stages; ++stage) { + initial_backward_start += forward_time[stage + 1]; + if (stage != num_stages - 1) initial_backward_start += kCommunicationOverhead; + } + + for (int stage = num_stages - 1; stage >= 0; --stage) { + dp[stage + 1][0][0] = kMaxLongLong; + dp[stage + 1][0][1] = initial_backward_start; + initial_backward_start += backward_time[stage + 1] + kCommunicationOverhead; + } + + // 计算稳态阶段 + for (int microbatch = 1; microbatch < num_microbatches; ++microbatch) { + // 前向计算 + for (int stage = 0; stage < num_stages; ++stage) { + if (microbatch <= last_batch[stage]) { + dp[stage + 1][microbatch][0] = max( + dp[stage][microbatch - 1][0] + forward_time[stage], + dp[stage + 1][microbatch - 1][1] + backward_time[stage + 1] + ); + if (stage != 0) dp[stage + 1][microbatch][0] += kCommunicationOverhead; + } + } + + // 后向计算 + for (int stage = num_stages - 1; stage >= 0; --stage) { + if (microbatch <= last_batch[stage]) { + dp[stage + 1][microbatch][1] = max( + dp[stage + 2][microbatch][1] + backward_time[stage + 2], + dp[stage + 1][microbatch][0] + forward_time[stage + 1] + ); + if (stage != num_stages - 1) dp[stage + 1][microbatch][1] += kCommunicationOverhead; + } + } + } + + // 寻找关键路径阶段 + int critical_stage = num_stages - 1; + while (critical_stage >= 0) { + int microbatch; + long long forward_comm = (critical_stage != 0) ? kCommunicationOverhead : 0; + long long backward_comm = (critical_stage != num_stages - 1) ? kCommunicationOverhead : 0; + + for (microbatch = 1; microbatch <= last_batch[critical_stage]; ++microbatch) { + if (dp[critical_stage + 1][microbatch][0] != + dp[critical_stage + 1][microbatch - 1][1] + + backward_time[critical_stage + 1] + forward_comm) { + break; + } + + if (dp[critical_stage + 1][microbatch][1] != + dp[critical_stage + 1][microbatch][0] + + forward_time[critical_stage + 1] + backward_comm) { + break; + } + } + + if (microbatch == last_batch[critical_stage] + 1) break; + --critical_stage; + } + + if (critical_stage < 0) { + throw runtime_error("Failed to determine critical stage"); + } + + return make_pair(dp[critical_stage + 1][last_batch[critical_stage]][0], + critical_stage); +} + +long long PipelinePartitioner::calculate_cooldown_phase( + int num_stages, + int critical_stage, + long long last_forward_start, + const vector& forward_time, + const vector& backward_time +) { + int vector_size = num_stages - critical_stage; + if (vector_size <= 0) return last_forward_start; + + vector> dp(vector_size, vector(vector_size, 0)); + long long backward_start = last_forward_start; + + // 初始化 + for (int i = 0; i < vector_size; ++i) { + backward_start += forward_time[critical_stage + 1 + i]; + if (critical_stage + i != num_stages - 1) { + backward_start += kCommunicationOverhead; + } + int j = vector_size - 1 - i; + dp[i][j] = backward_start; + } + + // 运行动态规划 + for (int col = vector_size - 2; col >= 0; --col) { + for (int row = vector_size - col - 2; row >= 0; --row) { + long long option1 = dp[row][col + 1] + + backward_time[critical_stage + 1 + row] + + kCommunicationOverhead; + long long option2 = dp[row + 1][col] + + backward_time[critical_stage + 1 + row + 1] + + kCommunicationOverhead; + dp[row][col] = max(option1, option2); + + if (row > 0) { + dp[row][col] = max(dp[row][col], dp[row - 1][col + 1]); + } + } + } + + return dp[0][0]; +} + +pair PipelinePartitioner::calculate_training_time( + const vector>& partition, + const vector>& block_time_mapping +) { + int num_stages = partition.size(); + int num_microbatches = num_stages * 2; + + vector last_microbatch(num_stages); + vector forward_time(num_stages + 2, 0); + vector backward_time(num_stages + 2, 0); + + // 计算阶段时间 + for (int i = 0; i < num_stages; ++i) { + last_microbatch[i] = num_microbatches - num_stages + i; + + long long forward_sum = 0, backward_sum = 0; + for (int block : partition[i]) { + forward_sum += block_time_mapping[0][block]; + backward_sum += block_time_mapping[1][block]; + } + forward_time[i + 1] = forward_sum; + backward_time[i + 1] = backward_sum; + } + + auto steady_result = calculate_steady_phase(last_microbatch, + forward_time, + backward_time); + + long long last_forward_start = steady_result.first; + int critical_stage = steady_result.second; + + if (last_forward_start == kMaxLongLong) { + throw runtime_error("Failed to calculate steady phase"); + } + + long long last_backward_start = calculate_cooldown_phase( + num_stages, critical_stage, last_forward_start, + forward_time, backward_time); + + long long pipeline_flush_time = last_backward_start; + for (int stage = critical_stage; stage > 0; --stage) { + pipeline_flush_time += backward_time[stage + 1] + kCommunicationOverhead; + } + pipeline_flush_time += backward_time[1]; + + return make_pair(pipeline_flush_time, critical_stage); +} + +PipelinePartitioner::PartitionResult PipelinePartitioner::find_best_partition( + const vector>& block_time_mapping, + int num_stages, + const vector>& initial_partition, + const vector& prefix_sum, + const vector>& dp_array +) { + // 哈希函数用于unordered_set + struct VectorHash { + size_t operator()(const vector>& v) const { + size_t hash = 0; + for (const auto& inner : v) { + for (int val : inner) { + hash ^= hash << 13; + hash ^= hash >> 7; + hash ^= hash << 17; + hash ^= val + 0x9e3779b9 + (hash << 6) + (hash >> 2); + } + } + return hash; + } + }; + + struct VectorEqual { + bool operator()(const vector>& a, const vector>& b) const { + if (a.size() != b.size()) return false; + for (size_t i = 0; i < a.size(); ++i) { + if (a[i].size() != b[i].size()) return false; + for (size_t j = 0; j < a[i].size(); ++j) { + if (a[i][j] != b[i][j]) return false; + } + } + return true; + } + }; + + vector last_microbatch(num_stages, 0); + vector forward_time(num_stages + 2, 0); + vector backward_time(num_stages + 2, 0); + + // 初始化最优结果 + PartitionResult best_result; + best_result.cost = kMaxLongLong; + best_result.critical_stage = kMaxInt32; + + // 记录已处理的分区 + unordered_set>, VectorHash, VectorEqual> visited; + queue>> partitions_queue; + partitions_queue.push(initial_partition); + visited.insert(initial_partition); + + while (!partitions_queue.empty()) { + vector> current_partition = partitions_queue.front(); + partitions_queue.pop(); + + // 计算当前分区的时间 + calculate_stage_times(current_partition, block_time_mapping, + forward_time, backward_time, last_microbatch); + + auto time_result = calculate_training_time(current_partition, + block_time_mapping); + long long current_cost = time_result.first; + int current_critical = time_result.second; + + // 更新最优结果 + if (current_cost < best_result.cost) { + best_result.partition = current_partition; + best_result.cost = current_cost; + best_result.critical_stage = current_critical; + } + + // 尝试调整分区(简化版,原逻辑较复杂) + if (current_critical > 0) { + // 尝试移动关键路径前的块 + vector blocks_before; + for (int stage = 0; stage < current_critical; ++stage) { + blocks_before.insert(blocks_before.end(), + current_partition[stage].begin(), + current_partition[stage].end()); + } + + // 添加关键路径的第一个块 + blocks_before.push_back(current_partition[current_critical][0]); + + // 重新分区 + vector> new_partition; + reconstruct_partitions(blocks_before, prefix_sum, dp_array, + blocks_before.size(), current_critical, + new_partition); + reverse(new_partition.begin(), new_partition.end()); + blocks_before.pop_back(); + + // 完成剩余分区 + for (int stage = current_critical; stage < current_partition.size(); ++stage) { + new_partition.push_back(current_partition[stage]); + } + new_partition[current_critical].erase(new_partition[current_critical].begin()); + + // 添加到队列 + if (visited.find(new_partition) == visited.end()) { + partitions_queue.push(new_partition); + visited.insert(new_partition); + } + } + } + + return best_result; +} + +vector PipelinePartitioner::merak_pipe( + const vector& forward_times, + const vector& backward_times, + int num_stages +) { + // 输入验证 + if (forward_times.empty() || backward_times.empty()) { + throw invalid_argument("Input vectors cannot be empty"); + } + + if (forward_times.size() != backward_times.size()) { + throw invalid_argument("Forward and backward vectors must have same size"); + } + + if (num_stages <= 0 || num_stages > static_cast(forward_times.size())) { + throw invalid_argument("Invalid number of pipeline stages"); + } + + // 准备数据 + vector> block_time_mapping = {forward_times, backward_times}; + vector model(forward_times.size()); + iota(model.begin(), model.end(), 0); + + // 执行算法 + vector> initial_partition = block_partition_algorithm( + model, num_stages, block_time_mapping); + + vector prefix_sum; + vector> dp_array; + calculate_prefix_sum_and_dp(model, num_stages, block_time_mapping, + prefix_sum, dp_array); + + PartitionResult best_result = find_best_partition( + block_time_mapping, num_stages, initial_partition, + prefix_sum, dp_array); + + // 返回每个分区的第一个块索引 + vector result; + for (const auto& partition : best_result.partition) { + result.push_back(partition[0]); + } + + return result; +} + +} // namespace torchpipe + +// Python绑定 +PYBIND11_MODULE(autopipe, m) { + m.doc() = "AutoPipe pipeline partition generator"; + + m.def("pipeline", &torchpipe::PipelinePartitioner::merak_pipe, + "Generate pipeline partition", + py::arg("forward_times"), + py::arg("backward_times"), + py::arg("num_stages")); +} diff --git a/torchtitan/experiments/autopartition/infra/parallelize_deepseek_v3.py b/torchtitan/experiments/autopartition/infra/parallelize_deepseek_v3.py new file mode 100644 index 0000000000..0793820ffd --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/parallelize_deepseek_v3.py @@ -0,0 +1,288 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import NoParallel, ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.models.llama3.infra.parallelize import apply_ddp +from torchtitan.models.llama4.infra.parallelize import ( + apply_compile, + apply_fsdp, + apply_moe_ep_tp, +) +from torchtitan.tools.logging import logger + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_to_all_single.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +# Adapted from llama4/infra/parallelize.py +def parallelize_deepseekv3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if enable_float8_tensorwise_tp: + # TODO(jianiw): This branch needs to be tested and enabled + raise NotImplementedError( + "Currently, float8 tensorwise TP is not tested for deepseekv3" + ) + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + use_flex_attn=use_flex_attn, + ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled + else None + ), + etp_enabled=parallel_dims.etp_enabled, + ) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, + ) + + if model_compile_enabled: + apply_compile(model, job_config.compile) + + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if parallel_dims.ep_enabled + else None + ), + gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + dp_mesh = world_mesh + apply_ddp( + model, + dp_mesh, + enable_compile=model_compile_enabled, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + use_flex_attn: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + if use_flex_attn: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) + else: + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), Replicate(), None), + desired_input_layouts=(Replicate(), Replicate(), None), + ), + # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor + # so that the intermedidate results k is generated as a DTensor and its gradient is + # correctly handled by the autograd engine. + "attention.wkv_a": NoParallel(use_local_output=False), + "attention.wkv_b": colwise_parallel(use_local_output=False), + "attention.kv_norm": NoParallel(use_local_output=False), + # NOTE: use_local_output=True so that the inputs to FlexAttention are plain Tensors + "attention.inner_attention": attention_kernel_plan, + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + if transformer_block.attention.q_lora_rank == 0: + layer_plan.update( + { + "attention.wq": colwise_parallel( + use_local_output=False + ), # This is only used when q_lora_rank==0 + } + ) + else: + layer_plan.update( + { + "attention.wq_a": NoParallel(use_local_output=False), + "attention.wq_b": colwise_parallel(use_local_output=False), + "attention.q_norm": NoParallel(use_local_output=False), + } + ) + + if not transformer_block.moe_enabled: + layer_plan.update( + { + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + ) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" + "Tensor Parallelism to the model" + ) diff --git a/torchtitan/experiments/autopartition/infra/parallelize_llama.py b/torchtitan/experiments/autopartition/infra/parallelize_llama.py new file mode 100644 index 0000000000..86ac3a6dfe --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/parallelize_llama.py @@ -0,0 +1,330 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + +import torch +import torch.nn as nn +from torch.distributed._composable.replicate import replicate + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.config.job_config import Compile as CompileConfig +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.tools.logging import logger + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +def parallelize_llama( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + ) + maybe_enable_async_tp(job_config, world_mesh["tp"]) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, + ) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if model_compile_enabled: + apply_compile(model, job_config.compile) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + + apply_fsdp( + model, + world_mesh[tuple(dp_mesh_dim_names)], + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + world_mesh, + enable_compile=model_compile_enabled, + ) + + return model + + +def apply_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": prepare_module_input( + input_layouts=(Shard(1), None, None), + desired_input_layouts=(Replicate(), None, None), + ), + "attention.wq": colwise_parallel(), + "attention.wk": colwise_parallel(), + "attention.wv": colwise_parallel(), + "attention.wo": rowwise_parallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": colwise_parallel(), + "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), + "feed_forward.w3": colwise_parallel(), + } + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_compile(model: nn.Module, compile_config: CompileConfig): + """ + Apply torch.compile to each TransformerBlock, which makes compilation efficient due to + repeated structure. Alternatively one can compile the whole model (after applying DP). + """ + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = torch.compile( + transformer_block, backend=compile_config.backend, fullgraph=True + ) + model.layers.register_module(layer_id, transformer_block) + + logger.info("Compiling each TransformerBlock with torch.compile") + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + for layer_id, transformer_block in model.layers.items(): + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + fully_shard(model, **fsdp_config) + + +def apply_ddp( + model: nn.Module, + dp_mesh: DeviceMesh, + enable_compile: bool, +): + if enable_compile: + torch._dynamo.config.optimize_ddp = "ddp_optimizer" + + replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) + + logger.info("Applied DDP to the model") diff --git a/torchtitan/experiments/autopartition/infra/pipeline_parallel.py b/torchtitan/experiments/autopartition/infra/pipeline_parallel.py new file mode 100644 index 0000000000..912e9a058a --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/pipeline_parallel.py @@ -0,0 +1,598 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy +import math +import os +from typing import Callable + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + _PipelineScheduleRuntime, + get_schedule_class, + PipelineScheduleMulti, + PipelineScheduleSingle, + ScheduleDualPipeV, + ScheduleZBVZeroBubble, +) +from torchtitan.components.loss import LossFunction, rescale_accumulated_loss +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.experiments.autopartition.infra.autopipe import pipeline +from torchtitan.experiments.autopartition.infra.profiler import FlopsProfiler +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction +from torchtitan.tools.logging import logger + +__all__ = [ + "pipeline_llm", + "build_pipeline_schedule", + "generate_llm_fqn_per_model_part", + "pipeline_module_split", +] + + +def autopipe_partition(model, num_stages, job_config): + """Partition layers based on automatic pipeline profiling. + + This method profiles each layer's computational cost (FLOPS) and + distributes layers to balance computation across stages. + + Args: + input_to_shard_dict: Dictionary containing input sharding information. + + Returns: + List of integers representing the number of layers assigned to each stage. + """ + + # Prepare input for profiling + tokenizer = build_hf_tokenizer(job_config) + + # build dataloader + dataloader = build_text_dataloader( + dp_world_size=1, + dp_rank=0, + tokenizer=tokenizer, + job_config=job_config, + ) + iterator = iter(dataloader) + inputs = next(iterator)[0].values() + + # Profile each layer's FLOPS + mflops_list = [] + for _, layer in enumerate(model): + prof = FlopsProfiler(layer) + prof.start_profile() + nparams_dense = 0 + for p in layer.parameters(): + nparams_dense += p.numel() + if isinstance(inputs, torch.Tensor): + inputs = layer(inputs) + else: + inputs = layer(*inputs) + mflops = prof.get_total_flops() / 10**6 # Convert to million FLOPS + mflops_list.append(round(mflops)) + prof.end_profile() + + logger.info(f"Autopipe partitioning with mflops: {mflops_list}") + + parts = pipeline( + mflops_list, + [i * 3 for i in mflops_list], # Assume backward is 3x forward + num_stages, + ) + parts.append(len(model)) # Add the total number of layers + return parts + + +def _build_module_for_profile(model, flatten_module_names): + # txd: merge autopipe + module_names_for_profile = [[item] for item in flatten_module_names] + + def _build_sequential_module( + module_names: list[str], + ) -> tuple[PipelineStage, nn.Module]: + + # Create a set of modules to keep for faster lookup + # modules_to_keep = set(module_names) + module_seq = nn.Sequential() + for mtk in module_names: + whole_model = copy.deepcopy(model) + modules_to_keep = set(mtk) + for module_name, module_value in whole_model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(whole_model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(whole_model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(whole_model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(whole_model, module_name, None) + module_seq.append(copy.deepcopy(whole_model)) + return module_seq + + seq_module = _build_sequential_module(module_names_for_profile) + + return seq_module + + +def pipeline_llm( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_args: BaseModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + pp_mesh = parallel_dims.world_mesh["pp"] + + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage + if hasattr(model_args, "n_layers"): + num_layers = model_args.n_layers + else: + raise ValueError("Model does not have n_layers attribute.") + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers + output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers + + # Calculate number of virtual stages + if layers_per_stage is not None: + + # Calculate number of virtual stages needed (using ceiling division) + # This allows for unequal distribution where stages can differ by at most 1 layer + num_virtual_stages = math.ceil( + (num_layers + input_weight + output_weight) / layers_per_stage + ) + + # Validation: check stages per rank based on schedule type + model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}" + stage_distribution_info = ( + f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks" + ) + + if num_virtual_stages % parallel_dims.pp != 0: + raise ValueError( + f"Number of virtual stages ({num_virtual_stages}) must be divisible by " + f"pipeline parallel size ({parallel_dims.pp}). " + f"{model_config_info}. " + f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages " + f"divisible by {parallel_dims.pp}." + ) + + stages_per_rank = num_virtual_stages // parallel_dims.pp + + if is_single_stage_schedule and stages_per_rank != 1: + raise ValueError( + f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher " + f"to achieve 1 stage per rank." + ) + + if not is_single_stage_schedule and stages_per_rank < 2: + raise ValueError( + f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank." + ) + else: + # Fallback to default behavior when layers_per_stage is not provided + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + + # if job_config.custom_config.auto_partition: + flatten_module_names = [ + item for sublist in module_names_per_stage for item in sublist + ] + seq_modules = _build_module_for_profile(model, flatten_module_names) + parts = autopipe_partition(seq_modules, parallel_dims.pp, job_config) + module_names_per_stage = [ + flatten_module_names[parts[i] : parts[i + 1]] for i in range(parallel_dims.pp) + ] + + for i, stage_ms in enumerate(module_names_per_stage): + logger.debug(f"Stage {i}: {stage_ms}") + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage + + +def build_pipeline_schedule( + job_config: JobConfig, stages: list[PipelineStage], loss_fn: Callable +) -> _PipelineSchedule: + """Builds a pipeline schedule for the given job configuration and stages. + + Args: + job_config (JobConfig): The job configuration. + stages (list[PipelineStage]): The stages to be scheduled. + loss_fn (Callable): The loss function. + + Returns: + _PipelineSchedule: The pipeline schedule for the given stages. + """ + pp_schedule_csv = job_config.parallelism.pipeline_parallel_schedule_csv + + # Validate that pp_schedule_csv is a valid path + if pp_schedule_csv: + if not os.path.isfile(pp_schedule_csv): + raise FileNotFoundError( + f"The specified path {pp_schedule_csv} does not exist or is not a file." + ) + schedule_class = _PipelineScheduleRuntime + else: + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + + looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) + microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size + batch_size = job_config.training.local_batch_size + # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training + if batch_size % microbatch_size != 0: + raise ValueError( + f"Batch size {job_config.training.local_batch_size} must be divisible by microbatch_size {microbatch_size}. " + "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." + ) + n_microbatches = batch_size // microbatch_size + # We expect that the number of local stages (`len(stages)`) is the same across all ranks + num_total_stages = job_config.parallelism.pipeline_parallel_degree * len(stages) + if n_microbatches < num_total_stages: + logger.warning( + f"Number of microbatches ({n_microbatches}) is less than the total number " + f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." + ) + + schedule = schedule_class( + stages if looped_schedule else stages[0], + n_microbatches=n_microbatches, + loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches), + scale_grads=False, + ) + logger.info( + f"Using pipeline schedule {job_config.parallelism.pipeline_parallel_schedule} " + f"with {n_microbatches} microbatches and {num_total_stages} stages." + ) + + if pp_schedule_csv: + assert schedule_class in [ + PipelineScheduleSingle, + PipelineScheduleMulti, + _PipelineScheduleRuntime, + ], ( + "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " + "and _PipelineScheduleRuntime support csv schedules" + ) + schedule._load_csv(pp_schedule_csv) + + return schedule + + +def generate_llm_fqn_per_model_part( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names model part, focused on LLMs models. + + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (tok_embeddings) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + + Returns: + List of lists containing module names for each model part + + Example: + generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Feasibility check: Ensure at least 1 layer in each PP stage + if layers_per_stage == 0: + raise ValueError( + f"Configuration would result in empty stages. " + f"With {num_stages} stages and {num_effective_layers} effective layers " + f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), " + f"each stage would get {layers_per_stage} layers on average. " + f"Reduce num_stages or increase num_layers/weights." + ) + + # Balance check: Ensure weights don't exceed minimum layers per stage + if input_weight > layers_per_stage: + raise ValueError( + f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + if output_weight > layers_per_stage: + raise ValueError( + f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Some model restrictions include: + - forward() method should tolerate deleted layers + - weight initialization methods should tolerate deleted layers + - Does not support nested moduledict and modulelist structures + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_degree = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with None + setattr(model, module_name, None) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) + + def _get_stage_indices() -> tuple[int]: + """ + Compute the stage ids for the stages that will run on this pp rank + for either a looped or V style schedule + """ + assert ( + num_stages % pp_degree == 0 + ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" + stages_per_rank = num_stages // pp_degree + if style == "loop": + return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) + elif style == "v": + assert ( + stages_per_rank == 2 + ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" + stage_v_pairs = list( + zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) + ) + return stage_v_pairs[pp_rank] + + for stage_idx in _get_stage_indices(): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models diff --git a/torchtitan/experiments/autopartition/infra/profiler.py b/torchtitan/experiments/autopartition/infra/profiler.py new file mode 100644 index 0000000000..f3a5ad6796 --- /dev/null +++ b/torchtitan/experiments/autopartition/infra/profiler.py @@ -0,0 +1,1371 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +import os +import sys +import time +from collections import OrderedDict +from functools import partial +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +Tensor = torch.Tensor + +module_flop_count = [] +module_mac_count = [] +old_functions = {} + +func_flops = {} + + +# Adapted from https://github.com/microsoft/DeepSpeed/blob/5218177922a4be5c14cf0db893dbfcb139179ba5/deepspeed/profiling/flops_profiler/profiler.py +class FlopsProfiler(object): + """Measures the latency, number of estimated floating-point operations and parameters of each module + in a PyTorch model. + + The flops-profiler profiles the forward pass of a PyTorch model and prints the model graph with the + measured profile attached to each module. It shows how latency, flops and parameters are spent in + the model and which modules or layers could be the bottleneck. It also outputs the names of the top + k modules in terms of aggregated latency, flops, and parameters at depth l with k and l specified + by the user. The output profile is computed for each batch of input. + + The DeepSpeed flops profiler can be used with the DeepSpeed runtime or as a standalone package. + When using DeepSpeed for model training, the flops profiler can be configured in the deepspeed_config + file and no user code change is required. + + If using the profiler as a standalone package, one imports the flops_profiler package and use the APIs. + + Here is an example for usage in a typical training workflow: + + .. code-block:: python + + model = Model() + prof = FlopsProfiler(model) + + for step, batch in enumerate(data_loader): + if step == profile_step: + prof.start_profile() + + loss = model(batch) + + if step == profile_step: + flops = prof.get_total_flops(as_string=True) + params = prof.get_total_params(as_string=True) + prof.print_model_profile(profile_step=profile_step) + prof.end_profile() + + loss.backward() + optimizer.step() + + To profile a trained model in inference, use the `get_model_profile` API. + + Args: + object (torch.nn.Module): The PyTorch model to profile. + """ + + def __init__(self, model, ds_engine=None): + self.model = model + self.ds_engine = ds_engine + self.started = False + self.func_patched = False + + def start_profile(self, ignore_list=None): + """Starts profiling. + + Extra attributes are added recursively to all the modules and the profiled torch.nn.functionals + are monkey patched. + + Args: + ignore_list (list, optional): the list of modules to ignore while profiling. Defaults to None. + """ + self.reset_profile() + _patch_functionals() + _patch_tensor_methods() + + def register_module_hooks(module, ignore_list): + if ignore_list and type(module) in ignore_list: + return + + # if computing the flops of a module directly + if type(module) in MODULE_HOOK_MAPPING: + module.__flops_handle__ = module.register_forward_hook( + MODULE_HOOK_MAPPING[type(module)] + ) + return + + # if computing the flops of the functionals in a module + def pre_hook(module, input): + module_flop_count.append([]) + module_mac_count.append([]) + + module.__pre_hook_handle__ = module.register_forward_pre_hook(pre_hook) + + def post_hook(module, input, output): + if module_flop_count: + module.__flops__ += sum([elem[1] for elem in module_flop_count[-1]]) + module_flop_count.pop() + module.__macs__ += sum([elem[1] for elem in module_mac_count[-1]]) + module_mac_count.pop() + + module.__post_hook_handle__ = module.register_forward_hook(post_hook) + + def start_time_hook(module, input): + torch.cuda.synchronize() + module.__start_time__ = time.time() + + module.__start_time_hook_handle__ = module.register_forward_pre_hook( + start_time_hook + ) + + def end_time_hook(module, input, output): + torch.cuda.synchronize() + module.__duration__ += time.time() - module.__start_time__ + + module.__end_time_hook_handle__ = module.register_forward_hook( + end_time_hook + ) + + self.model.apply(partial(register_module_hooks, ignore_list=ignore_list)) + self.started = True + self.func_patched = True + + def stop_profile(self): + """Stop profiling. + + All torch.nn.functionals are restored to their originals. + """ + if self.started and self.func_patched: + _reload_functionals() + _reload_tensor_methods() + global old_functions + old_functions = {} + self.func_patched = False + + def remove_profile_attrs(module): + if hasattr(module, "__pre_hook_handle__"): + module.__pre_hook_handle__.remove() + del module.__pre_hook_handle__ + if hasattr(module, "__post_hook_handle__"): + module.__post_hook_handle__.remove() + del module.__post_hook_handle__ + if hasattr(module, "__flops_handle__"): + module.__flops_handle__.remove() + del module.__flops_handle__ + if hasattr(module, "__start_time_hook_handle__"): + module.__start_time_hook_handle__.remove() + del module.__start_time_hook_handle__ + if hasattr(module, "__end_time_hook_handle__"): + module.__end_time_hook_handle__.remove() + del module.__end_time_hook_handle__ + + self.model.apply(remove_profile_attrs) + + def reset_profile(self): + """Resets the profiling. + + Adds or resets the extra attributes. + """ + + def add_or_reset_attrs(module): + module.__flops__ = 0 + module.__macs__ = 0 + module.__params__ = sum(p.numel() for p in module.parameters()) + module.__start_time__ = 0 + module.__duration__ = 0 + + self.model.apply(add_or_reset_attrs) + + def end_profile(self): + """Ends profiling. + + The added attributes and handles are removed recursively on all the modules. + """ + if not self.started: + return + self.stop_profile() + self.started = False + + def remove_profile_attrs(module): + if hasattr(module, "__flops__"): + del module.__flops__ + if hasattr(module, "__macs__"): + del module.__macs__ + if hasattr(module, "__params__"): + del module.__params__ + if hasattr(module, "__start_time__"): + del module.__start_time__ + if hasattr(module, "__duration__"): + del module.__duration__ + + self.model.apply(remove_profile_attrs) + + def get_total_flops(self, as_string=False): + """Returns the total flops of the model. + + Args: + as_string (bool, optional): whether to output the flops as string. Defaults to False. + + Returns: + The number of multiply-accumulate operations of the model forward pass. + """ + total_flops = get_module_flops(self.model) + return num_to_string(total_flops) if as_string else total_flops + + def get_total_macs(self, as_string=False): + """Returns the total MACs of the model. + + Args: + as_string (bool, optional): whether to output the flops as string. Defaults to False. + + Returns: + The number of multiply-accumulate operations of the model forward pass. + """ + total_macs = get_module_macs(self.model) + return macs_to_string(total_macs) if as_string else total_macs + + def get_total_duration(self, as_string=False): + """Returns the total duration of the model forward pass. + + Args: + as_string (bool, optional): whether to output the duration as string. Defaults to False. + + Returns: + The latency of the model forward pass. + """ + total_duration = get_module_duration(self.model) + return duration_to_string(total_duration) if as_string else total_duration + + def get_total_params(self, as_string=False): + """Returns the total parameters of the model. + + Args: + as_string (bool, optional): whether to output the parameters as string. Defaults to False. + + Returns: + The number of parameters in the model. + """ + return ( + params_to_string(self.model.__params__) + if as_string + else self.model.__params__ + ) + + def print_model_profile( + self, + profile_step=1, + module_depth=-1, + top_modules=1, + detailed=True, + output_file=None, + ): + """Prints the model graph with the measured profile attached to each module. + + Args: + profile_step (int, optional): + The global training step at which to profile. + Note that warm up steps are needed for accurate time measurement. + module_depth (int, optional): + The depth of the model to which to print the aggregated module information. + When set to -1, it prints information from the top to the innermost modules (the maximum depth). + top_modules (int, optional): Limits the aggregated profile output to the number of top modules specified. + detailed (bool, optional): Whether to print the detailed model profile. + output_file (str, optional): Path to the output file. If None, the profiler prints to stdout. + """ + if not self.started: + return + + original_stdout = None + f = None + if output_file and output_file != "": + dir_path = os.path.dirname(output_file) + if not os.path.exists(dir_path): + os.makedirs(dir_path, exist_ok=True) + original_stdout = sys.stdout + f = open(output_file, "w") + sys.stdout = f + + total_flops = self.get_total_flops() + total_macs = self.get_total_macs() + total_duration = self.get_total_duration() + total_params = self.get_total_params() + + self.flops = total_flops + self.macs = total_macs + self.params = total_params + + print( + "\n-------------------------- DeepSpeed Flops Profiler --------------------------" + ) + print(f"Profile Summary at step {profile_step}:") + print( + "Notations:\ndata parallel size (dp_size), model parallel size(mp_size)," + "\nnumber of parameters (params), number of multiply-accumulate operations(MACs)," + "\nnumber of floating-point operations (flops), floating-point operations per second (FLOPS)," + "\nfwd latency (forward propagation latency), bwd latency (backward propagation latency)," + "\nstep (weights update latency), iter latency (sum of fwd, bwd and step latency)\n" + ) + if self.ds_engine: + print("{:<60} {:<8}".format("world size: ", self.ds_engine.world_size)) + print( + "{:<60} {:<8}".format( + "data parallel size: ", self.ds_engine.dp_world_size + ) + ) + print( + "{:<60} {:<8}".format( + "model parallel size: ", self.ds_engine.mp_world_size + ) + ) + print( + "{:<60} {:<8}".format( + "batch size per GPU: ", + self.ds_engine.train_micro_batch_size_per_gpu(), + ) + ) + + print( + "{:<60} {:<8}".format("params per gpu: ", params_to_string(total_params)) + ) + print( + "{:<60} {:<8}".format( + "params of model = params per GPU * mp_size: ", + params_to_string( + total_params + * ((self.ds_engine.mp_world_size) if self.ds_engine else 1) + ), + ) + ) + + print("{:<60} {:<8}".format("fwd MACs per GPU: ", macs_to_string(total_macs))) + + print("{:<60} {:<8}".format("fwd flops per GPU: ", num_to_string(total_flops))) + + print( + "{:<60} {:<8}".format( + "fwd flops of model = fwd flops per GPU * mp_size: ", + num_to_string( + total_flops + * ((self.ds_engine.mp_world_size) if self.ds_engine else 1) + ), + ) + ) + + fwd_latency = self.get_total_duration() + if self.ds_engine and self.ds_engine.wall_clock_breakdown(): + fwd_latency = self.ds_engine.timers("forward").elapsed(False) + print("{:<60} {:<8}".format("fwd latency: ", duration_to_string(fwd_latency))) + print( + "{:<60} {:<8}".format( + "fwd FLOPS per GPU = fwd flops per GPU / fwd latency: ", + flops_to_string(total_flops / fwd_latency), + ) + ) + + global func_flops + print("function flops", func_flops) + func_flops = {} + + if self.ds_engine and self.ds_engine.wall_clock_breakdown(): + bwd_latency = self.ds_engine.timers("backward").elapsed(False) + step_latency = self.ds_engine.timers("step").elapsed(False) + print( + "{:<60} {:<8}".format("bwd latency: ", duration_to_string(bwd_latency)) + ) + print( + "{:<60} {:<8}".format( + "bwd FLOPS per GPU = 2 * fwd flops per GPU / bwd latency: ", + flops_to_string(2 * total_flops / bwd_latency), + ) + ) + print( + "{:<60} {:<8}".format( + "fwd+bwd FLOPS per GPU = 3 * fwd flops per GPU / (fwd+bwd latency): ", + flops_to_string(3 * total_flops / (fwd_latency + bwd_latency)), + ) + ) + + print( + "{:<60} {:<8}".format( + "step latency: ", duration_to_string(step_latency) + ) + ) + + iter_latency = fwd_latency + bwd_latency + step_latency + print( + "{:<60} {:<8}".format( + "iter latency: ", duration_to_string(iter_latency) + ) + ) + print( + "{:<60} {:<8}".format( + "FLOPS per GPU = 3 * fwd flops per GPU / iter latency: ", + flops_to_string(3 * total_flops / iter_latency), + ) + ) + + samples_per_iter = ( + self.ds_engine.train_micro_batch_size_per_gpu() + * self.ds_engine.world_size + ) + print( + "{:<60} {:<8.2f}".format( + "samples/second: ", samples_per_iter / iter_latency + ) + ) + + def flops_repr(module): + params = module.__params__ + flops = get_module_flops(module) + macs = get_module_macs(module) + items = [ + params_to_string(params), + "{:.2%} Params".format(params / total_params), + macs_to_string(macs), + "{:.2%} MACs".format(0.0 if total_macs == 0 else macs / total_macs), + flops_to_string(flops).lower(), + ] + duration = get_module_duration(module) + + items.append(duration_to_string(duration)) + items.append( + "{:.2%} latency".format( + 0.0 if total_duration == 0 else duration / total_duration + ) + ) + items.append(flops_to_string(0.0 if duration == 0 else flops / duration)) + items.append(module.original_extra_repr()) + return ", ".join(items) + + def add_extra_repr(module): + flops_extra_repr = flops_repr.__get__(module) + if module.extra_repr != flops_extra_repr: + module.original_extra_repr = module.extra_repr + module.extra_repr = flops_extra_repr + assert module.extra_repr != module.original_extra_repr + + def del_extra_repr(module): + if hasattr(module, "original_extra_repr"): + module.extra_repr = module.original_extra_repr + del module.original_extra_repr + + self.model.apply(add_extra_repr) + + print( + "\n----------------------------- Aggregated Profile per GPU -----------------------------" + ) + self.print_model_aggregated_profile( + module_depth=module_depth, top_modules=top_modules + ) + + if detailed: + print( + "\n------------------------------ Detailed Profile per GPU ------------------------------" + ) + print( + "Each module profile is listed after its name in the following order: " + "\nparams, percentage of total params, MACs, percentage of total MACs, fwd latency," + "percentage of total fwd latency, fwd FLOPS" + ) + print( + "\nNote: 1. A module can have torch.nn.module or torch.nn.functional to compute logits" + "(e.g. CrossEntropyLoss). They are not counted as submodules, thus not to be printed out." + "However they make up the difference between " + "a parent's MACs (or latency) and the sum of its submodules'." + "\n2. Number of floating-point operations is a theoretical estimation, thus FLOPS computed" + "using that could be larger than the maximum system throughput." + "\n3. The fwd latency listed in the top module's profile is directly captured at the module forward" + "function in PyTorch, thus it's less than the fwd latency shown above which is captured in DeepSpeed.\n" + ) + print(self.model) + + self.model.apply(del_extra_repr) + + print( + "------------------------------------------------------------------------------" + ) + + if output_file: + sys.stdout = original_stdout + f.close() + + def print_model_aggregated_profile(self, module_depth=-1, top_modules=1): + """Prints the names of the top top_modules modules in terms of aggregated time, flops, and parameters + at depth module_depth. + + Args: + module_depth (int, optional): the depth of the modules to show. Defaults to -1 (the innermost modules). + top_modules (int, optional): the number of top modules to show. Defaults to 1. + """ + info = {} + if not hasattr(self.model, "__flops__"): + print( + "no __flops__ attribute in the model, call this function after start_profile and before end_profile" + ) + return + + def walk_module(module, curr_depth, info): + if curr_depth not in info: + info[curr_depth] = {} + if module.__class__.__name__ not in info[curr_depth]: + info[curr_depth][module.__class__.__name__] = [ + 0, + 0, + 0, + ] # macs, params, time + info[curr_depth][module.__class__.__name__][0] += get_module_macs(module) + info[curr_depth][module.__class__.__name__][1] += module.__params__ + info[curr_depth][module.__class__.__name__][2] += get_module_duration( + module + ) + has_children = len(module._modules.items()) != 0 + if has_children: + for child in module.children(): + walk_module(child, curr_depth + 1, info) + + walk_module(self.model, 0, info) + + depth = module_depth + if module_depth == -1: + depth = len(info) - 1 + + print( + f"Top {top_modules} modules in terms of params, MACs or fwd latency at different model depths:" + ) + + for d in range(depth): + num_items = min(top_modules, len(info[d])) + + sort_macs = { + k: macs_to_string(v[0]) + for k, v in sorted( + info[d].items(), key=lambda item: item[1][0], reverse=True + )[:num_items] + } + sort_params = { + k: params_to_string(v[1]) + for k, v in sorted( + info[d].items(), key=lambda item: item[1][1], reverse=True + )[:num_items] + } + sort_time = { + k: duration_to_string(v[2]) + for k, v in sorted( + info[d].items(), key=lambda item: item[1][2], reverse=True + )[:num_items] + } + + print(f"depth {d}:") + print(f" params - {sort_params}") + print(f" MACs - {sort_macs}") + print(f" fwd latency - {sort_time}") + + +def _prod(dims): + p = 1 + for v in dims: + p *= v + return p + + +def _linear_flops_compute(input, weight, bias=None): + out_features = weight.shape[0] + macs = torch.numel(input) * out_features + return 2 * macs, macs + + +def _relu_flops_compute(input, inplace=False): + return torch.numel(input), 0 + + +def _prelu_flops_compute(input: Tensor, weight: Tensor): + return torch.numel(input), 0 + + +def _elu_flops_compute(input: Tensor, alpha: float = 1.0, inplace: bool = False): + return torch.numel(input), 0 + + +def _leaky_relu_flops_compute( + input: Tensor, negative_slope: float = 0.01, inplace: bool = False +): + return torch.numel(input), 0 + + +def _relu6_flops_compute(input: Tensor, inplace: bool = False): + return torch.numel(input), 0 + + +def _silu_flops_compute(input: Tensor, inplace: bool = False): + return torch.numel(input), 0 + + +def _gelu_flops_compute(input, approximate=None): + return torch.numel(input), 0 + + +def _pool_flops_compute( + input, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return torch.numel(input), 0 + + +def _conv_flops_compute( + input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1 +): + assert weight.shape[1] * groups == input.shape[1] + + batch_size = input.shape[0] + in_channels = input.shape[1] + out_channels = weight.shape[0] + kernel_dims = list(weight.shape[2:]) + input_dims = list(input.shape[2:]) + + length = len(input_dims) + + paddings = padding if type(padding) is tuple else (padding,) * length + strides = stride if type(stride) is tuple else (stride,) * length + dilations = dilation if type(dilation) is tuple else (dilation,) * length + + output_dims = [] + for idx, input_dim in enumerate(input_dims): + output_dim = ( + input_dim + + 2 * paddings[idx] + - (dilations[idx] * (kernel_dims[idx] - 1) + 1) + ) // strides[idx] + 1 + output_dims.append(output_dim) + + filters_per_channel = out_channels // groups + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel + active_elements_count = batch_size * int(_prod(output_dims)) + overall_conv_macs = conv_per_position_macs * active_elements_count + overall_conv_flops = 2 * overall_conv_macs + + bias_flops = 0 + if bias is not None: + bias_flops = out_channels * active_elements_count + + return int(overall_conv_flops + bias_flops), int(overall_conv_macs) + + +def _conv_trans_flops_compute( + input, + weight, + bias=None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, +): + batch_size = input.shape[0] + in_channels = input.shape[1] + out_channels = weight.shape[0] + kernel_dims = list(weight.shape[2:]) + input_dims = list(input.shape[2:]) + + length = len(input_dims) + + paddings = padding if type(padding) is tuple else (padding,) * length + strides = stride if type(stride) is tuple else (stride,) * length + dilations = dilation if type(dilation) is tuple else (dilation,) * length + + output_dims = [] + for idx, input_dim in enumerate(input_dims): + + output_dim = ( + input_dim + + 2 * paddings[idx] + - (dilations[idx] * (kernel_dims[idx] - 1) + 1) + ) // strides[idx] + 1 + output_dims.append(output_dim) + + paddings = padding if type(padding) is tuple else (padding, padding) + strides = stride if type(stride) is tuple else (stride, stride) + dilations = dilation if type(dilation) is tuple else (dilation, dilation) + + filters_per_channel = out_channels // groups + conv_per_position_macs = int(_prod(kernel_dims)) * in_channels * filters_per_channel + active_elements_count = batch_size * int(_prod(input_dims)) + overall_conv_macs = conv_per_position_macs * active_elements_count + overall_conv_flops = 2 * overall_conv_macs + + bias_flops = 0 + if bias is not None: + bias_flops = out_channels * batch_size * int(_prod(output_dims)) + + return int(overall_conv_flops + bias_flops), int(overall_conv_macs) + + +def _batch_norm_flops_compute( + input, + running_mean, + running_var, + weight=None, + bias=None, + training=False, + momentum=0.1, + eps=1e-05, +): + has_affine = weight is not None + if training: + # estimation + return torch.numel(input) * (5 if has_affine else 4), 0 + flops = torch.numel(input) * (2 if has_affine else 1) + return flops, 0 + + +def _layer_norm_flops_compute( + input: Tensor, + normalized_shape: List[int], + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +): + has_affine = weight is not None + # estimation + return torch.numel(input) * (5 if has_affine else 4), 0 + + +def _group_norm_flops_compute( + input: Tensor, + num_groups: int, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +): + has_affine = weight is not None + # estimation + return torch.numel(input) * (5 if has_affine else 4), 0 + + +def _instance_norm_flops_compute( + input: Tensor, + running_mean: Optional[Tensor] = None, + running_var: Optional[Tensor] = None, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + use_input_stats: bool = True, + momentum: float = 0.1, + eps: float = 1e-5, +): + has_affine = weight is not None + # estimation + return torch.numel(input) * (5 if has_affine else 4), 0 + + +def _upsample_flops_compute( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + if size is not None: + if isinstance(size, tuple): + return int(_prod(size)), 0 + else: + return int(size), 0 + assert scale_factor is not None, "either size or scale_factor should be defined" + flops = torch.numel(input) + if isinstance(scale_factor, tuple) and len(scale_factor) == len(input): + flops * int(_prod(scale_factor)) + else: + flops * scale_factor ** len(input) + return flops, 0 + + +def _softmax_flops_compute(input, dim=None, _stacklevel=3, dtype=None): + return torch.numel(input), 0 + + +def _embedding_flops_compute( + input, + weight, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, +): + return 0, 0 + + +def _dropout_flops_compute(input, p=0.5, training=True, inplace=False): + return 0, 0 + + +def _matmul_flops_compute(input, other, *, out=None): + """ + Count flops for the matmul operation. + """ + macs = _prod(input.shape) * other.shape[-1] + # if torch.distributed.get_rank()==0: print(2*macs) + + return 2 * macs, macs + + +def _addmm_flops_compute(input, mat1, mat2, *, beta=1, alpha=1, out=None): + """ + Count flops for the addmm operation. + """ + macs = _prod(mat1.shape) * mat2.shape[-1] + return 2 * macs + _prod(input.shape), macs + + +def _baddbmm_flops_compute(input, mat1, mat2, *, beta=1, alpha=1, out=None): + """ + Count flops for the baddbmm operation. + """ + macs = _prod(mat1.shape) * mat2.shape[-1] + return 2 * macs + _prod(input.shape), macs + + +def _einsum_flops_compute(equation, *operands): + """ + Count flops for the einsum operation. + """ + equation = equation.replace(" ", "") + input_shapes = [o.shape for o in operands] + + # Re-map equation so that same equation with different alphabet + # representations will look the same. + letter_order = OrderedDict((k, 0) for k in equation if k.isalpha()).keys() + mapping = {ord(x): 97 + i for i, x in enumerate(letter_order)} + equation = equation.translate(mapping) + + np_arrs = [np.zeros(s) for s in input_shapes] + optim = np.einsum_path(equation, *np_arrs, optimize="optimal")[1] + for line in optim.split("\n"): + if "optimized flop" in line.lower(): + flop = int(float(line.split(":")[-1])) + return flop, 0 + raise NotImplementedError("Unsupported einsum operation.") + + +def _tensor_addmm_flops_compute(self, mat1, mat2, *, beta=1, alpha=1, out=None): + """ + Count flops for the tensor addmm operation. + """ + macs = _prod(mat1.shape) * mat2.shape[-1] + return 2 * macs + _prod(self.shape), macs + + +def _mul_flops_compute(input, other, *, out=None): + return _elementwise_flops_compute(input, other) + + +def _add_flops_compute(input, other, *, alpha=1, out=None): + return _elementwise_flops_compute(input, other) + + +def _elementwise_flops_compute(input, other): + if not torch.is_tensor(input): + if torch.is_tensor(other): + return _prod(other.shape), 0 + else: + return 1, 0 + elif not torch.is_tensor(other): + return _prod(input.shape), 0 + else: + dim_input = len(input.shape) + dim_other = len(other.shape) + max_dim = max(dim_input, dim_other) + + final_shape = [] + for i in range(max_dim): + in_i = input.shape[i] if i < dim_input else 1 + ot_i = other.shape[i] if i < dim_other else 1 + if in_i > ot_i: + final_shape.append(in_i) + else: + final_shape.append(ot_i) + flops = _prod(final_shape) + return flops, 0 + + +def wrapFunc(func, funcFlopCompute): + oldFunc = func + name = func.__str__ + func_name = func.__name__ + # print(name, oldFunc) + old_functions[name] = oldFunc + + def newFunc(*args, **kwds): + flops, macs = funcFlopCompute(*args, **kwds) + global func_flops # noqa: F824 # type: ignore + if module_flop_count: + if func_name not in func_flops: + func_flops[func_name] = flops + else: + func_flops[func_name] += flops + module_flop_count[-1].append((name, flops)) + if module_mac_count and macs: + module_mac_count[-1].append((name, macs)) + return oldFunc(*args, **kwds) + + newFunc.__str__ = func.__str__ + + return newFunc + + +def _patch_functionals(): + # FC + F.linear = wrapFunc(F.linear, _linear_flops_compute) + + # convolutions + F.conv1d = wrapFunc(F.conv1d, _conv_flops_compute) + F.conv2d = wrapFunc(F.conv2d, _conv_flops_compute) + F.conv3d = wrapFunc(F.conv3d, _conv_flops_compute) + + # conv transposed + F.conv_transpose1d = wrapFunc(F.conv_transpose1d, _conv_trans_flops_compute) + F.conv_transpose2d = wrapFunc(F.conv_transpose2d, _conv_trans_flops_compute) + F.conv_transpose3d = wrapFunc(F.conv_transpose3d, _conv_trans_flops_compute) + + # activations + F.relu = wrapFunc(F.relu, _relu_flops_compute) + F.prelu = wrapFunc(F.prelu, _prelu_flops_compute) + F.elu = wrapFunc(F.elu, _elu_flops_compute) + F.leaky_relu = wrapFunc(F.leaky_relu, _leaky_relu_flops_compute) + F.relu6 = wrapFunc(F.relu6, _relu6_flops_compute) + if hasattr(F, "silu"): + F.silu = wrapFunc(F.silu, _silu_flops_compute) + F.gelu = wrapFunc(F.gelu, _gelu_flops_compute) + + # Normalizations + F.batch_norm = wrapFunc(F.batch_norm, _batch_norm_flops_compute) + F.layer_norm = wrapFunc(F.layer_norm, _layer_norm_flops_compute) + F.instance_norm = wrapFunc(F.instance_norm, _instance_norm_flops_compute) + F.group_norm = wrapFunc(F.group_norm, _group_norm_flops_compute) + + # poolings + F.avg_pool1d = wrapFunc(F.avg_pool1d, _pool_flops_compute) + F.avg_pool2d = wrapFunc(F.avg_pool2d, _pool_flops_compute) + F.avg_pool3d = wrapFunc(F.avg_pool3d, _pool_flops_compute) + F.max_pool1d = wrapFunc(F.max_pool1d, _pool_flops_compute) + F.max_pool2d = wrapFunc(F.max_pool2d, _pool_flops_compute) + F.max_pool3d = wrapFunc(F.max_pool3d, _pool_flops_compute) + F.adaptive_avg_pool1d = wrapFunc(F.adaptive_avg_pool1d, _pool_flops_compute) + F.adaptive_avg_pool2d = wrapFunc(F.adaptive_avg_pool2d, _pool_flops_compute) + F.adaptive_avg_pool3d = wrapFunc(F.adaptive_avg_pool3d, _pool_flops_compute) + F.adaptive_max_pool1d = wrapFunc(F.adaptive_max_pool1d, _pool_flops_compute) + F.adaptive_max_pool2d = wrapFunc(F.adaptive_max_pool2d, _pool_flops_compute) + F.adaptive_max_pool3d = wrapFunc(F.adaptive_max_pool3d, _pool_flops_compute) + + # upsample + F.upsample = wrapFunc(F.upsample, _upsample_flops_compute) + F.interpolate = wrapFunc(F.interpolate, _upsample_flops_compute) + + # softmax + F.softmax = wrapFunc(F.softmax, _softmax_flops_compute) + + # embedding + F.embedding = wrapFunc(F.embedding, _embedding_flops_compute) + + +def _patch_tensor_methods(): + torch.matmul = wrapFunc(torch.matmul, _matmul_flops_compute) + torch.Tensor.matmul = wrapFunc(torch.Tensor.matmul, _matmul_flops_compute) + torch.mm = wrapFunc(torch.mm, _matmul_flops_compute) + torch.Tensor.mm = wrapFunc(torch.Tensor.mm, _matmul_flops_compute) + torch.bmm = wrapFunc(torch.bmm, _matmul_flops_compute) + torch.Tensor.bmm = wrapFunc(torch.Tensor.bmm, _matmul_flops_compute) + + torch.addmm = wrapFunc(torch.addmm, _addmm_flops_compute) + torch.Tensor.addmm = wrapFunc(torch.Tensor.addmm, _tensor_addmm_flops_compute) + + torch.mul = wrapFunc(torch.mul, _mul_flops_compute) + torch.Tensor.mul = wrapFunc(torch.Tensor.mul, _mul_flops_compute) + + torch.add = wrapFunc(torch.add, _add_flops_compute) + torch.Tensor.add = wrapFunc(torch.Tensor.add, _add_flops_compute) + + torch.einsum = wrapFunc(torch.einsum, _einsum_flops_compute) + + torch.baddbmm = wrapFunc(torch.baddbmm, _baddbmm_flops_compute) + + +def _reload_functionals(): + # torch.nn.functional does not support importlib.reload() + F.linear = old_functions[F.linear.__str__] + F.conv1d = old_functions[F.conv1d.__str__] + F.conv2d = old_functions[F.conv2d.__str__] + F.conv3d = old_functions[F.conv3d.__str__] + F.conv_transpose1d = old_functions[F.conv_transpose1d.__str__] + F.conv_transpose2d = old_functions[F.conv_transpose2d.__str__] + F.conv_transpose3d = old_functions[F.conv_transpose3d.__str__] + F.relu = old_functions[F.relu.__str__] + F.prelu = old_functions[F.prelu.__str__] + F.elu = old_functions[F.elu.__str__] + F.leaky_relu = old_functions[F.leaky_relu.__str__] + F.relu6 = old_functions[F.relu6.__str__] + if hasattr(F, "silu"): + F.silu = old_functions[F.silu.__str__] + F.gelu = old_functions[F.gelu.__str__] + F.batch_norm = old_functions[F.batch_norm.__str__] + F.layer_norm = old_functions[F.layer_norm.__str__] + F.instance_norm = old_functions[F.instance_norm.__str__] + F.group_norm = old_functions[F.group_norm.__str__] + F.avg_pool1d = old_functions[F.avg_pool1d.__str__] + F.avg_pool2d = old_functions[F.avg_pool2d.__str__] + F.avg_pool3d = old_functions[F.avg_pool3d.__str__] + F.max_pool1d = old_functions[F.max_pool1d.__str__] + F.max_pool2d = old_functions[F.max_pool2d.__str__] + F.max_pool3d = old_functions[F.max_pool3d.__str__] + F.adaptive_avg_pool1d = old_functions[F.adaptive_avg_pool1d.__str__] + F.adaptive_avg_pool2d = old_functions[F.adaptive_avg_pool2d.__str__] + F.adaptive_avg_pool3d = old_functions[F.adaptive_avg_pool3d.__str__] + F.adaptive_max_pool1d = old_functions[F.adaptive_max_pool1d.__str__] + F.adaptive_max_pool2d = old_functions[F.adaptive_max_pool2d.__str__] + F.adaptive_max_pool3d = old_functions[F.adaptive_max_pool3d.__str__] + F.upsample = old_functions[F.upsample.__str__] + F.interpolate = old_functions[F.interpolate.__str__] + F.softmax = old_functions[F.softmax.__str__] + F.embedding = old_functions[F.embedding.__str__] + + +def _reload_tensor_methods(): + torch.matmul = old_functions[torch.matmul.__str__] + torch.Tensor.matmul = old_functions[torch.Tensor.matmul.__str__] + torch.mm = old_functions[torch.mm.__str__] + torch.Tensor.mm = old_functions[torch.Tensor.mm.__str__] + torch.bmm = old_functions[torch.matmul.__str__] + torch.Tensor.bmm = old_functions[torch.Tensor.bmm.__str__] + torch.addmm = old_functions[torch.addmm.__str__] + torch.Tensor.addmm = old_functions[torch.Tensor.addmm.__str__] + torch.mul = old_functions[torch.mul.__str__] + torch.Tensor.mul = old_functions[torch.Tensor.mul.__str__] + torch.add = old_functions[torch.add.__str__] + torch.Tensor.add = old_functions[torch.Tensor.add.__str__] + + torch.einsum = old_functions[torch.einsum.__str__] + + torch.baddbmm = old_functions[torch.baddbmm.__str__] + + +def _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): + # matrix matrix mult ih state and internal state + flops += w_ih.shape[0] * w_ih.shape[1] + # matrix matrix mult hh state and internal state + flops += w_hh.shape[0] * w_hh.shape[1] + if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): + # add both operations + flops += rnn_module.hidden_size + elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): + # hadamard of r + flops += rnn_module.hidden_size + # adding operations from both states + flops += rnn_module.hidden_size * 3 + # last two hadamard _product and add + flops += rnn_module.hidden_size * 3 + elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): + # adding operations from both states + flops += rnn_module.hidden_size * 4 + # two hadamard _product and add for C state + flops += ( + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + ) + # final hadamard + flops += ( + rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size + ) + return flops + + +def _rnn_forward_hook(rnn_module, input, output): + flops = 0 + # input is a tuple containing a sequence to process and (optionally) hidden state + inp = input[0] + batch_size = inp.shape[0] + seq_length = inp.shape[1] + num_layers = rnn_module.num_layers + + for i in range(num_layers): + w_ih = rnn_module.__getattr__("weight_ih_l" + str(i)) + w_hh = rnn_module.__getattr__("weight_hh_l" + str(i)) + if i == 0: + input_size = rnn_module.input_size + else: + input_size = rnn_module.hidden_size + flops = _rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) + if rnn_module.bias: + b_ih = rnn_module.__getattr__("bias_ih_l" + str(i)) + b_hh = rnn_module.__getattr__("bias_hh_l" + str(i)) + flops += b_ih.shape[0] + b_hh.shape[0] + + flops *= batch_size + flops *= seq_length + if rnn_module.bidirectional: + flops *= 2 + rnn_module.__flops__ += int(flops) + + +def _rnn_cell_forward_hook(rnn_cell_module, input, output): + flops = 0 + inp = input[0] + batch_size = inp.shape[0] + w_ih = rnn_cell_module.__getattr__("weight_ih") + w_hh = rnn_cell_module.__getattr__("weight_hh") + input_size = inp.shape[1] + flops = _rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) + if rnn_cell_module.bias: + b_ih = rnn_cell_module.__getattr__("bias_ih") + b_hh = rnn_cell_module.__getattr__("bias_hh") + flops += b_ih.shape[0] + b_hh.shape[0] + + flops *= batch_size + rnn_cell_module.__flops__ += int(flops) + + +MODULE_HOOK_MAPPING = { + # RNN + nn.RNN: _rnn_forward_hook, + nn.GRU: _rnn_forward_hook, + nn.LSTM: _rnn_forward_hook, + nn.RNNCell: _rnn_cell_forward_hook, + nn.LSTMCell: _rnn_cell_forward_hook, + nn.GRUCell: _rnn_cell_forward_hook, +} + + +def num_to_string(num, precision=2): + if num // 10**9 > 0: + return str(round(num / 10.0**9, precision)) + " G" + elif num // 10**6 > 0: + return str(round(num / 10.0**6, precision)) + " M" + elif num // 10**3 > 0: + return str(round(num / 10.0**3, precision)) + " K" + else: + return str(num) + + +def macs_to_string(macs, units=None, precision=2): + if units is None: + if macs // 10**9 > 0: + return str(round(macs / 10.0**9, precision)) + " GMACs" + elif macs // 10**6 > 0: + return str(round(macs / 10.0**6, precision)) + " MMACs" + elif macs // 10**3 > 0: + return str(round(macs / 10.0**3, precision)) + " KMACs" + else: + return str(macs) + " MACs" + else: + if units == "GMACs": + return str(round(macs / 10.0**9, precision)) + " " + units + elif units == "MMACs": + return str(round(macs / 10.0**6, precision)) + " " + units + elif units == "KMACs": + return str(round(macs / 10.0**3, precision)) + " " + units + else: + return str(macs) + " MACs" + + +def number_to_string(num, units=None, precision=2): + if units is None: + if num // 10**9 > 0: + return str(round(num / 10.0**9, precision)) + " G" + elif num // 10**6 > 0: + return str(round(num / 10.0**6, precision)) + " M" + elif num // 10**3 > 0: + return str(round(num / 10.0**3, precision)) + " K" + else: + return str(num) + " " + else: + if units == "G": + return str(round(num / 10.0**9, precision)) + " " + units + elif units == "M": + return str(round(num / 10.0**6, precision)) + " " + units + elif units == "K": + return str(round(num / 10.0**3, precision)) + " " + units + else: + return str(num) + " " + + +def flops_to_string(flops, units=None, precision=2): + if units is None: + if flops // 10**12 > 0: + return str(round(flops / 10.0**12, precision)) + " TFLOPS" + if flops // 10**9 > 0: + return str(round(flops / 10.0**9, precision)) + " GFLOPS" + elif flops // 10**6 > 0: + return str(round(flops / 10.0**6, precision)) + " MFLOPS" + elif flops // 10**3 > 0: + return str(round(flops / 10.0**3, precision)) + " KFLOPS" + else: + return str(flops) + " FLOPS" + else: + if units == "TFLOPS": + return str(round(flops / 10.0**12, precision)) + " " + units + if units == "GFLOPS": + return str(round(flops / 10.0**9, precision)) + " " + units + elif units == "MFLOPS": + return str(round(flops / 10.0**6, precision)) + " " + units + elif units == "KFLOPS": + return str(round(flops / 10.0**3, precision)) + " " + units + else: + return str(flops) + " FLOPS" + + +def params_to_string(params_num, units=None, precision=2): + if units is None: + if params_num // 10**6 > 0: + return str(round(params_num / 10**6, 2)) + " M" + elif params_num // 10**3: + return str(round(params_num / 10**3, 2)) + " k" + else: + return str(params_num) + else: + if units == "M": + return str(round(params_num / 10.0**6, precision)) + " " + units + elif units == "K": + return str(round(params_num / 10.0**3, precision)) + " " + units + else: + return str(params_num) + + +def duration_to_string(duration, units=None, precision=2): + if units is None: + if duration > 1: + return str(round(duration, precision)) + " s" + elif duration * 10**3 > 1: + return str(round(duration * 10**3, precision)) + " ms" + elif duration * 10**6 > 1: + return str(round(duration * 10**6, precision)) + " us" + else: + return str(duration) + else: + if units == "us": + return str(round(duration * 10.0**6, precision)) + " " + units + elif units == "ms": + return str(round(duration * 10.0**3, precision)) + " " + units + else: + return str(round(duration, precision)) + " s" + + # can not iterate over all submodules using self.model.modules() + # since modules() returns duplicate modules only once + + +def get_module_flops(module): + sum = module.__flops__ + # iterate over immediate children modules + for child in module.children(): + sum += get_module_flops(child) + return sum + + +def get_module_macs(module): + sum = module.__macs__ + # iterate over immediate children modules + for child in module.children(): + sum += get_module_macs(child) + return sum + + +def get_module_duration(module): + duration = module.__duration__ + if duration == 0: # e.g. ModuleList + for m in module.children(): + duration += m.__duration__ + return duration + + +def get_model_profile( + model, + input_shape=None, + args=[], + kwargs={}, + print_profile=True, + detailed=True, + module_depth=-1, + top_modules=1, + warm_up=1, + as_string=True, + output_file=None, + ignore_modules=None, +): + """Returns the total floating-point operations, MACs, and parameters of a model. + + Example: + + .. code-block:: python + + model = torchvision.models.alexnet() + batch_size = 256 + flops, macs, params = get_model_profile(model=model, input_shape=(batch_size, 3, 224, 224))) + + Args: + model ([torch.nn.Module]): the PyTorch model to be profiled. + input_shape (tuple): input shape to the model. If specified, the model takes a tensor with this shape as the only positional argument. + args (list): list of positional arguments to the model. + kwargs (dict): dictionary of keyword arguments to the model. + print_profile (bool, optional): whether to print the model profile. Defaults to True. + detailed (bool, optional): whether to print the detailed model profile. Defaults to True. + module_depth (int, optional): the depth into the nested modules. Defaults to -1 (the inner most modules). + top_modules (int, optional): the number of top modules to print in the aggregated profile. Defaults to 3. + warm_up (int, optional): the number of warm-up steps before measuring the latency of each module. Defaults to 1. + as_string (bool, optional): whether to print the output as string. Defaults to True. + output_file (str, optional): path to the output file. If None, the profiler prints to stdout. + ignore_modules ([type], optional): the list of modules to ignore during profiling. Defaults to None. + + Returns: + The number of floating-point operations, multiply-accumulate operations (MACs), and parameters in the model. + """ + assert isinstance(model, nn.Module), "model must be a PyTorch module" + prof = FlopsProfiler(model) + model.eval() + + if input_shape is not None: + assert type(input_shape) is tuple, "input_shape must be a tuple" + assert len(input_shape) >= 1, "input_shape must have at least one element" + try: + input = torch.ones(()).new_empty( + (*input_shape,), + dtype=next(model.parameters()).dtype, + device=next(model.parameters()).device, + ) + except StopIteration: + input = torch.ones(()).new_empty((*input_shape,)) + + args = [input] + + assert (len(args) > 0) or ( + len(kwargs) > 0 + ), "args and/or kwargs must be specified if input_shape is None" + + for _ in range(warm_up): + _ = model(*args, **kwargs) + + prof.start_profile(ignore_list=ignore_modules) + + _ = model(*args, **kwargs) + + flops = prof.get_total_flops() + macs = prof.get_total_macs() + params = prof.get_total_params() + if print_profile: + prof.print_model_profile( + profile_step=warm_up, + module_depth=module_depth, + top_modules=top_modules, + detailed=detailed, + output_file=output_file, + ) + + prof.end_profile() + if as_string: + return number_to_string(flops), macs_to_string(macs), params_to_string(params) + + return flops, macs, params diff --git a/torchtitan/experiments/autopartition/job_config.py b/torchtitan/experiments/autopartition/job_config.py new file mode 100644 index 0000000000..063a23b905 --- /dev/null +++ b/torchtitan/experiments/autopartition/job_config.py @@ -0,0 +1,17 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class CustomConfig: + auto_partition: bool = True + """Whether to use autopartition method to split module, default False""" + +@dataclass +class JobConfig: + custom_config: CustomConfig = field(default_factory=CustomConfig) diff --git a/torchtitan/experiments/autopartition/llama3/args.py b/torchtitan/experiments/autopartition/llama3/args.py new file mode 100644 index 0000000000..d83fb83102 --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3/args.py @@ -0,0 +1,71 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + + +from dataclasses import dataclass, field + +from torch import nn + +from torchtitan.config import JobConfig +from torchtitan.models.utils import get_dense_model_nparams_and_flops +from torchtitan.protocols.model import BaseModelArgs +from torchtitan.tools.logging import logger + + +@dataclass +class RoPEScalingArgs: + scaling_factor: float = 8.0 + low_freq_factor: float = 1.0 + high_freq_factor: float = 4.0 + original_max_position_embeddings: int = 8192 + + +@dataclass +class TransformerModelArgs(BaseModelArgs): + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int | None = None + vocab_size: int = 128256 + multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2 + ffn_dim_multiplier: float | None = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + rope_scaling_args: RoPEScalingArgs = field(default_factory=RoPEScalingArgs) + + max_seq_len: int = 131072 + # If `True`, then each transformer block init uses its layer ID, and if + # `False`, each uses the total number of transformer blocks + depth_init: bool = True + + use_flex_attn: bool = False + attn_mask_type: str = "causal" + eos_id: int = 0 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + raise NotImplementedError( + "CP support for FlexAttention is still in progress." + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + return get_dense_model_nparams_and_flops( + self, + model, + 2 * (self.dim // self.n_heads), + seq_len, + ) diff --git a/torchtitan/experiments/autopartition/llama3/model.py b/torchtitan/experiments/autopartition/llama3/model.py new file mode 100644 index 0000000000..124153f14c --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3/model.py @@ -0,0 +1,503 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import math + +import torch +import torch.nn.functional as F +from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask + +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + ScaledDotProductAttentionWrapper, +) +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import RoPEScalingArgs, TransformerModelArgs + + +def precompute_freqs_cis( + dim: int, + end: int, + theta: float = 10000.0, + scaling_args: RoPEScalingArgs = RoPEScalingArgs(), +) -> torch.Tensor: + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + end (int): End index for precomputing frequencies. + theta (float | None): Scaling factor for frequency computation. Defaults to 10000.0. + scaling_args (RoPEScalingArgs | None): RoPE scaling arguments. Defaults to None. + scaling_factor (float): RoPE scaling multiplier; larger values + stretch positions to support longer contexts. Defaults to 8.0. + low_freq_factor (float): Extra scaling applied to the low-frequency + (long-wavelength) RoPE bands. Defaults to 1.0. + high_freq_factor (float): Extra scaling applied to the high-frequency + (short-wavelength) RoPE bands. Defaults to 4.0. + original_max_position_embeddings (int): Maximum position embeddings + for original model. Defaults to 8192. + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. + """ + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + + # apply rope scaling + scaling_factor = scaling_args.scaling_factor + low_freq_factor = scaling_args.low_freq_factor + high_freq_factor = scaling_args.high_freq_factor + original_max_position_embeddings = scaling_args.original_max_position_embeddings + wavelen = 2 * math.pi / freqs + high_freq_wavelen = original_max_position_embeddings / high_freq_factor + low_freq_wavelen = original_max_position_embeddings / low_freq_factor + # wavelen < high_freq_wavelen: do nothing + # wavelen > low_freq_wavelen: divide by scaling factor + freqs = torch.where(wavelen > low_freq_wavelen, freqs / scaling_factor, freqs) + # wavelen in between: linear interpolation of the scaled freqs and the original freqs + smooth_factor = (original_max_position_embeddings / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smoothed_freqs = ( + 1 - smooth_factor + ) * freqs / scaling_factor + smooth_factor * freqs + is_medium_freqs = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen) + freqs = torch.where(is_medium_freqs, smoothed_freqs, freqs) + + t = torch.arange(end, device=freqs.device) + freqs = torch.outer(t, freqs).float() + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. + xk (torch.Tensor): Key tensor to apply rotary embeddings. + freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + """ + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention module. + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_kv_heads (int): Number of key and value heads. + n_heads (int): Number of query heads. + n_rep (int): Number of repetitions for local heads. + head_dim (int): Dimension size of each attention head. + wq (Linear): Linear transformation for queries. + wk (Linear): Linear transformation for keys. + wv (Linear): Linear transformation for values. + wo (Linear): Linear transformation for output. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.n_kv_heads = ( + model_args.n_heads + if model_args.n_kv_heads is None + else model_args.n_kv_heads + ) + self.n_rep = self.n_heads // self.n_kv_heads + self.head_dim = model_args.dim // model_args.n_heads + + self.wq = nn.Linear( + model_args.dim, model_args.n_heads * self.head_dim, bias=False + ) + self.wk = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wv = nn.Linear(model_args.dim, self.n_kv_heads * self.head_dim, bias=False) + self.wo = nn.Linear( + model_args.n_heads * self.head_dim, model_args.dim, bias=False + ) + + self.use_flex_attn = model_args.use_flex_attn + if self.use_flex_attn: + self.inner_attention = FlexAttentionWrapper() + else: + self.inner_attention = ScaledDotProductAttentionWrapper() + + def init_weights(self, init_std: float): + for linear in (self.wq, self.wk, self.wv): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass of the attention module. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed frequency tensor. + + Returns: + torch.Tensor: Output tensor after attention. + + """ + + bs, seqlen, _ = x.shape + xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) + + # Use -1 instead of `n_heads` (or `n_kv_heads`) to infer the actual + # local heads from sizes of xq, xk, and xv as TP may have sharded them + # after the above linear ops. + xq = xq.view(bs, seqlen, -1, self.head_dim) + xk = xk.view(bs, seqlen, -1, self.head_dim) + xv = xv.view(bs, seqlen, -1, self.head_dim) + + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + values = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) + + xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) + + assert ( + isinstance(attention_masks, BlockMask) or attention_masks is None + ), attention_masks + + if self.use_flex_attn: + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + else: + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) + + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + output = output.view(bs, seqlen, -1) + return self.wo(output) + + +class FeedForward(nn.Module): + """ + FeedForward module + + Args: + dim (int): Input dimension. + hidden_dim (int): Hidden dimension of the feedforward layer. + multiple_of (int): Value to ensure hidden dimension is a multiple of this value. + ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None. + + Attributes: + w1 (Linear): Linear transformation for the first layer. + w2 (Linear): Linear transformation for the second layer. + w3 (Linear): Linear transformation for the third layer. + + """ + + def __init__( + self, + dim: int, + hidden_dim: int, + multiple_of: int, + ffn_dim_multiplier: float | None, + ): + super().__init__() + hidden_dim = int(2 * hidden_dim / 3) + # custom dim factor multiplier + if ffn_dim_multiplier is not None: + hidden_dim = int(ffn_dim_multiplier * hidden_dim) + hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) + + self.w1 = nn.Linear(dim, hidden_dim, bias=False) + self.w2 = nn.Linear(hidden_dim, dim, bias=False) + self.w3 = nn.Linear(dim, hidden_dim, bias=False) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + for linear in (self.w2, self.w3): + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + + +class TransformerBlock(nn.Module): + """ + TransformerBlock Module + + Args: + layer_id (int): Identifier for the layer. + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + n_heads (int): Number of attention heads. + dim (int): Dimension size of the model. + head_dim (int): Dimension size of each attention head. + attention (Attention): Attention module. + feed_forward (FeedForward): FeedForward module. + layer_id (int): Identifier for the layer. + attention_norm (RMSNorm): Layer normalization for attention output. + ffn_norm (RMSNorm): Layer normalization for feedforward output. + + """ + + def __init__(self, layer_id: int, model_args: TransformerModelArgs): + super().__init__() + self.n_heads = model_args.n_heads + self.dim = model_args.dim + self.attention = Attention(model_args) + self.feed_forward = FeedForward( + dim=model_args.dim, + hidden_dim=4 * model_args.dim, + multiple_of=model_args.multiple_of, + ffn_dim_multiplier=model_args.ffn_dim_multiplier, + ) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + if model_args.depth_init: + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + else: + self.weight_init_std = 0.02 / (2 * model_args.n_layers) ** 0.5 + + def forward( + self, + x: torch.Tensor, + freqs_cis: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Perform a forward pass through the TransformerBlock. + + Args: + x (torch.Tensor): Input tensor. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + Returns: + torch.Tensor: Output tensor after applying attention and feedforward layers. + + """ + h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + def init_weights(self): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.feed_forward.init_weights(self.weight_init_std) + + +class Transformer(nn.Module, ModelProtocol): + """ + Transformer Module + + Args: + model_args (TransformerModelArgs): Model configuration arguments. + + Attributes: + model_args (TransformerModelArgs): Model configuration arguments. + vocab_size (int): Vocabulary size. + n_layers (int): Number of layers in the model. + tok_embeddings (ParallelEmbedding): Token embeddings. + layers (torch.nn.ModuleList): List of Transformer blocks. + norm (RMSNorm): Layer normalization for the model output. + output (Linear): Linear layer for final output. + freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + + """ + + def __init__(self, model_args: TransformerModelArgs): + super().__init__() + self.model_args = model_args + self.vocab_size = model_args.vocab_size + self.n_layers = model_args.n_layers + + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + + self.register_buffer( + "freqs_cis", self._precompute_freqs_cis(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args) + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear(model_args.dim, model_args.vocab_size, bias=False) + + def init_weights( + self, + buffer_device: torch.device | None = None, + ): + """ + [Note: On ``init_weights`` vs. ``reset_parameters``] + Modules may define ``reset_parameters`` to initialize parameter values. + ``reset_parameters`` is meant to only initialize directly owned + parameters/buffers, not those of their child modules, and it can be + used to give the initial values for these tensors. + Separately, users may want custom initialization for their modules, + different from that in ``reset_parameters``. For this, we define + ``init_weights``. We only call it in the constructor of this + ``Transformer`` root module to avoid reinitializing tensors. + """ + buffer_device = buffer_device or self.freqs_cis.device + with torch.device(buffer_device): + self.freqs_cis = self._precompute_freqs_cis() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights() + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_freqs_cis(self) -> torch.Tensor: + return precompute_freqs_cis( + self.model_args.dim // self.model_args.n_heads, + # Need to compute until at least the max token limit for generation + # TODO: explain in docs/composability.md why we removed the 2x + # relaxing in our CP enablement PR + self.model_args.max_seq_len, + self.model_args.rope_theta, + self.model_args.rope_scaling_args, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + case "block_causal": + B = input_batch.shape[0] + mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + return create_attention_mask( + and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] + ) + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType | None = None, + ): + """ + Perform a forward pass through the Transformer model. + + Args: + tokens (torch.Tensor): Input token indices if pipeline parallelism is not enabled. + If pipeline parallelism is enabled, this will be the input token indices + for the ranks on the first pipeline stage. This will be the activation of the + previous pipeline stage if the current rank is not on the first stage. + + Returns: + torch.Tensor: Output logits after applying the Transformer model. + + """ + # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens + + for layer in self.layers.values(): + h = layer(h, self.freqs_cis, attention_masks=attention_masks) + + h = self.norm(h) if self.norm else h + output = self.output(h) if self.output else h + return output diff --git a/torchtitan/experiments/autopartition/llama3/state_dict_adapter.py b/torchtitan/experiments/autopartition/llama3/state_dict_adapter.py new file mode 100644 index 0000000000..2c386ece0d --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3/state_dict_adapter.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import re +from typing import Any + +logger = logging.getLogger() + +from torchtitan.protocols.state_dict_adapter import StateDictAdapter + +from .args import TransformerModelArgs + + +class Llama3StateDictAdapter(StateDictAdapter): + def __init__( + self, + model_args: TransformerModelArgs, + hf_assets_path: str | None, + ): + super().__init__(model_args, hf_assets_path) + + self.model_args = model_args + self.hf_assets_path = hf_assets_path + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.rotary_emb.inv_freq": None, + "model.layers.{}.mlp.gate_proj.weight": "layers.{}.feed_forward.w1.weight", + "model.layers.{}.mlp.up_proj.weight": "layers.{}.feed_forward.w3.weight", + "model.layers.{}.mlp.down_proj.weight": "layers.{}.feed_forward.w2.weight", + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + # HuggingFace permutation function (exact copy from their conversion script) + def _permute(self, w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return ( + w.view(n_heads_arg, dim1 // n_heads_arg // 2, 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + .clone() + ) + + def _reverse_permute(self, w, n_heads_arg, dim1=None, dim2=None): + if dim1 is None: + dim1 = w.shape[0] + if dim2 is None: + dim2 = w.shape[1] + return ( + w.view(n_heads_arg, 2, dim1 // n_heads_arg // 2, dim2) + .transpose(1, 2) + .reshape(dim1, dim2) + ) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + + n_heads = self.model_args.n_heads + n_kv_heads = ( + self.model_args.n_kv_heads + if self.model_args.n_kv_heads is not None + else n_heads + ) + dim = self.model_args.dim + head_dim = dim // n_heads + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = to_hf_map[abstract_key] + # We need to permute the weights in wq and wk layer in order to account for the difference between + # the native Llama and huggingface RoPE implementation. + if abstract_key == "layers.{}.attention.wq.weight": + value = self._permute(value, n_heads) + if abstract_key == "layers.{}.attention.wk.weight": + key_value_dim = head_dim * n_kv_heads + value = self._permute(value, n_kv_heads, key_value_dim, dim) + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = to_hf_map[key] + + hf_state_dict[new_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + n_heads = self.model_args.n_heads + n_kv_heads = ( + self.model_args.n_kv_heads + if self.model_args.n_kv_heads is not None + else n_heads + ) + dim = self.model_args.dim + head_dim = dim // n_heads + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + layer_num = re.search(r"\d+", key).group(0) + new_key = self.from_hf_map[abstract_key] + + # We need to permute the weights in wq and wk layer in order to account for the difference between + # the native Llama and huggingface RoPE implementation. + if abstract_key == "model.layers.{}.self_attn.q_proj.weight": + value = self._reverse_permute(value, n_heads) + if abstract_key == "model.layers.{}.self_attn.k_proj.weight": + key_value_dim = head_dim * n_kv_heads + value = self._reverse_permute(value, n_kv_heads, key_value_dim, dim) + + if new_key is None: + continue + new_key = new_key.format(layer_num) + else: + new_key = self.from_hf_map[key] + + state_dict[new_key] = value + return state_dict diff --git a/torchtitan/experiments/autopartition/llama3_tain_spec.py b/torchtitan/experiments/autopartition/llama3_tain_spec.py new file mode 100644 index 0000000000..ca861ec3f7 --- /dev/null +++ b/torchtitan/experiments/autopartition/llama3_tain_spec.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize_llama import parallelize_llama +from .infra.pipeline_parallel import pipeline_llm +from .llama3.args import TransformerModelArgs +from .llama3.model import Transformer +from .llama3.state_dict_adapter import Llama3StateDictAdapter + +__all__ = [ + "parallelize_llama", + "TransformerModelArgs", + "Transformer", + "llama3_args", +] + + +llama3_args = { + "debugmodel": TransformerModelArgs( + dim=4096, n_layers=16, n_heads=16, vocab_size=2048, rope_theta=500000 + ), + "debugmodel_flex_attn": TransformerModelArgs( + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2048, + rope_theta=500000, + use_flex_attn=True, + attn_mask_type="block_causal", + ), + "8B": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=1024, + rope_theta=500000, + ), + "70B": TransformerModelArgs( + dim=8192, + n_layers=80, + n_heads=64, + n_kv_heads=8, + ffn_dim_multiplier=1.3, + multiple_of=4096, + rope_theta=500000, + ), + "405B": TransformerModelArgs( + dim=16384, + n_layers=126, + n_heads=128, + n_kv_heads=8, + ffn_dim_multiplier=1.2, + multiple_of=4096, + rope_theta=500000, + ), +} + + +def get_llama3_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=Transformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autopartition/train.py b/torchtitan/experiments/autopartition/train.py new file mode 100644 index 0000000000..d8d0c4a9c4 --- /dev/null +++ b/torchtitan/experiments/autopartition/train.py @@ -0,0 +1,363 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import os + +import torch +from torch.distributed.elastic.multiprocessing.errors import record + +from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.ft import FTManager +from torchtitan.components.loss import rescale_accumulated_loss +from torchtitan.components.metrics import ( + build_metrics_processor, + ensure_pp_loss_visible, +) +from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import utils as dist_utils +from torchtitan.protocols.model_converter import build_model_converters +from torchtitan.tools import utils +from torchtitan.tools.logging import init_logger, logger +from torchtitan.train import Trainer +from . import ( # noqa: F401 # type: ignore + get_deepseek_v3_train_spec, + get_llama3_train_spec, +) + + +class AotoPartitionTrainer(Trainer): + + # Enable debug tracing on failure: https://pytorch.org/docs/stable/elastic/errors.html + @record + def __init__(self, job_config: JobConfig): + torch._C._log_api_usage_once("torchtitan.train") + + self.job_config = job_config + + logger.info(f"Starting job: {job_config.job.description}") + + if job_config.experimental.custom_import: + importlib.import_module(job_config.experimental.custom_import) + + device_module, device_type = utils.device_module, utils.device_type + self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") + # Device has to be set before creating TorchFT manager. + device_module.set_device(self.device) + + job_config.maybe_log() + + # init distributed and build meshes + self.parallel_dims = parallel_dims = self.init_distributed() + + world_mesh = parallel_dims.world_mesh + if parallel_dims.dp_enabled: + dp_mesh = world_mesh["dp"] + dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + else: + dp_degree, dp_rank = 1, 0 + + self.ft_manager = FTManager(job_config.fault_tolerance) + dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + + # take control of garbage collection to avoid stragglers + self.gc_handler = utils.GarbageCollection( + gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug + ) + + # Set random seed, and maybe enable deterministic mode + # (mainly for debugging, expect perf loss). + dist_utils.set_determinism( + world_mesh, + self.device, + job_config.debug, + distinct_seed_mesh_dims=["pp"], + ) + self.train_spec = get_llama3_train_spec() + + # build tokenizer and dataloader + self.tokenizer = ( + self.train_spec.build_tokenizer_fn(job_config) + if self.train_spec.build_tokenizer_fn is not None + else None + ) + + self.dataloader = self.train_spec.build_dataloader_fn( + dp_world_size=dp_degree, + dp_rank=dp_rank, + tokenizer=self.tokenizer, + job_config=job_config, + ) + + # build model (using meta init) + model_args = self.train_spec.model_args[job_config.model.flavor] + # set the model args from training job configs + model_args.update_from_config(job_config) + self.model_args = model_args + + logger.info( + f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" + ) + with ( + torch.device("meta"), + utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), + ): + model = self.train_spec.model_cls(model_args) + + # Build the collection of model converters. No-op if `model.converters` empty + model_converters = build_model_converters(job_config, parallel_dims) + model_converters.convert(model) + + # metrics logging + build_metrics_processor_fn = ( + build_metrics_processor + if self.train_spec.build_metrics_processor_fn is None + else self.train_spec.build_metrics_processor_fn + ) + self.metrics_processor = build_metrics_processor_fn( + job_config, parallel_dims, model_args + ) + color = self.metrics_processor.color + + # calculate model size and flops per token + ( + model_param_count, + self.metrics_processor.num_flops_per_token, + ) = model_args.get_nparams_and_flops(model, job_config.training.seq_len) + + logger.info( + f"{color.blue}Model {job_config.model.name} {job_config.model.flavor} " + f"{color.red}size: {model_param_count:,} total parameters{color.reset}" + ) + + # move sharded model to CPU/GPU and initialize weights via DTensor + if job_config.checkpoint.create_seed_checkpoint: + init_device = "cpu" + buffer_device = None + elif job_config.training.enable_cpu_offload: + init_device = "cpu" + buffer_device = device_type + else: + init_device = device_type + buffer_device = None + + self.loss_fn = self.train_spec.build_loss_fn( + job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager + ) + + # verify batch sizes + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + global_batch_size = job_config.training.local_batch_size * dp_degree + assert global_batch_size > 0 + assert ( + global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + ) + + # calculate gradient accumulation steps + self.gradient_accumulation_steps = global_batch_size // ( + job_config.training.local_batch_size * dp_degree + ) + assert self.gradient_accumulation_steps > 0 + self.loss_fn = rescale_accumulated_loss( + self.loss_fn, self.gradient_accumulation_steps + ) + + # apply parallelisms and initialization + if parallel_dims.pp_enabled: + if not self.train_spec.pipelining_fn: + raise RuntimeError( + f"Pipeline Parallel is enabled but {job_config.model.name} " + f"does not support pipelining" + ) + + # apply both PT-D Pipeline Parallel and SPMD-style PT-D techniques + ( + self.pp_schedule, + self.model_parts, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) = self.train_spec.pipelining_fn( + model, + parallel_dims, + job_config, + self.device, + model_args, + self.train_spec.parallelize_fn, + self.loss_fn, + ) + # when PP is enabled, `model` obj is no longer used after this point, + # model_parts is used instead + del model + + for m in self.model_parts: + m.to_empty(device=init_device) + with torch.no_grad(): + m.init_weights(buffer_device=buffer_device) + m.train() + + # confirm that user will be able to view loss metrics on the console + ensure_pp_loss_visible(parallel_dims, job_config, color) + else: + # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel + model = self.train_spec.parallelize_fn(model, parallel_dims, job_config) + + model.to_empty(device=init_device) + with torch.no_grad(): + model.init_weights(buffer_device=buffer_device) + model.train() + + self.model_parts = [model] + + self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) + + # initialize device memory monitor and get peak flops for MFU calculation + device_memory_monitor = self.metrics_processor.device_memory_monitor + gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) + logger.info(f"Peak FLOPS used for computing MFU: {gpu_peak_flops:.3e}") + device_mem_stats = device_memory_monitor.get_peak_stats() + logger.info( + f"{device_type.upper()} memory usage for model: " + f"{device_mem_stats.max_reserved_gib:.2f}GiB" + f"({device_mem_stats.max_reserved_pct:.2f}%)" + ) + + # build optimizer after applying parallelisms to the model + self.optimizers = self.train_spec.build_optimizers_fn( + self.model_parts, job_config.optimizer, parallel_dims, self.ft_manager + ) + self.lr_schedulers = self.train_spec.build_lr_schedulers_fn( + self.optimizers, job_config.lr_scheduler, job_config.training.steps + ) + # Post optimizer step model converters hook. + # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 + # where it issues a single all-reduce for all parameters at once for better performance + self.optimizers.register_step_post_hook( + lambda *args, **kwargs: model_converters.post_optimizer_hook( + self.model_parts + ) + ) + self.metrics_processor.optimizers = self.optimizers + self.metrics_processor.model_parts = self.model_parts + + # Initialize trainer states that will be saved in checkpoint. + # These attributes must be initialized before checkpoint loading. + self.step = 0 + self.ntokens_seen = 0 + + self.checkpointer = CheckpointManager( + dataloader=self.dataloader, + model_parts=self.model_parts, + optimizers=self.optimizers, + lr_schedulers=self.lr_schedulers, + states={"train_state": self}, + checkpoint_config=job_config.checkpoint, + sd_adapter=( + self.train_spec.state_dict_adapter( + model_args, job_config.model.hf_assets_path + ) + if self.train_spec.state_dict_adapter + else None + ), + base_folder=job_config.job.dump_folder, + ft_manager=self.ft_manager, + ) + + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + self.train_context = dist_utils.get_train_context(loss_parallel_enabled) + self.maybe_enable_amp = dist_utils.maybe_enable_amp( + parallel_dims, + job_config.training.mixed_precision_param, + device_type, + ) + + # Build validator if validation is configured + if job_config.validation.enable: + assert self.train_spec.build_validator_fn is not None + + pp_schedule, pp_has_first_stage, pp_has_last_stage = ( + ( + self.pp_schedule, + self.pp_has_first_stage, + self.pp_has_last_stage, + ) + if parallel_dims.pp_enabled + else (None, None, None) + ) + + self.validator = self.train_spec.build_validator_fn( + job_config=job_config, + dp_world_size=dp_degree, + dp_rank=dp_rank, + tokenizer=self.tokenizer, + parallel_dims=parallel_dims, + loss_fn=self.loss_fn, + validation_context=self.train_context, + maybe_enable_amp=self.maybe_enable_amp, + metrics_processor=self.metrics_processor, + pp_schedule=pp_schedule, + pp_has_first_stage=pp_has_first_stage, + pp_has_last_stage=pp_has_last_stage, + ) + + logger.info( + "Trainer is initialized with " + f"local batch size {job_config.training.local_batch_size}, " + f"global batch size {global_batch_size}, " + f"gradient accumulation steps {self.gradient_accumulation_steps}, " + f"sequence length {job_config.training.seq_len}, " + f"total steps {job_config.training.steps} " + f"(warmup {job_config.lr_scheduler.warmup_steps})" + ) + + +def main(trainer_class: type[Trainer]) -> None: + """Main entry point for training with a specified trainer class. + + Args: + trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer) + """ + init_logger() + config_manager = ConfigManager() + config = config_manager.parse_args() + trainer: Trainer | None = None + + try: + trainer = trainer_class(config) + + if config.checkpoint.create_seed_checkpoint: + assert ( + int(os.environ["WORLD_SIZE"]) == 1 + ), "Must create seed checkpoint using a single device, to disable sharding." + assert ( + config.checkpoint.enable + ), "Must enable checkpointing when creating a seed checkpoint." + trainer.checkpointer.save(curr_step=0, last_step=True) + logger.info("Created seed checkpoint") + else: + trainer.train() + except Exception: + if trainer: + trainer.close() + raise + else: + trainer.close() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + logger.info("Process group destroyed") + + +if __name__ == "__main__": + main(AotoPartitionTrainer) diff --git a/torchtitan/experiments/autopartition/train_configs/debug_model.toml b/torchtitan/experiments/autopartition/train_configs/debug_model.toml new file mode 100644 index 0000000000..1b0cd08848 --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/debug_model.toml @@ -0,0 +1,81 @@ +[job] +dump_folder = "./outputs" +description = "Llama 3 debug training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "llama3" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 32 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 4 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +pipeline_parallel_microbatch_size = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/autopartition/train_configs/debug_model_deepseekv3.toml b/torchtitan/experiments/autopartition/train_configs/debug_model_deepseekv3.toml new file mode 100644 index 0000000000..09cf0e5ac7 --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/debug_model_deepseekv3.toml @@ -0,0 +1,79 @@ +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 debug training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 4 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] diff --git a/torchtitan/experiments/autopartition/train_configs/llama3_405b.toml b/torchtitan/experiments/autopartition/train_configs/llama3_405b.toml new file mode 100644 index 0000000000..48c669e404 --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/llama3_405b.toml @@ -0,0 +1,70 @@ +# NOTE: this toml config is a preset for 128 H100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 405B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "405B" +hf_assets_path = "./assets/hf/Llama-3.1-405B" +converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-5 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps + +[training] +local_batch_size = 2 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 3000 +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +enable_async_tensor_parallel = true +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] + +[compile] +enable=true +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = true +precompute_float8_dynamic_scale_for_fsdp = true +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 diff --git a/torchtitan/experiments/autopartition/train_configs/llama3_70b.toml b/torchtitan/experiments/autopartition/train_configs/llama3_70b.toml new file mode 100644 index 0000000000..37fd35b5cb --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/llama3_70b.toml @@ -0,0 +1,69 @@ +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 70B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "70B" +hf_assets_path = "./assets/hf/Llama-3.1-70B" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 1.5e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps + +[training] +local_batch_size = 8 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 1000 +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 8 # 8-way TP +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "full" + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192 diff --git a/torchtitan/experiments/autopartition/train_configs/llama3_8b.toml b/torchtitan/experiments/autopartition/train_configs/llama3_8b.toml new file mode 100644 index 0000000000..ef86d783bf --- /dev/null +++ b/torchtitan/experiments/autopartition/train_configs/llama3_8b.toml @@ -0,0 +1,70 @@ +# NOTE: this toml config is a preset for 64 A100 GPUs. + +[job] +dump_folder = "./outputs" +description = "Llama 3 8B training" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 10 +enable_tensorboard = true +save_tb_folder = "tb" + +[model] +name = "llama3" +flavor = "8B" +hf_assets_path = "./assets/hf/Llama-3.1-8B" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 3e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 200 # lr scheduler warm up + +[training] +local_batch_size = 1 +seq_len = 8192 +max_norm = 1.0 # grad norm clipping +steps = 1000 +dataset = "c4" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[compile] +enable=false +components = ["model", "loss"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 500 +steps = 1200 # Recommend value for c4_validation with world-size=8 and seq_len=8192