Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
"transformers_modeling_backend",
"autoparallel.llama3",
"autoparallel.deepseek_v3",
"autoparallel.local_map_deepseek_v3",
]
)
6 changes: 6 additions & 0 deletions torchtitan/experiments/autoparallel/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,9 @@ Requires installing [[email protected]:meta-pytorch/autoparallel.git](https://githu
**DeepSeekv3**

`CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config`

**DeepSeekv3 local_map**

This is a variant of titan's DSv3, which uses a local_map for the expert parallel region. This only supports 2D mesh right now. NOTE: the mesh provided are just to reuse torchtitan's trainer mesh setup code. Autoparallel is not bound to use dp2ep.

`NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml tlp ./run_train.sh --model.name autoparallel.local_map_deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config --parallelism.data_parallel_shard_degree 2 --parallelism.expert_parallel_degree 2`
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ def set_torchtitan_fields(orig, new):
block.moe_enabled = hasattr(block, "moe")


# Run workflow with:
# CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ./run_train.sh --model.name deepseekv3_autoparallel
def parallelize_deepseekv3(
model,
parallel_dims: ParallelDims,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

import copy

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.distributed.pipeline_parallel import pipeline_llm
from torchtitan.hf_datasets.text_datasets import build_text_dataloader

from torchtitan.models.deepseek_v3 import deepseekv3_args
from torchtitan.models.deepseek_v3.model.state_dict_adapter import (
DeepSeekV3StateDictAdapter,
)
from torchtitan.protocols.train_spec import TrainSpec

from .args import DeepSeekV3ModelArgs, get_sample_config

from .model import DeepSeekV3Model
from .parallelize_deepseekv3 import parallelize_deepseekv3


def get_model_args() -> DeepSeekV3ModelArgs:
model_args = copy.deepcopy(deepseekv3_args)
# TODO: Align configs between AP and Titan
for config in model_args.keys():
# Just override the configs
override = get_sample_config()
override.update_from_config = model_args[config].update_from_config
override.get_nparams_and_flops = model_args[config].get_nparams_and_flops
model_args[config] = override

return model_args


def get_train_spec() -> TrainSpec:
model_args = get_model_args()

return TrainSpec(
model_cls=DeepSeekV3Model,
model_args=model_args,
parallelize_fn=parallelize_deepseekv3,
pipelining_fn=pipeline_llm,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_text_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
state_dict_adapter=DeepSeekV3StateDictAdapter,
)
49 changes: 49 additions & 0 deletions torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from dataclasses import dataclass

from autoparallel._testing.models.dsv3 import (
DeepSeekV3ModelArgs as _DeepSeekV3ModelArgs,
MoEArgs as _MoEArgs,
)
from torchtitan.protocols.model import BaseModelArgs


# Need to share same base class with torchtitan models
@dataclass
class DeepSeekV3ModelArgs(_DeepSeekV3ModelArgs, BaseModelArgs):
pass


def get_sample_config() -> DeepSeekV3ModelArgs:
return DeepSeekV3ModelArgs(
vocab_size=2048,
max_seq_len=2048,
dim=256,
inter_dim=1024,
moe_inter_dim=256,
n_layers=4,
n_dense_layers=0,
n_heads=16,
moe_args=_MoEArgs(
num_experts=4,
num_shared_experts=2,
top_k=2,
score_func="softmax",
route_norm=False,
score_before_experts=False,
mesh=None,
),
q_lora_rank=0,
kv_lora_rank=512,
qk_nope_head_dim=128,
qk_rope_head_dim=64,
v_head_dim=128,
mscale=0.70,
)
18 changes: 18 additions & 0 deletions torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
#
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.

from autoparallel._testing.models.dsv3 import DeepSeekV3Model as _DeepSeekV3Model
from torchtitan.protocols.train_spec import ModelProtocol

from .args import DeepSeekV3ModelArgs


# Need to share same base class with torchtitan models
class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol):
def __init__(self, model_args: DeepSeekV3ModelArgs):
super().__init__(model_args)
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import time

import torch
from autoparallel.api import AutoParallel
from autoparallel.auto_bucketing import configure_inductor_for_autobucketing

from torch.distributed.tensor.placement_types import Shard
from torchtitan.config import JobConfig
from torchtitan.distributed import ParallelDims

from torchtitan.tools.logging import logger


# TODO: Autoparallel should transparently wrap the original nn.Module
# but I don't know how to do that.
def set_torchtitan_fields(orig, new):
assert isinstance(new.layers, torch.nn.ModuleDict)
for block in new.layers.values():
block.moe_enabled = hasattr(block, "moe")


def parallelize_deepseekv3(
model,
parallel_dims: ParallelDims,
job_config: JobConfig,
):
"""
Apply Autoparallel to the model

NOTE: The passed-in model preferably should be on meta device. Otherwise,
the model must fit on GPU or CPU memory.
"""

# TODO(whc)
# I do this because otherwise sometimes inductor will skip re-running passes like comms reordering
torch._inductor.config.force_disable_caches = True
# this is necessary for working with reordering passes. Just leave it set for all the jobs for now.
torch._inductor.config.allow_buffer_reuse = False

# allow configuring inductor comms optimizations from torchtitan commandline
configure_inductor_for_autobucketing(
job_config.experimental.comms_bucket_reorder_strategy
)

world_mesh = parallel_dims.world_mesh

# Update me when changing dsv3.py
assert world_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims"
assert world_mesh.mesh_dim_names == (
"dp_shard_mod_ep",
"dp_shard_in_ep",
), "Current setup assumes these specific meshes"

# Provide AP MoE with mesh
for layer in model.layers.values():
if layer.moe_enabled:
layer.moe.mesh = world_mesh
layer.moe.axis_name = "dp_shard_in_ep"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems okay for now, but when we enable TP, this should change so just add a comment that modify this when enabling TP.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added an assert to check for dp2ep specifically, because i don't think we're handling the mesh setup well...


def input_fn():
global_batch_size = job_config.training.global_batch_size
if global_batch_size < 0:
# This global batch size results in 1 gradient accumulation
# step.
dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard
global_batch_size = job_config.training.local_batch_size * dp_degree
return (
torch.randint(
0,
model.model_args.vocab_size,
(global_batch_size, job_config.training.seq_len),
device=torch.device("cuda"),
),
)

should_compile = job_config.compile.enable
if should_compile:
# TODO: support more options in AP API
assert job_config.compile.components == ["model"]
assert job_config.compile.backend == "inductor"

mp_policy = None
with AutoParallel(
model,
input_fn,
world_mesh,
mp_policy=mp_policy,
compile=should_compile,
dynamic=True,
) as autop:
autop.add_parameter_memory_constraint(low=None, high=None)

x_sharding = (Shard(0), Shard(0))
loss_parallel_enabled = (
parallel_dims.tp_enabled
and not job_config.parallelism.disable_loss_parallel
)
assert not loss_parallel_enabled
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of failing here why not disable loss parallel and give a warning?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warnings are easy to miss, you won't know loss parallel is disable if you don't spot it before titan starts dumping its logs

autop.add_input_constraints([x_sharding])
autop.add_output_constraints([x_sharding])
t0 = time.time()
sharding_placement = autop.optimize_placement()
t1 = time.time()
logger.info(f"AutoParallel took {t1 - t0} seconds")
parallel_mod = autop.apply_placement(sharding_placement)

set_torchtitan_fields(model, parallel_mod)

if loss_parallel_enabled:

# current PyTorch's implementation of loss parallel assumes
# that the DTensor has a 1d device mesh. This is not true
# in our case, but we can work around it by adding
# casting the output to a DTensor on a 1d device mesh.
# We should just use AutoParallel to do this for us, but
# it would require putting the loss inside the model as well
def _return_as_dtensor_for_loss_parallel(module, args, output):
return torch.distributed.tensor.DTensor.from_local(
output, world_mesh["tp"], (Shard(2),)
)

# not keeping a reference to the hook, don't plan on
# removing it at any point
parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel)

