Skip to content

Commit 78da644

Browse files
committed
add joint graph runner deepseek_v3 experiment
1 parent db82f8b commit 78da644

File tree

3 files changed

+417
-1
lines changed

3 files changed

+417
-1
lines changed

torchtitan/experiments/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
_supported_experiments = frozenset(
8-
["flux", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
8+
[
9+
"flux",
10+
"simple_fsdp.llama3",
11+
"simple_fsdp.deepseek_v3",
12+
"vlm",
13+
"joint_graph_runner.deepseek_v3",
14+
]
915
)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
#
7+
# Copyright (c) Meta Platforms, Inc. All Rights Reserved.
8+
9+
from torchtitan.components.loss import build_cross_entropy_loss
10+
from torchtitan.components.lr_scheduler import build_lr_schedulers
11+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
12+
from torchtitan.components.tokenizer import build_hf_tokenizer
13+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14+
from torchtitan.distributed.pipeline_parallel import pipeline_llm
15+
16+
from torchtitan.experiments.simple_fsdp.deepseek_v3.model import (
17+
SimpleFSDPDeepSeekV3Model,
18+
)
19+
from torchtitan.models.deepseek_v3 import deepseekv3_args
20+
from torchtitan.protocols.train_spec import TrainSpec
21+
22+
from .parallelize import parallelize_deepseekv3
23+
24+
25+
def get_train_spec() -> TrainSpec:
26+
return TrainSpec(
27+
model_cls=SimpleFSDPDeepSeekV3Model,
28+
model_args=deepseekv3_args,
29+
parallelize_fn=parallelize_deepseekv3,
30+
pipelining_fn=pipeline_llm,
31+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
32+
build_lr_schedulers_fn=build_lr_schedulers,
33+
build_dataloader_fn=build_hf_dataloader,
34+
build_tokenizer_fn=build_hf_tokenizer,
35+
build_loss_fn=build_cross_entropy_loss,
36+
)

0 commit comments

Comments
 (0)