Skip to content

Commit f47d26d

Browse files
committed
Annotate EP with dispatch/compute/combine
ghstack-source-id: faa8a54 Pull Request resolved: #1907
1 parent c5aa247 commit f47d26d

File tree

2 files changed

+3
-0
lines changed

2 files changed

+3
-0
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(self):
7373
self.permuted_indices = None
7474

7575
# performing all-to-all dispatch on the input
76+
@torch.fx.traceback.annotate_fn({"EP": "dispatch"})
7677
def _token_dispatch(self, mod, inputs, device_mesh):
7778
# annotate module input placements/sharding with input_layouts
7879
routed_input, num_tokens_per_expert = inputs
@@ -145,6 +146,7 @@ def _partition_fn(name, mod, device_mesh):
145146
mod.register_parameter(name, dist_param)
146147

147148
# performing all-to-all combine on the output
149+
@torch.fx.traceback.annotate_fn({"EP": "combine"})
148150
def _token_combine(self, mod, routed_output, device_mesh):
149151
routed_output = _unpermute(
150152
routed_output, self.input_shape, self.permuted_indices

torchtitan/models/moe/moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def __init__(
139139
self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
140140
self.use_grouped_mm = use_grouped_mm
141141

142+
@torch.fx.traceback.annotate_fn({"EP": "compute"})
142143
def forward(
143144
self,
144145
x: torch.Tensor,

0 commit comments

Comments
 (0)