Skip to content

Commit 0f7a7c9

Browse files
committed
[wip] current state working, needs cleanup
1 parent 71dea16 commit 0f7a7c9

File tree

8 files changed

+532
-33
lines changed

8 files changed

+532
-33
lines changed

run_train.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ set -ex
1010
# use envs as local overwrites for convenience
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
13-
NGPU=${NGPU:-"8"}
14-
export LOG_RANK=${LOG_RANK:-0}
13+
# NGPU=${NGPU:-"8"}
14+
NGPU=${NGPU:-"4"}
15+
# export LOG_RANK=${LOG_RANK:-0,1,2,3,4,5,6,7}
16+
# export LOG_RANK=${LOG_RANK:-0,1,2,3}
17+
export LOG_RANK=${LOG_RANK:-3}
1518
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
1619
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
1720

torchtitan/config/job_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,10 +623,10 @@ class MX:
623623

624624
@dataclass
625625
class Comm:
626-
init_timeout_seconds: int = 300
626+
init_timeout_seconds: int = 30
627627
"""Timeout for communication operations, during initialization and first train step."""
628628

629-
train_timeout_seconds: int = 100
629+
train_timeout_seconds: int = 10
630630
"""
631631
Timeout for communication operations after the first train step --
632632
usually a tighter bound than during initialization.

torchtitan/distributed/expert_parallel.py

Lines changed: 240 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8-
from typing import Callable, Literal
8+
from typing import Callable, Literal, Dict
99

1010
import torch
1111
import torch.nn as nn
@@ -22,6 +22,202 @@
2222
)
2323
from torch.distributed.tensor.parallel import ParallelStyle
2424

25+
import threading
26+
import torch
27+
from typing import Optional
28+
import time
29+
30+
class SimplifiedHookCoordinator:
31+
"""
32+
TODO: this hangs because FWD is doing dispatch, and BWD is doing combine
33+
the two communications are conflicting??
34+
"""
35+
36+
"""Alternating forward/backward coordination with just 2 semaphores"""
37+
38+
def __init__(self):
39+
self._lock = threading.Lock()
40+
41+
# Only need 2 semaphores - one for forward, one for backward
42+
self._forward_semaphore = threading.Semaphore(0) # Forward waits
43+
self._backward_semaphore = threading.Semaphore(1) # Backward starts first
44+
45+
# CUDA event tracking
46+
self._forward_cuda_event = torch.cuda.Event()
47+
self._backward_cuda_event = torch.cuda.Event()
48+
self._forward_event_recorded = False
49+
self._backward_event_recorded = False
50+
51+
# Store AsyncCollectiveTensors from previous operations
52+
self._stored_forward_async_tensor = None
53+
self._stored_backward_async_tensor = None
54+
55+
self._coordination_enabled = False
56+
self._cycle_count = 0
57+
58+
def is_coordination_enabled(self) -> bool:
59+
"""Check if coordination is currently enabled"""
60+
return self._coordination_enabled
61+
62+
def enable_coordination(self, num_layers: Optional[int] = None):
63+
self._coordination_enabled = True
64+
self._cycle_count = 0
65+
66+
# Reset semaphores
67+
self._forward_semaphore = threading.Semaphore(0)
68+
self._backward_semaphore = threading.Semaphore(1)
69+
70+
# Reset CUDA events
71+
self._forward_cuda_event = torch.cuda.Event()
72+
self._backward_cuda_event = torch.cuda.Event()
73+
self._forward_event_recorded = False
74+
self._backward_event_recorded = False
75+
76+
# num layers
77+
self._num_layers = num_layers
78+
79+
print("[COORDINATION] Simplified hook coordination with CUDA events ENABLED")
80+
81+
def disable_coordination(self):
82+
self._coordination_enabled = False
83+
# Release both semaphores to unblock any waiting threads
84+
try:
85+
self._forward_semaphore.release()
86+
self._backward_semaphore.release()
87+
except ValueError:
88+
pass
89+
print("[COORDINATION] Simplified hook coordination DISABLED")
90+
91+
def acquire_forward_execution(self):
92+
if not self._coordination_enabled:
93+
return
94+
95+
print("[FORWARD] Attempting acquire forward execution")
96+
self._forward_semaphore.acquire()
97+
98+
# 2. Wait for PREVIOUS FORWARD CUDA operations to complete
99+
if self._forward_event_recorded:
100+
print("[FORWARD] Waiting for previous FORWARD CUDA operations to complete...")
101+
self._forward_cuda_event.wait()
102+
print("[FORWARD] Previous FORWARD CUDA operations completed")
103+
104+
# Wait for forward's own previously stored AsyncCollectiveTensor
105+
if self._stored_forward_async_tensor is not None:
106+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
107+
if isinstance(self._stored_forward_async_tensor, AsyncCollectiveTensor):
108+
print("[FORWARD] Waiting for forward's own previous AsyncCollectiveTensor...")
109+
torch.ops._c10d_functional.wait_tensor(self._stored_forward_async_tensor)
110+
print("[FORWARD] Forward's previous AsyncCollectiveTensor completed")
111+
self._stored_forward_async_tensor = None # Clear after waiting
112+
113+
print("[FORWARD] Acquired forward execution")
114+
115+
def release_forward_execution(self, async_tensor: Optional[torch.Tensor] = None):
116+
if not self._coordination_enabled:
117+
return
118+
119+
# 1. Record CUDA event for current forward operations
120+
current_stream = torch.cuda.current_stream()
121+
self._forward_cuda_event = torch.cuda.Event() # Create new event
122+
self._forward_cuda_event.record(current_stream)
123+
self._forward_event_recorded = True
124+
print("[FORWARD] Recorded forward CUDA completion event")
125+
126+
# Store forward's AsyncCollectiveTensor for forward's own future use
127+
self._stored_forward_async_tensor = async_tensor
128+
if async_tensor is not None:
129+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
130+
if isinstance(async_tensor, AsyncCollectiveTensor):
131+
print("[FORWARD] Stored forward AsyncCollectiveTensor for forward's future use")
132+
133+
print("[FORWARD] Releasing forward, signaling backward")
134+
self._backward_semaphore.release() # Signal backward can start
135+
136+
self._cycle_count += 1
137+
print(f"cycle count {self._cycle_count}")
138+
# TODO: better way to determine when to disable coordination
139+
moe_multipler = 4
140+
if self._num_layers is not None and self._cycle_count >= moe_multipler * self._num_layers:
141+
print("[COORDINATION] Reached target number of cycles, disabling coordination")
142+
self.disable_coordination()
143+
return # Exit early since coordination is now disabled
144+
145+
def acquire_backward_execution(self):
146+
if not self._coordination_enabled:
147+
return
148+
149+
print("[BACKWARD] Attempting acquire backward execution")
150+
self._backward_semaphore.acquire()
151+
152+
# # 2. Wait for PREVIOUS BACKWARD CUDA operations to complete
153+
if self._backward_event_recorded:
154+
print("[BACKWARD] Waiting for previous BACKWARD CUDA operations to complete...")
155+
self._backward_cuda_event.wait()
156+
print("[BACKWARD] Previous BACKWARD CUDA operations completed")
157+
158+
# Wait for backward's own previously stored AsyncCollectiveTensor
159+
if self._stored_backward_async_tensor is not None:
160+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
161+
if isinstance(self._stored_backward_async_tensor, AsyncCollectiveTensor):
162+
print("[BACKWARD] Waiting for backward's own previous AsyncCollectiveTensor...")
163+
torch.ops._c10d_functional.wait_tensor(self._stored_backward_async_tensor)
164+
print("[BACKWARD] Backward's previous AsyncCollectiveTensor completed")
165+
self._stored_backward_async_tensor = None # Clear after waiting
166+
167+
print("[BACKWARD] Acquired backward execution")
168+
169+
def release_backward_execution(self, async_tensor: Optional[torch.Tensor] = None):
170+
if not self._coordination_enabled:
171+
return
172+
173+
# 1. Record CUDA event for current backward operations
174+
current_stream = torch.cuda.current_stream()
175+
self._backward_cuda_event = torch.cuda.Event() # Create new event
176+
self._backward_cuda_event.record(current_stream)
177+
self._backward_event_recorded = True
178+
print("[BACKWARD] Recorded backward CUDA completion event")
179+
180+
# Store backward's AsyncCollectiveTensor for backward's own future use
181+
self._stored_backward_async_tensor = async_tensor
182+
if async_tensor is not None:
183+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
184+
if isinstance(async_tensor, AsyncCollectiveTensor):
185+
print("[BACKWARD] Stored backward AsyncCollectiveTensor for backward's future use")
186+
187+
print("[BACKWARD] Releasing backward, signaling next forward")
188+
self._forward_semaphore.release() # Signal next forward can start
189+
# self._cycle_count += 1
190+
# print(f"[CYCLE] Completed cycle {self._cycle_count}")
191+
192+
# Global coordinator
193+
_hook_coordinator = SimplifiedHookCoordinator()
194+
195+
class SyncHook(torch.autograd.Function):
196+
@staticmethod
197+
def forward(ctx, x, hook_name):
198+
ctx.hook_name = hook_name
199+
200+
_hook_coordinator.acquire_forward_execution()
201+
202+
try:
203+
if _hook_coordinator.is_coordination_enabled():
204+
print(f"[FORWARD] {hook_name}_fwd")
205+
return x
206+
finally:
207+
_hook_coordinator.release_forward_execution(x)
208+
209+
@staticmethod
210+
def backward(ctx, grad_output):
211+
hook_name = ctx.hook_name
212+
213+
_hook_coordinator.acquire_backward_execution()
214+
215+
try:
216+
if _hook_coordinator.is_coordination_enabled():
217+
print(f"[BACKWARD] {hook_name}_bwd")
218+
return grad_output, None
219+
finally:
220+
_hook_coordinator.release_backward_execution(grad_output)
25221

26222
TOKEN_GROUP_ALIGN_SIZE_M = 8
27223
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
@@ -109,11 +305,15 @@ def _token_dispatch(self, mod, inputs, device_mesh):
109305
.to(torch.device("cpu"), non_blocking=True)
110306
)
111307
# NOTE: this would incur a device-to-host sync
308+
# CPU-GPU sync here!!!
309+
# start_time = time.time()
112310
output_splits = (
113311
num_tokens_per_expert_group.view(ep_size, -1)
114312
.sum(dim=1)
115313
.to(torch.device("cpu"), non_blocking=False)
116314
)
315+
# sync_time = time.time() - start_time
316+
# print(f"CPU-GPU sync took {sync_time:.4f}s")
117317
self.input_splits = input_splits.tolist()
118318
self.output_splits = output_splits.tolist()
119319

@@ -125,6 +325,11 @@ def _token_dispatch(self, mod, inputs, device_mesh):
125325
device_mesh.get_group(),
126326
)
127327

328+
# TODO: FIX NEEDING THIS???
329+
routed_input = torch.ops._c10d_functional.wait_tensor(
330+
routed_input
331+
)
332+
128333
# NOTE: After this all-to-all, the routed input is put on proper EP rank.
129334
# However, the num_tokens_per_expert_group is not of the final target format
130335
# [#tokens for local expert 0, #tokens for local expert 1, ...]
@@ -152,16 +357,48 @@ def _token_combine(self, mod, routed_output, device_mesh):
152357
self.output_splits,
153358
device_mesh.get_group(),
154359
)
360+
# TODO: FIX NEEDING THIS???
361+
# CRITICAL: Wait for AsyncCollectiveTensor BEFORE coordination
362+
from torch.distributed._functional_collectives import AsyncCollectiveTensor
363+
if isinstance(routed_output, AsyncCollectiveTensor):
364+
routed_output = torch.ops._c10d_functional.wait_tensor(routed_output)
365+
155366
return routed_output
156367

157368
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
158-
return distribute_module(
159-
module,
369+
"""
370+
hooks are called in the order they are registered:
371+
A, dispatch, B (pre hooks)
372+
C, combine, D (post hooks)
373+
"""
374+
inner_wrapped_module = self._wrap_with_inner_hooks(module)
375+
distributed_module = distribute_module(
376+
inner_wrapped_module,
160377
device_mesh,
161378
partition_fn=ExpertParallel._partition_fn,
162379
input_fn=self._token_dispatch,
163380
output_fn=self._token_combine,
164381
)
382+
final_module = self._wrap_with_outer_hooks(distributed_module)
383+
return final_module
384+
385+
def _wrap_with_inner_hooks(self, module):
386+
def inner_pre_hook(module, input):
387+
return (SyncHook.apply(input[0], "dispatch_A"),) + input[1:]
388+
def inner_post_hook(module, input, output):
389+
return SyncHook.apply(output, "combine_C")
390+
module.register_forward_pre_hook(inner_pre_hook)
391+
module.register_forward_hook(inner_post_hook)
392+
return module
393+
394+
def _wrap_with_outer_hooks(self, module):
395+
def outer_pre_hook(module, input):
396+
return (SyncHook.apply(input[0], "dispatch_B"),) + input[1:]
397+
def outer_post_hook(module, input, output):
398+
return SyncHook.apply(output, "combine_D")
399+
module.register_forward_pre_hook(outer_pre_hook)
400+
module.register_forward_hook(outer_post_hook)
401+
return module
165402

166403

167404
# This class is for dp2ep with TP (without TP we can just use ExpertParallel)

torchtitan/models/deepseek_v3/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
1212
from torchtitan.components.tokenizer import build_hf_tokenizer
1313
from torchtitan.datasets.hf_datasets import build_hf_dataloader
14-
from torchtitan.models.llama3.infra.pipeline import pipeline_llama
14+
from torchtitan.models.llama3.infra.pipeline import pipeline_llama, pipeline_llama_tracer
1515
from torchtitan.models.moe import MoEArgs
1616

1717
from torchtitan.protocols.train_spec import register_train_spec, TrainSpec
@@ -32,10 +32,11 @@
3232
deepseekv3_configs = {
3333
"debugmodel": DeepSeekV3ModelArgs(
3434
vocab_size=2000,
35-
dim=256,
35+
# needs at least dim 8?
36+
dim=8,
3637
inter_dim=1024,
3738
moe_inter_dim=256,
38-
n_layers=6,
39+
n_layers=16,
3940
n_dense_layers=1,
4041
n_heads=16,
4142
moe_args=MoEArgs(

torchtitan/models/deepseek_v3/train_configs/debug_model.toml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ description = "DeepSeek-V3 debug training"
44
print_args = false
55

66
[profiling]
7-
enable_profiling = false
7+
enable_profiling = true
88
save_traces_folder = "profile_trace"
9-
profile_freq = 10
9+
profile_freq = 5
1010
enable_memory_snapshot = false
1111
save_memory_snapshot_folder = "memory_snapshot"
1212

@@ -36,22 +36,23 @@ decay_type = "linear"
3636
min_lr_factor = 0.0
3737

3838
[training]
39-
local_batch_size = 8
40-
seq_len = 2048
39+
local_batch_size = 4
40+
seq_len = 4
4141
max_norm = 1.0 # grad norm clipping
42-
steps = 10
42+
steps = 6
4343
dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M)
44+
# dataset = "c4"
4445

4546
[parallelism]
4647
data_parallel_replicate_degree = 1
4748
data_parallel_shard_degree = -1
4849
fsdp_reshard_after_forward = "default" # default / never / always
4950
tensor_parallel_degree = 1
5051
enable_async_tensor_parallel = false
51-
pipeline_parallel_degree = 1
52-
pipeline_parallel_schedule = "1F1B"
52+
pipeline_parallel_degree = 2
53+
expert_parallel_degree = 2
5354
context_parallel_degree = 1
54-
expert_parallel_degree = 1
55+
pipeline_parallel_schedule = "DualPipeV"
5556
expert_tensor_parallel_degree = 1
5657

5758
[checkpoint]
@@ -63,7 +64,7 @@ export_dtype = "float32"
6364
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
6465

6566
[activation_checkpoint]
66-
mode = "selective" # ["none", "selective", "full"]
67+
mode = "none" # ["none", "selective", "full"]
6768
selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy
6869

6970
[compile]

0 commit comments

Comments
 (0)