Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion torchtitan/experiments/auto_parallel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
)
register_train_spec(
TrainSpec(
name="deepseekv3_auto_parallel",
name="deepseek_v3_auto_parallel",
model_cls=DeepSeekV3Model,
model_args=deepseekv3_configs,
parallelize_fn=parallelize_deepseekv3,
Expand Down
38 changes: 29 additions & 9 deletions torchtitan/experiments/auto_parallel/parallelize_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
from torchtitan.tools.logging import logger


def apply_local_map_to_moe():
def apply_local_map_to_moe(mesh):
"""
TODO: fix HOPs not restoring the original signature.
TODO: fix tracing with local shapes so that we can use Shard placements

Current HOP signature we get:

Current HOP signature we get 9 inputs:
class subgraph_0(torch.nn.Module):
def forward(self,
rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0",
Expand All @@ -43,28 +42,42 @@ def forward(self,
moe._moe_forward = local_map(
moe._moe_forward,
out_placements=(
(Replicate(),), # (Shard(0),),
# (Shard(0),),
(Replicate(),),
(Replicate(),),
),
in_placements=(
(Replicate(),), # (Shard(0),),
# x
# (Shard(0),),
(Replicate(),),
# router weight
(Replicate(),),
# expert bias
(Replicate(),),
# expert
# (Shard(0),),
# (Shard(0),),
# (Shard(0),),
(Replicate(),),
(Replicate(),),
(Replicate(),),
# shared
(Replicate(),),
(Replicate(),),
(Replicate(),),
None,
None
),
redistribute_inputs=True,
in_grad_placements=None,
device_mesh=None,
device_mesh=mesh,
)


# Run workflow with:
# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_auto_parallel
# Run AP with:
# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseek_v3_auto_parallel
# Non-AP command:
# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseek_v3
def parallelize_deepseekv3(
model,
parallel_dims: ParallelDims,
Expand Down Expand Up @@ -100,7 +113,7 @@ def input_fn():
assert parallel_dims.pp_enabled is False, "PP not supported yet"

# apply local_map to MoE
apply_local_map_to_moe()
apply_local_map_to_moe(world_mesh)

# torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = (
# lambda bucket_idx: 500 / parallel_dims.tp
Expand Down Expand Up @@ -217,9 +230,16 @@ def get_moe_modules(model):
# Copy custom attributes from original to parallel MoE modules
# This is fine to do since these attributes are read only
for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules):

if hasattr(orig_moe, 'moe_enabled'):
par_moe.load_balance_coeff = orig_moe.load_balance_coeff

# piggyback in some asserts needed for our modified forward
# these can't be in the model code because export thinks they are data-dependent
assert orig_moe.score_before_experts
assert orig_moe.experts.use_grouped_mm
assert orig_moe.shared_experts is not None

# Copy load_balance_coeff
if hasattr(orig_moe, 'load_balance_coeff'):
par_moe.load_balance_coeff = orig_moe.load_balance_coeff
79 changes: 57 additions & 22 deletions torchtitan/models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ class MoEArgs:
use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation
load_balance_coeff: float | None = 1e-3

def functional_feed_forward(w1, w2, w3, x):
return F.linear(F.silu(F.linear(x, w1)) * F.linear(x, w3), w2)

# can be used as dense FFN layer or shared experts in MoE layers
class FeedForward(nn.Module):
Expand Down Expand Up @@ -191,7 +193,7 @@ def __init__(
self.route_scale = route_scale

def forward(
self, x: torch.Tensor, expert_bias: torch.Tensor | None = None
self, x: torch.Tensor, gate_weight: torch.nn.Parameter, expert_bias: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Args:
Expand All @@ -209,7 +211,8 @@ def forward(
Number of tokens assigned to each expert with shape ``(num_experts,)``.
"""
# scores shape (bs*slen, num_experts)
scores = self.gate(x)
# scores = self.gate(x)
scores = torch.nn.functional.linear(x, gate_weight)

# By default, sigmoid or softmax is performed in float32 to avoid loss explosion
if self.score_func == "sigmoid":
Expand Down Expand Up @@ -311,7 +314,29 @@ def forward(
num_tokens_per_expert,
)

def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, experts, shared_experts):
# def forward(self,
# rms_norm_5: "f32[64, 2048, 256][524288, 256, 1]cuda:0",
# self____modules__layers____modules__1____modules__moe____modules__router____modules__gate____parameters__weight: "f32[8, 256][256, 1]cuda:0",
# self____modules__layers____modules__1____modules__moe____buffers__expert_bias: "f32[8][1]cuda:0",
# self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w1: "f32[8, 256, 256][65536, 256, 1]cuda:0",
# self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w3: "f32[8, 256, 256][65536, 256, 1]cuda:0",
# self____modules__layers____modules__1____modules__moe____modules__experts____parameters__w2: "f32[8, 256, 256][65536, 256, 1]cuda:0",
# self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w1____parameters__weight: "f32[512, 256][256, 1]cuda:0",
# self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w3____parameters__weight: "f32[512, 256][256, 1]cuda:0",
# self____modules__layers____modules__1____modules__moe____modules__shared_experts____modules__w2____parameters__weight: "f32[256, 512][512, 1]cuda:0"):
def _moe_forward(
x,
router_gate_weight,
expert_bias,
experts_w1,
experts_w3,
experts_w2,
shared_w1,
shared_w3,
shared_w2,
router, # None
reorderer, # None
):
# x: 64, 2048, 256
bs, slen, dim = x.shape
x = x.view(-1, dim)
Expand All @@ -322,7 +347,7 @@ def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, expert
top_scores,
selected_experts_indices,
num_tokens_per_expert,
) = router(x, expert_bias)
) = router(x, router_gate_weight, expert_bias)