_preserve_moe_attributes(model, parallel_mod)

return parallel_mod


def _preserve_moe_attributes(original_model, parallel_model):
"""
Preserve MoE custom attributes from the original model to the parallel model.
This is only needed for attributes that aren't used in the graph, so they aren't
lifted as graph inputs and fetched by the pre-graph runtime wrapper.

`moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify
this block as a moe block. This should be safe as they are read-only.
"""

def get_moe_modules(model):
"""Extract all MoE modules from the model."""
moe_modules = []
if hasattr(model, "layers"):
if isinstance(model.layers, torch.nn.ModuleDict):
# regular torchtitan structure
blocks = model.layers.values()
else:
# autoparallel might change structure
blocks = (
model.layers.children() if hasattr(model.layers, "children") else []
)

for block in blocks:
if (
hasattr(block, "moe_enabled")
and block.moe_enabled
and hasattr(block, "moe")
):
moe_modules.append(block.moe)
elif hasattr(block, "moe"): # fallback for autoparallel
moe_modules.append(block.moe)
return moe_modules

original_moe_modules = get_moe_modules(original_model)
parallel_moe_modules = get_moe_modules(parallel_model)

# Copy custom attributes from original to parallel MoE modules
# This is fine to do since these attributes are read only
for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules):
if hasattr(orig_moe, "moe_enabled"):
par_moe.load_balance_coeff = orig_moe.load_balance_coeff

# Copy load_balance_coeff
if hasattr(orig_moe, "load_balance_coeff"):
par_moe.load_balance_coeff = orig_moe.load_balance_coeff