Skip to content

Commit 54d2a8b

Browse files
committed
add script to train with ft
Summary: the script adds configuration options to run training locally with ft enabled
1 parent 9e9890a commit 54d2a8b

File tree

2 files changed

+97
-2
lines changed

2 files changed

+97
-2
lines changed

docs/torchft.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,12 @@ The `--training.global_batch_size` parameter refers to global batch size that wi
6868

6969
#### Replica Group 0
7070
```bash
71-
CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config
71+
CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=0,1,2,3 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=0
7272
```
7373

7474
#### Replica Group 1
7575
```bash
76-
CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1 --fault_tolerance.semi_sync_method="diloco" --experimental.custom_args_module=torchtitan.components.ft.config
76+
CONFIG_FILE=./torchtitan/models/llama3_ft/train_configs/debug_model.toml CUDA_VISIBLE_DEVICES=4,5,6,7 NGPU=4 ./run_train.sh --parallelism.data_parallel_shard_degree=4 --fault_tolerance.enable --fault_tolerance.group_size=2 --fault_tolerance.replica_id=1
7777
```
7878

7979
## Fault Tolerance Configuration Options
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
[job]
2+
dump_folder = "./outputs"
3+
description = "Llama 3 debug training"
4+
print_args = false
5+
6+
[profiling]
7+
enable_profiling = true
8+
save_traces_folder = "profile_trace"
9+
profile_freq = 10
10+
profiler_active = 10
11+
profiler_warmup = 0
12+
enable_memory_snapshot = false
13+
save_memory_snapshot_folder = "memory_snapshot"
14+
15+
[metrics]
16+
log_freq = 1
17+
disable_color_printing = false
18+
enable_tensorboard = false
19+
save_tb_folder = "tb"
20+
enable_wandb = false
21+
22+
[model]
23+
name = "llama3"
24+
flavor = "debugmodel"
25+
# test folder with tokenizer.json, for debug purpose only
26+
hf_assets_path = "./tests/assets/tokenizer"
27+
# converters = ["float8"]
28+
29+
[optimizer]
30+
name = "AdamW"
31+
lr = 8e-4
32+
eps = 1e-8
33+
34+
[lr_scheduler]
35+
warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps
36+
decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps
37+
decay_type = "linear"
38+
min_lr_factor = 0.0
39+
40+
[training]
41+
local_batch_size = 8
42+
seq_len = 2048
43+
max_norm = 1.0 # grad norm clipping
44+
steps = 100
45+
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
46+
47+
[parallelism]
48+
data_parallel_replicate_degree = 1
49+
data_parallel_shard_degree = -1
50+
fsdp_reshard_after_forward = "default" # default / never / always
51+
tensor_parallel_degree = 1
52+
enable_async_tensor_parallel = false
53+
pipeline_parallel_degree = 1
54+
context_parallel_degree = 1
55+
56+
[checkpoint]
57+
enable = false
58+
folder = "checkpoint"
59+
interval = 10
60+
last_save_model_only = false
61+
export_dtype = "float32"
62+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
63+
64+
[activation_checkpoint]
65+
mode = "selective" # ["none", "selective", "full"]
66+
selective_ac_option = "2" # 'int' = ac every positive int layer or 'op', ac based on ops policy
67+
68+
[compile]
69+
enable = false
70+
components = ["model", "loss"]
71+
72+
[quantize.linear.float8]
73+
enable_fsdp_float8_all_gather = false
74+
precompute_float8_dynamic_scale_for_fsdp = false
75+
filter_fqns = ["output"]
76+
77+
[validation]
78+
enable = false
79+
dataset = "c4_validation"
80+
freq = 5
81+
steps = 10
82+
83+
[comm]
84+
train_timeout_seconds = 15
85+
86+
[fault_tolerance]
87+
enable = true
88+
sync_steps = 10
89+
num_fragments = 2
90+
semi_sync_method = "diloco"
91+
process_group = "nccl"
92+
process_group_timeout_ms = 10000
93+
94+
[experimental]
95+
custom_args_module = "torchtitan.components.ft.config"

0 commit comments

Comments
 (0)