Skip to content

Commit 89dcb96

Browse files
committed
Export MoE
1 parent d0170b1 commit 89dcb96

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
# torchrun --nproc-per-node 4 --standalone tracing.py
2+
3+
import torch
4+
import torch.distributed as dist
5+
6+
from model import MoE
7+
from model_config import deepseek_config_registry
8+
9+
10+
def print0(*args, **kwargs):
11+
if dist.get_rank() == 0:
12+
print("\n")
13+
print(*args, **kwargs)
14+
15+
def setup_mesh():
16+
ep_size = dist.get_world_size()
17+
mesh_shape = (ep_size,)
18+
mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("ep",))
19+
return mesh
20+
21+
def setup_model(mesh):
22+
group_size = mesh["ep"].size()
23+
rank = mesh["ep"].get_local_rank()
24+
25+
model_id = "deepseek-ai/DeepSeek-V2-Lite-Chat"
26+
config = deepseek_config_registry[model_id]
27+
config.ep_size = group_size
28+
29+
device = torch.device("cuda", rank % torch.cuda.device_count())
30+
dtype = torch.bfloat16
31+
torch.set_default_dtype(dtype)
32+
33+
# Initialize the model
34+
print0("Initializing MoE model...")
35+
with mesh, torch.device(device):
36+
moe = MoE(config)
37+
38+
print0("Setting up Symmetric Memory ...")
39+
moe.setup_symm_mem(torch.bfloat16, device)
40+
41+
return moe
42+
43+
def test_export(moe, mesh):
44+
seqlen = 256
45+
batch_size = 1
46+
config = moe.config
47+
48+
rank = mesh["ep"].get_local_rank()
49+
device = torch.device("cuda", rank % torch.cuda.device_count())
50+
51+
x = torch.randn(
52+
batch_size, seqlen, config.hidden_size, dtype=torch.bfloat16, device=device
53+
)
54+
y = moe(x)
55+
# print(y.shape)
56+
57+
# Let's export the model
58+
print0("Exporting MoE model using torch.export...")
59+
60+
# Put model in eval mode for export
61+
moe.eval()
62+
63+
# Create example input for export
64+
example_input = (
65+
torch.randn(
66+
batch_size, seqlen, config.hidden_size, dtype=torch.bfloat16, device=device
67+
),
68+
)
69+
70+
# Export using torch.export.export
71+
exported_model = torch.export.export(moe, example_input)
72+
print0("Successfully exported the MoE model using torch.export")
73+
74+
# Save the exported model
75+
# export_path = "exported_moe_model.pt"
76+
# torch.export.save(exported_model, export_path)
77+
# print(f"Exported model saved to: {export_path}")
78+
79+
# Test the exported model
80+
print0("Testing exported model...")
81+
with torch.no_grad():
82+
original_output = moe(*example_input)
83+
exported_output = exported_model.module()(*example_input)
84+
85+
# Check if outputs are close
86+
if torch.allclose(original_output, exported_output, rtol=1e-3, atol=1e-3):
87+
print0("✓ Exported model outputs match original model outputs")
88+
else:
89+
print0("⚠ Exported model outputs differ from original model")
90+
print0(
91+
f"Max difference: {torch.max(torch.abs(original_output - exported_output))}"
92+
)
93+
94+
print0("Model export completed!\n")
95+
96+
if rank == 0:
97+
exported_model.graph_module.print_readable()
98+
99+
100+
if __name__ == "__main__":
101+
torch.distributed.init_process_group(backend="nccl")
102+
mesh = setup_mesh()
103+
moe = setup_model(mesh)
104+
test_export(moe, mesh)
105+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)