Skip to content

Commit b7aebb5

Browse files
authored
Update tensor_parallel_example.py (#1324)
1 parent 2c435c7 commit b7aebb5

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

distributed/tensor_parallelism/tensor_parallel_example.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,6 @@ def forward(self, x):
9191
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
9292
tp_model = ToyModel().to("cuda")
9393

94-
# Create an optimizer for the parallelized module.
95-
lr = 0.25
96-
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)
9794

9895
# Custom parallelization plan for the model
9996
tp_model = parallelize_module(
@@ -104,6 +101,12 @@ def forward(self, x):
104101
"out_proj": RowwiseParallel(),
105102
},
106103
)
104+
105+
# Create an optimizer for the parallelized module.
106+
lr = 0.25
107+
optimizer = torch.optim.AdamW(tp_model.parameters(), lr=lr, foreach=True)
108+
109+
107110
# Perform a num of iterations of forward/backward
108111
# and optimizations for the sharded module.
109112
num_iters = 10

0 commit comments

Comments
 (0)