Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions torchtitan/experiments/autopartition/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Auto-Partition in torchtitan

## Overview

This folder provides an automatic partitioning method that considers the computation cost of embedding layers.
This method involves calculating the floating-point operations (FLOPs) of the embedding layers and constructing an array that incorporates the FLOPs of both the transformer and embedding layers. Subsequently, a heuristic algorithm is employed to identify a balanced pipeline partition.

## Quick Start

### Compile

First, we need to compile `autopipe.cpp`.
```bash
pip install pybind11
cd ./torchtitan/experiments/autopartition/infra/cpp
mkdir build
cd build
cmake ..
make
mv *.so ../../
```

The following command uses Llama 3 as an example:

```bash
CONFIG_FILE="./torchtitan/experiments/autopartition/train_configs/debug_model.toml" ./run_train.sh
```

## Performance

Hardware configuration: 4x RTX 3090 24GB, pipeline parallelism dimension is 4.

### llama3 配置对比
| hidden size| layers | autopipe TPS| default TPS| Speedup |
| ---------- | ---- | ---------- | -----------| ----------- |
| dim=256 | 6 | 31,094 | 29,549 | +5.2% |
| dim=256 | 12 | 21,803 | 21,923 | -0.5% |
| dim=2048 | 12 | 3,348 | 2,616 | +28.0% |
| dim=4096 | 12 | 981 | 761 | +28.9% |

### deepseekv3(without moe) 配置对比

| hidden size| layers | autopipe TPS| default TPS| Speedup |
| ---------- | ---- | ---------- | -----------| ----------- |
| dim=256 | 6 | 13,373 | 13,059 | +2.4% |
| dim=256 | 12 | 7,714 | 6,859 | +12.5% |
| dim=2048 | 12 | 4,331 | 3,810 | +13.7% |
| dim=4096 | 12 | 2,888 | 2,561 | +12.8% |
| dim=4096 | 16 | 2,207 | 2,008 | +9.9% |
| dim=8192 | 16 | 4,331 | 3,935 | +10.1% |


### Known Issues

- **Not Support Moe** - Auto-Partition need flops for each layers, but current profiler from deepspeed not support computing flops for moe.
14 changes: 14 additions & 0 deletions torchtitan/experiments/autopartition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = [
"get_deepseek_v3_train_spec",
"get_llama3_train_spec",
]


from .deepseek_v3_tain_spec import get_deepseek_v3_train_spec
from .llama3_tain_spec import get_llama3_train_spec
121 changes: 121 additions & 0 deletions torchtitan/experiments/autopartition/deepseek_v3/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.


from dataclasses import dataclass, field

from torch import nn

from torchtitan.config import JobConfig
from torchtitan.models.moe import MoEArgs
from torchtitan.models.utils import get_moe_model_nparams_and_flops
from torchtitan.protocols.model import BaseModelArgs
from torchtitan.tools.logging import logger
from torchtitan.tools.utils import has_cuda_capability


# Reference: https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
@dataclass
class DeepSeekV3ModelArgs(BaseModelArgs):
"""
Data class for defining model arguments and hyperparameters.

Attributes:
max_batch_size (int): Maximum batch size.
max_seq_len (int): Maximum sequence length.
vocab_size (int): Vocabulary size.
dim (int): Model dimension.
inter_dim (int): Intermediate dimension for MLP layers.
moe_inter_dim (int): Intermediate dimension for MoE layers.
n_layers (int): Number of transformer layers.
n_dense_layers (int): Number of dense layers in the model.
n_heads (int): Number of attention heads.
norm_eps (float): Epsilon value used for RMSNorm.
moe_args (MoEArgs): MoE configuration.
n_expert_groups (int): Number of expert groups.
n_limited_groups (int): Number of limited groups for MoE routing.
q_lora_rank (int): LoRA rank for query projections.
kv_lora_rank (int): LoRA rank for key-value projections.
qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings.
qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings.
v_head_dim (int): Dimension for value projections.
use_flex_attn (bool): Whether to use FlexAttention.
attn_mask_type (str): Type of attention mask.
original_seq_len (int): Original sequence length.
rope_theta (float): Base for rotary positional encoding.
rope_factor (float): Scaling factor for extended sequence lengths.
beta_fast (int): Fast beta correction factor.
beta_slow (int): Slow beta correction factor.
"""

max_batch_size: int = 8
max_seq_len: int = 4096 * 4
vocab_size: int = 102400
dim: int = 2048
inter_dim: int = 10944
moe_inter_dim: int = 1408
n_layers: int = 27
n_dense_layers: int = 1
n_heads: int = 16
norm_eps: float = 1e-5 # eps used for RMSNorm

# MoE
moe_args: MoEArgs = field(default_factory=MoEArgs)
# TODO: node-limited routing is not supported yet
n_expert_groups: int = 1
n_limited_groups: int = 1

# Multi-Head Latent Attention (MLA)
q_lora_rank: int = 0
kv_lora_rank: int = 512
qk_nope_head_dim: int = 128
qk_rope_head_dim: int = 64
v_head_dim: int = 128
use_flex_attn: bool = False
attn_mask_type: str = "causal"

# yarn
original_seq_len: int = 4096
rope_theta: float = 10000.0
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.0

def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
seq_len = job_config.training.seq_len
if seq_len > self.max_seq_len:
logger.warning(
f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}."
)
self.max_seq_len = seq_len

if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0):
logger.warning(
"Failed to use grouped mm, which is only supported on SM90 or later",
)
self.moe_args.use_grouped_mm = False

if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
raise NotImplementedError(
"CP support for FlexAttention is still in progress."
)

self.moe_args._debug_force_load_balance = (
job_config.debug.moe_force_load_balance
)

def get_nparams_and_flops(
self, model: nn.Module, seq_len: int
) -> tuple[int, float]:
return get_moe_model_nparams_and_flops(
self,
model,
self.qk_nope_head_dim + self.qk_rope_head_dim + self.v_head_dim,
seq_len,
)
Loading