Skip to content

Commit 480539e

Browse files
committed
Export MoE
1 parent d0170b1 commit 480539e

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import torch
2+
import torch.distributed as dist
3+
4+
from model import MoE
5+
from model_config import deepseek_config_registry
6+
7+
8+
def print0(*args, **kwargs):
9+
if dist.get_rank() == 0:
10+
print("\n")
11+
print(*args, **kwargs)
12+
13+
14+
def run(ep_size):
15+
mesh_shape = (ep_size,)
16+
mesh = dist.init_device_mesh("cuda", mesh_shape, mesh_dim_names=("ep",))
17+
group_size = mesh["ep"].size()
18+
rank = mesh["ep"].get_local_rank()
19+
20+
model_id = "deepseek-ai/DeepSeek-V2-Lite-Chat"
21+
config = deepseek_config_registry[model_id]
22+
config.ep_size = group_size
23+
24+
device = torch.device("cuda", rank % torch.cuda.device_count())
25+
dtype = torch.bfloat16
26+
torch.set_default_dtype(dtype)
27+
28+
# Initialize the model
29+
print0("Initializing MoE model...")
30+
with mesh, torch.device(device):
31+
moe = MoE(config)
32+
33+
print0("Setting up Symmetric Memory ...")
34+
moe.setup_symm_mem(torch.bfloat16, device)
35+
36+
seqlen = 256
37+
batch_size = 1
38+
x = torch.randn(
39+
batch_size, seqlen, config.hidden_size, dtype=torch.bfloat16, device=device
40+
)
41+
y = moe(x)
42+
# print(y.shape)
43+
44+
# Let's export the model
45+
print0("Exporting MoE model using torch.export...")
46+
47+
# Put model in eval mode for export
48+
moe.eval()
49+
50+
# Create example input for export
51+
example_input = (
52+
torch.randn(
53+
batch_size, seqlen, config.hidden_size, dtype=torch.bfloat16, device=device
54+
),
55+
)
56+
57+
# Export using torch.export.export
58+
exported_model = torch.export.export(moe, example_input)
59+
print0("Successfully exported the MoE model using torch.export")
60+
61+
# Save the exported model
62+
# export_path = "exported_moe_model.pt"
63+
# torch.export.save(exported_model, export_path)
64+
# print(f"Exported model saved to: {export_path}")
65+
66+
# Test the exported model
67+
print0("Testing exported model...")
68+
with torch.no_grad():
69+
original_output = moe(*example_input)
70+
exported_output = exported_model.module()(*example_input)
71+
72+
# Check if outputs are close
73+
if torch.allclose(original_output, exported_output, rtol=1e-3, atol=1e-3):
74+
print0("✓ Exported model outputs match original model outputs")
75+
else:
76+
print0("⚠ Exported model outputs differ from original model")
77+
print0(
78+
f"Max difference: {torch.max(torch.abs(original_output - exported_output))}"
79+
)
80+
81+
print0("Model export completed!\n")
82+
83+
if rank == 0:
84+
exported_model.graph_module.print_readable()
85+
86+
87+
if __name__ == "__main__":
88+
run(4)
89+
torch.distributed.destroy_process_group()

0 commit comments

Comments
 (0)