diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 7efc04b784..2164826c6e 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -362,6 +362,12 @@ def apply_fsdp( transformer_block.moe.experts.set_gradient_divide_factor( gradient_divide_factor, ) + else: + fully_shard( + transformer_block._checkpoint_wrapped_module.feed_forward if hasattr(transformer_block, "_checkpoint_wrapped_module") else transformer_block.feed_forward, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) fully_shard( transformer_block, diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index f55cb301bb..f64a7c5fcb 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -83,8 +83,8 @@ dim=2048, inter_dim=10944, moe_inter_dim=1408, - n_layers=27, - n_dense_layers=1, + n_layers=5, + n_dense_layers=5, n_heads=16, moe_args=MoEArgs( num_experts=64, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7182b1fca3..7a81b8d2aa 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -156,6 +156,9 @@ def parallelize_deepseekv3( else: logger.info("Applied FSDP to the model") + # import fbvscode + # fbvscode.set_trace() + if parallel_dims.cp_enabled: logger.info("Applied Context Parallel to the model") diff --git a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml index 7e9983a532..be4c96414f 100644 --- a/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml +++ b/torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml @@ -4,10 +4,10 @@ description = "DeepSeek-V3 16B model training" print_args = false [profiling] -enable_profiling = false +enable_profiling = true save_traces_folder = "profile_trace" profile_freq = 10 -enable_memory_snapshot = false +enable_memory_snapshot = true save_memory_snapshot_folder = "memory_snapshot" [metrics] @@ -35,10 +35,10 @@ decay_type = "cosine" min_lr_factor = 0.1 [training] -local_batch_size = 8 -seq_len = 4096 +local_batch_size = 1 +seq_len = 4 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 20 dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) [parallelism] @@ -49,7 +49,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 8 +expert_parallel_degree = 1 expert_tensor_parallel_degree = 1 [checkpoint] @@ -61,11 +61,11 @@ export_dtype = "float32" async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "full" # ["none", "selective", "full"] selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy [compile] -enable=true +enable=false components = ["loss"] # ["model", "loss"] [float8] diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index a34b4463f8..1ca6d3156f 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -42,7 +42,7 @@ ), "8B": TransformerModelArgs( dim=4096, - n_layers=32, + n_layers=3, n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 7d8aa76f0d..bd01b3fb20 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -298,6 +298,11 @@ def apply_fsdp( reshard_after_forward=reshard_after_forward, ) for layer_id, transformer_block in model.layers.items(): + fully_shard( + transformer_block.feed_forward if not hasattr(transformer_block, "_checkpoint_wrapped_module") else transformer_block._checkpoint_wrapped_module.feed_forward, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) fully_shard( transformer_block, **fsdp_config, diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 85df53b664..3caaddc240 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -7,7 +7,9 @@ description = "Llama 3 8B training" [profiling] enable_profiling = true save_traces_folder = "profile_trace" -profile_freq = 100 +profile_freq = 10 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot" [metrics] log_freq = 10 @@ -30,9 +32,9 @@ warmup_steps = 200 # lr scheduler warm up [training] local_batch_size = 1 -seq_len = 8192 +seq_len = 4 max_norm = 1.0 # grad norm clipping -steps = 1000 +steps = 100 dataset = "c4" [parallelism] @@ -55,7 +57,7 @@ enable=false components = ["model", "loss"] [activation_checkpoint] -mode = "selective" # ["none", "selective", "full"] +mode = "full" # ["none", "selective", "full"] selective_ac_option = "op" # "int" = ac every positive int layer or 'op', ac based on ops policy [float8]