-
Notifications
You must be signed in to change notification settings - Fork 551
gpt-oss model enablement #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
gpt-oss model enablement #1754
Changes from all commits
9461315
371f204
4957bb0
c3fc9e7
b696028
4010fa2
2e71aaf
122e93a
589ce62
4fc78a3
b28fe7c
bb8ee6f
07c0ff4
a2727a6
e7f9a56
ef146e1
6f41f6c
2b47774
cd89d26
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# gpt-oss Model in torchtitan | ||
|
||
## Quick Start | ||
```bash | ||
CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh | ||
``` | ||
|
||
## Supported Features | ||
- FSDP/HSDP, TP, EP, ETP | ||
- Grouped matrix multiplication for efficient computation | ||
- SwiGLU activation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not a feature, but part of model itself. |
||
- Multi-head attention with sliding window mask and attention sink | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. similar |
||
|
||
|
||
## TODO | ||
1. More parallelism support: CP, PP | ||
2. Conversion between HF weights (StateDictAdapter) | ||
3. Forward parity verification | ||
4. CI support |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# This source code is licensed under the BSD-style license found in the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.components.loss import build_cross_entropy_loss | ||
from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
from torchtitan.components.tokenizer import build_hf_tokenizer | ||
from torchtitan.datasets.hf_datasets import build_hf_dataloader | ||
from torchtitan.models.moe import MoEArgs | ||
|
||
from torchtitan.protocols.train_spec import TrainSpec | ||
|
||
from .infra.parallelize import parallelize_gptoss | ||
from .model.args import GptOssModelArgs | ||
from .model.model import GptOssModel | ||
|
||
__all__ = [ | ||
"parallelize_gptoss", | ||
"GptOssModelArgs", | ||
"GptOssModel", | ||
"gptoss_configs", | ||
] | ||
|
||
|
||
gptoss_configs = { | ||
"debugmodel": GptOssModelArgs( | ||
hidden_size=256, | ||
num_hidden_layers=4, | ||
moe_args=MoEArgs( | ||
num_experts=8, | ||
num_shared_experts=0, | ||
score_func="softmax", | ||
route_norm=False, | ||
route_scale=1.0, | ||
score_before_experts=False, | ||
top_k=4, | ||
use_grouped_mm=True, | ||
load_balance_coeff=1e-3, | ||
), | ||
use_flex_attn=True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. default is already True, so can remove this |
||
attn_mask_type="causal", | ||
), | ||
"20b": GptOssModelArgs( | ||
num_hidden_layers=24, | ||
moe_args=MoEArgs( | ||
num_experts=32, | ||
num_shared_experts=0, | ||
score_func="softmax", | ||
route_norm=False, | ||
route_scale=1.0, | ||
score_before_experts=False, | ||
top_k=4, | ||
use_grouped_mm=True, | ||
load_balance_coeff=1e-3, | ||
), | ||
), | ||
"120b": GptOssModelArgs( | ||
num_hidden_layers=36, | ||
moe_args=MoEArgs( | ||
num_experts=128, | ||
num_shared_experts=0, | ||
score_func="softmax", | ||
route_norm=False, | ||
route_scale=1.0, | ||
score_before_experts=False, | ||
top_k=4, | ||
use_grouped_mm=True, | ||
load_balance_coeff=1e-3, | ||
), | ||
), | ||
} | ||
|
||
|
||
def get_train_spec() -> TrainSpec: | ||
return TrainSpec( | ||
name="gpt_oss", | ||
model_cls=GptOssModel, | ||
model_args=gptoss_configs, | ||
parallelize_fn=parallelize_gptoss, | ||
pipelining_fn=None, | ||
build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
build_dataloader_fn=build_hf_dataloader, | ||
build_tokenizer_fn=build_hf_tokenizer, | ||
build_loss_fn=build_cross_entropy_loss, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Callable | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributed.tensor import ( | ||
DeviceMesh, | ||
distribute_module, | ||
distribute_tensor, | ||
DTensor, | ||
Replicate, | ||
Shard, | ||
) | ||
from torch.distributed.tensor.parallel import ParallelStyle | ||
from torchtitan.distributed.expert_parallel import ExpertParallel | ||
|
||
|
||
# implementation of Tensor Parallel for the GroupedExperts in MoE | ||
class TensorParallel(ParallelStyle): | ||
def _partition_fn(self, name, module, device_mesh): | ||
module.register_parameter( | ||
"mlp1_weight", | ||
nn.Parameter( | ||
distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)]) | ||
), | ||
) # Column-wise sharding | ||
module.register_parameter( | ||
"mlp1_bias", | ||
nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), | ||
) # Column-wise sharding | ||
module.register_parameter( | ||
"mlp2_weight", | ||
nn.Parameter( | ||
distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)]) | ||
), | ||
) # Row-wise sharding | ||
module.register_parameter( | ||
"mlp2_bias", | ||
nn.Parameter( | ||
distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()]) | ||
), | ||
) # Replicate | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
self._partition_fn, | ||
) | ||
|
||
|
||
# This class is for dp2ep with TP (without TP we can just use ExpertParallel) | ||
class ExpertTensorParallel(ExpertParallel): | ||
def __init__( | ||
self, | ||
tp_mesh: DeviceMesh, | ||
ep_mesh: DeviceMesh, | ||
): | ||
super().__init__() | ||
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, | ||
# as DeviceMesh doesn't support slicing from a submesh. | ||
self.tp_mesh = tp_mesh | ||
self.ep_mesh = ep_mesh | ||
|
||
def _token_dispatch(self, mod, inputs, device_mesh): | ||
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_dispatch(mod, inputs, self.ep_mesh) | ||
|
||
def _partition_fn_2d(self, name, mod, ep_tp_mesh): | ||
mod.register_parameter( | ||
"mlp1_weight", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)]) | ||
), | ||
) # Column-wise sharding | ||
mod.register_parameter( | ||
"mlp1_bias", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) | ||
), | ||
) # Column-wise sharding | ||
mod.register_parameter( | ||
"mlp2_weight", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(1)]) | ||
), | ||
) # Row-wise sharding | ||
mod.register_parameter( | ||
"mlp2_bias", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) | ||
), | ||
) # Replicate | ||
|
||
def _token_combine(self, mod, routed_output, device_mesh): | ||
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_combine(mod, routed_output, self.ep_mesh) | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
partition_fn=self._partition_fn_2d, | ||
input_fn=self._token_dispatch, | ||
output_fn=self._token_combine, | ||
) | ||
|
||
|
||
# TODO(jianiw): This need to be merged with expert_parallel | ||
def expert_parallel(func: Callable) -> Callable: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I'll merge my refactor, and then please rebase |
||
""" | ||
This is a wrapper applied to the GroupedExperts computation, serving | ||
the following three purposes: | ||
1. Convert parameters from DTensors to plain Tensors, to work with | ||
dynamic-shape inputs which cannot be easily expressed as DTensors. | ||
2. In Expert Parallel, apply the generate_permute_indices kernel to | ||
permute the inputs to be ordered by local experts (see the _token_dispatch | ||
function in ExpertParallel) and permute the outputs back. | ||
3. In order to use torch._grouped_mm, we need to make sure the number of | ||
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices | ||
kernel also helps achieve this via padding, without incurring synchronization | ||
between device and host. Note that this will create side effects when wrapping | ||
the for-loop implementation of GroupedExperts, as it does not need padding. | ||
Among the above: | ||
1 and 2 are needed only when expert_parallel_degree > 1. | ||
3 is needed even for single-device computation. | ||
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. | ||
""" | ||
|
||
def wrapper( | ||
mlp1_weight: torch.Tensor, | ||
mlp1_bias: torch.Tensor, | ||
mlp2_weight: torch.Tensor, | ||
mlp2_bias: torch.Tensor, | ||
swiglu_limit: float, | ||
x: torch.Tensor, | ||
num_tokens_per_expert: torch.Tensor | None = None, | ||
) -> torch.Tensor: | ||
if isinstance(mlp1_weight, DTensor): | ||
mlp1_weight = mlp1_weight.to_local() | ||
mlp1_bias = mlp1_bias.to_local() | ||
mlp2_weight = mlp2_weight.to_local() | ||
mlp2_bias = mlp2_bias.to_local() | ||
|
||
if num_tokens_per_expert is not None: | ||
from torchtitan.experiments.kernels.moe.indices import ( | ||
generate_permute_indices, | ||
) | ||
|
||
experts_per_ep_rank = mlp1_weight.shape[0] | ||
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank | ||
|
||
ALIGN_SIZE_M = 16 | ||
with torch.no_grad(): | ||
( | ||
permuted_indices, | ||
num_tokens_per_expert, | ||
_, # offsets, | ||
) = generate_permute_indices( | ||
num_tokens_per_expert, | ||
experts_per_ep_rank, | ||
num_ep_ranks, | ||
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, | ||
ALIGN_SIZE_M, | ||
) | ||
|
||
x = torch.vstack((x, x.new_zeros((x.shape[-1])))) | ||
input_shape = x.shape | ||
x = x[permuted_indices, :] | ||
|
||
out = func( | ||
mlp1_weight, | ||
mlp1_bias, | ||
mlp2_weight, | ||
mlp2_bias, | ||
swiglu_limit, | ||
x, | ||
num_tokens_per_expert, | ||
) | ||
|
||
if num_tokens_per_expert is not None: | ||
out_unpermuted = out.new_empty(input_shape) | ||
out_unpermuted[permuted_indices, :] = out | ||
out = out_unpermuted[:-1] | ||
|
||
return out | ||
|
||
return wrapper |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: follow lexicographic order for readability