# tokens_per_expert will be used to update the expert bias for load balancing.
# and also to count the expert usage
Expand Down Expand Up @@ -354,27 +379,23 @@ def _moe_forward(x, router, expert_bias, reorderer, score_before_experts, expert

# shape (bs*slen*top_k, dim)
routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted)

if score_before_experts:
routed_input = (
routed_input.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)
routed_input = (
routed_input.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)

# shape (bs*slen*top_k, dim)
routed_output = experts(routed_input, num_tokens_per_expert)

if not score_before_experts:
routed_output = (
routed_output.to(torch.float32)
* top_scores_experts_sorted.reshape(-1, 1)
).to(x.dtype)
# routed_output = experts(routed_input, num_tokens_per_expert)
routed_output = _run_experts_grouped_mm(
experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert
)

# shared expert
if shared_experts is not None:
out = shared_experts(x)
else:
out = torch.zeros_like(x)
# if shared_experts is not None:
# out = shared_experts(x)
# else:
# out = torch.zeros_like(x)
out = functional_feed_forward(shared_w1, shared_w2, shared_w3, x)

out = out.scatter_add(
dim=0, index=token_indices_experts_sorted, src=routed_output
Expand Down Expand Up @@ -439,7 +460,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
Returns:
out (torch.Tensor): Output tensor with shape ``(bs, slen, dim)``.
"""
out, num_tokens_per_expert = _moe_forward(x, self.router, self.expert_bias, self.reorderer, self.score_before_experts, self.experts, self.shared_experts)
experts_w1, experts_w2, experts_w3 = self.experts.parameters()
shared_w1, shared_w2, shared_w3 = self.shared_experts.parameters()
out, num_tokens_per_expert = _moe_forward(
x,
self.router.gate.weight,
self.expert_bias,
experts_w1,
experts_w3,
experts_w2,
shared_w1,
shared_w3,
shared_w2,
self.router, # None
self.reorderer, # None
)

# HOPs don't support buffer mutations, keep this outside
with torch.no_grad():
Expand Down
Loading